Skip to content

Commit

Permalink
Add trtrt.alloc_enqueue op
Browse files Browse the repository at this point in the history
  • Loading branch information
jhalakpatel committed Oct 11, 2024
1 parent 41142c4 commit 36c143c
Show file tree
Hide file tree
Showing 6 changed files with 277 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,57 @@ def TensorRTRuntime_EnqueueOp : TensorRTRuntime_Op<"enqueue", [
}];
}

//===----------------------------------------------------------------------===//
// EnqueueAllocOp
//===----------------------------------------------------------------------===//

def Output_Desc : TypeDef<CUDA_Dialect, "Output_Desc", []> {
let mnemonic = "Output_Desc";
let description = [{
An opaque object which represents a CUDA stream object (CUstream).
A CUDA stream contains a sequence of operations that execute on GPU in the
order in which they are issued by the host.
}];
}

def TensorRTRuntime_EnqueueAllocOp : TensorRTRuntime_Op<"enqueue_alloc", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
]> {
let description = [{
Asynchronously executes the computation represented by the
`execution_context` on the specified CUDA stream. This operation
can accept inputs of either tensor or memref types and returns
results of either tensor or memref types.
}];

let arguments = (ins
TensorRTRuntime_Context:$execution_context,
CUDA_Stream:$stream,
Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$inputs,
OptionalAttr<DenseI64ArrayAttr>:$host_tensor_args
);

let results = (outs Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$results);

let assemblyFormat = [{
$execution_context `stream` `(` $stream `)` ` `
(`host_tensor_args` $host_tensor_args^ ` ` )?
`(` $inputs `)`
attr-dict `:` functional-type($inputs, $results)
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
/// Return true if the operand at the specified index is a host tensor
/// argument.
bool isOperandOnHost(int64_t operandIdx) {
if(std::optional<ArrayRef<int64_t>> indices = getHostTensorArgs()) {
return llvm::is_contained(*indices, operandIdx - 2);
}
return false;
}
}];
}

#endif // MLIR_TENSORRT_DIALECT_TENSORRT_IR_TENSORRTRUNTIMEOPS_TD
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ LogicalResult EnqueueOp::inferReturnTypes(
SmallVectorImpl<Type> &inferredReturnTypes) {
EnqueueOp::Adaptor adaptor(operands, attributes, properties, regions);

// If the `outs` operands are tensor types, then we shoudl return those as
// If the `outs` operands are tensor types, then we should return those as
// results. Otherwise, for memref outs, we do not return results.
for (Type t : TypeRange(adaptor.getOuts())) {
auto tensorType = dyn_cast<TensorType>(t);
Expand Down Expand Up @@ -113,6 +113,62 @@ void EnqueueOp::getEffects(
}
}

//===----------------------------------------------------------------------===//
// EnqueueAllocOp
//===----------------------------------------------------------------------===//

LogicalResult EnqueueAllocOp::verify() {
// Verify host tensor indices.
if (std::optional<ArrayRef<int64_t>> hostTensorIndices =
getHostTensorArgs()) {
// We don't count the context and stream argument here.
const int64_t numInputArgs = getInputs().size();
for (int64_t idx : *hostTensorIndices) {
if (idx >= numInputArgs || idx < 0)
return emitOpError("host_tensor_args value ")
<< idx << " is out of bounds";
Value operand = getInputs()[idx];
Type elType = mlir::getElementTypeOrSelf(operand.getType());
if (!elType.isInteger(32))
return emitOpError("host tensor arguments must have element type i32, "
"but input arg ")
<< idx << " has type " << operand.getType();
}
}

// Verify that all results are either all tensors or all memrefs
if (getNumResults() > 0) {
bool allTensors = getResult(0).getType().isa<TensorType>();
for (auto result : getResults()) {
if (result.getType().isa<TensorType>() != allTensors) {
return emitOpError("all results must be of the same type (all tensors "
"or all memrefs)");
}
}
}
return success();
}

void EnqueueAllocOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
// This op allocates memory for its results
effects.emplace_back(MemoryEffects::Allocate::get(), 0,
/*effectOnFullRegion=*/true);

for (OpOperand &operand : getInputsMutable()) {
if (!llvm::isa<MemRefType>(operand.get().getType()))
continue;
effects.emplace_back(MemoryEffects::Read::get(), &operand,
SideEffects::DefaultResource::get());
}
for (OpResult result : getResults()) {
effects.emplace_back(MemoryEffects::Write::get(), result,
SideEffects::DefaultResource::get());
}
}


//===----------------------------------------------------------------------===//
// TensorRTRuntimeDialect Interfaces
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ static bool isInMemorySpace(Type memrefType, plan::MemorySpace memType) {
return space == plan::MemorySpaceAttr::get(type.getContext(), memType);
}

static bool isScalar(Value *input) {
if (input->getType().isa<TensorType>() || input->getType().isa<MemRefType>()) {
return false;
}
return true;
}

namespace {
struct EnqueueOpInterface
: public bufferization::DstBufferizableOpInterfaceExternalModel<
Expand All @@ -49,16 +56,16 @@ struct EnqueueOpInterface
/// outputs in our use-case.
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const bufferization::AnalysisState &state) const {
EnqueueOp callOp = cast<EnqueueOp>(op);
return callOp.isDpsInput(&opOperand);
EnqueueOp enqueueOp = cast<EnqueueOp>(op);
return enqueueOp.isDpsInput(&opOperand);
}

/// Only dps inits are written.
bool
bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const bufferization::AnalysisState &state) const {
EnqueueOp callOp = cast<EnqueueOp>(op);
return callOp.isDpsInit(&opOperand);
EnqueueOp enqueueOp = cast<EnqueueOp>(op);
return enqueueOp.isDpsInit(&opOperand);
}

// TensorRT will guarantee that the input will be read before the result
Expand All @@ -72,21 +79,21 @@ struct EnqueueOpInterface
/// Bufferize the `trtrt.enqueue` operation.
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
EnqueueOp callOp = cast<EnqueueOp>(op);
EnqueueOp enqueueOp = cast<EnqueueOp>(op);
MLIRContext *ctx = op->getContext();
Location loc = op->getLoc();
rewriter.setInsertionPoint(callOp);
rewriter.setInsertionPoint(enqueueOp);

// For the inputs, check the memory space and insert a copy if it is not in
// the correct space.
SmallVector<Value> newInputBuffers;
newInputBuffers.reserve(callOp.getNumDpsInputs());
newInputBuffers.reserve(enqueueOp.getNumDpsInputs());
for (auto [idx, opOperand] :
llvm::enumerate(callOp.getDpsInputOperands())) {
llvm::enumerate(enqueueOp.getDpsInputOperands())) {

// The context and steam operands are considered "DPS inputs" and
// therefore they'll be skipped here.
if (callOp.isScalar(opOperand)) {
if (enqueueOp.isScalar(opOperand)) {
newInputBuffers.push_back(opOperand->get());
continue;
}
Expand All @@ -99,7 +106,7 @@ struct EnqueueOpInterface
// Check if this input is a host tensor. Insert a copy if required. Note
// that we subtract two from the index to account for context/stream
// arguments.
if (callOp.isOperandOnHost(idx) &&
if (enqueueOp.isOperandOnHost(idx) &&
!isInMemorySpace(memRefType, plan::MemorySpace::host_pinned)) {
FailureOr<Value> pinnedAlloc = options.createAlloc(
rewriter, op->getLoc(),
Expand All @@ -117,7 +124,7 @@ struct EnqueueOpInterface
}

// If we are in host space, then copy to the device.
if (!callOp.isOperandOnHost(idx) &&
if (!enqueueOp.isOperandOnHost(idx) &&
!isInMemorySpace(memRefType, plan::MemorySpace::device)) {
FailureOr<Value> devAlloc = options.createAlloc(
rewriter, op->getLoc(),
Expand All @@ -139,10 +146,10 @@ struct EnqueueOpInterface
}

SmallVector<Value> newOutputBuffers;
newOutputBuffers.reserve(callOp.getNumDpsInits());
newOutputBuffers.reserve(enqueueOp.getNumDpsInits());
for (OpResult opResult : op->getOpResults()) {
OpOperand *opOperand =
callOp.getDpsInitOperand(opResult.getResultNumber());
enqueueOp.getDpsInitOperand(opResult.getResultNumber());
FailureOr<Value> resultBuffer =
getBuffer(rewriter, opOperand->get(), options);
if (failed(resultBuffer))
Expand All @@ -153,18 +160,144 @@ struct EnqueueOpInterface
rewriter.create<EnqueueOp>(
op->getLoc(), newInputBuffers[0], newInputBuffers[1],
ValueRange(newInputBuffers).drop_front(2), newOutputBuffers,
callOp.getHostTensorArgsAttr());
enqueueOp.getHostTensorArgsAttr());
replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers);
return success();
}
};


struct EnqueueAllocOpInterface
: public bufferization::BufferizableOpInterface::ExternalModel<
EnqueueAllocOpInterface, EnqueueAllocOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const bufferization::AnalysisState &state) const {
auto enqueueAllocOp = cast<EnqueueAllocOp>(op);
OperandRange inputs = enqueueAllocOp.getInputs();
return std::find(inputs.begin(), inputs.end(), opOperand.get()) !=
inputs.end();
}

bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const bufferization::AnalysisState &state) const {
return false; // This op doesn't write to its inputs
}


bufferization::AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const bufferization::AnalysisState &state) const {
return {};
}

bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
const bufferization::AnalysisState &state) const {
// EnqueueAllocOp creates new outputs, doesn't modify inputs in-place
return false;
}

SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
const bufferization::AnalysisState &state) const {
return {}; // This op doesn't alias its inputs to its outputs
}

bool bufferizesToElementwiseAccess(Operation *op,
const bufferization::AnalysisState &state,
ArrayRef<OpOperand *> opOperands) const {
return true;
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto enqueueAllocOp = cast<EnqueueAllocOp>(op);
MLIRContext *ctx = op->getContext();
Location loc = op->getLoc();
rewriter.setInsertionPoint(enqueueAllocOp);

// Handle inputs
SmallVector<Value> newInputBuffers;
newInputBuffers.reserve(enqueueAllocOp.getInputs().size());
for (auto [idx, input] : llvm::enumerate(enqueueAllocOp.getInputs())) {
if (isScalar(&input)) {
newInputBuffers.push_back(input);
continue;
}
FailureOr<Value> buffer = getBuffer(rewriter, input, options);
if (failed(buffer))
return failure();

MemRefType memRefType = cast<MemRefType>(buffer->getType());

// Handle host tensor inputs
if (enqueueAllocOp.isOperandOnHost(idx) &&
!isInMemorySpace(memRefType, plan::MemorySpace::host_pinned)) {
FailureOr<Value> pinnedAlloc = options.createAlloc(
rewriter, loc,
MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
memRefType.getLayout(),
plan::MemorySpaceAttr::get(
ctx, plan::MemorySpace::host_pinned)),
ValueRange{});
if (failed(pinnedAlloc))
return failure();
if (failed(options.createMemCpy(rewriter, loc, *buffer, *pinnedAlloc)))
return failure();
newInputBuffers.push_back(*pinnedAlloc);
continue;
}

// Handle device tensor inputs
if (!enqueueAllocOp.isOperandOnHost(idx) &&
!isInMemorySpace(memRefType, plan::MemorySpace::device)) {
FailureOr<Value> devAlloc = options.createAlloc(
rewriter, loc,
MemRefType::get(
memRefType.getShape(), memRefType.getElementType(),
memRefType.getLayout(),
plan::MemorySpaceAttr::get(ctx, plan::MemorySpace::device)),
ValueRange{});
if (failed(devAlloc))
return failure();
if (failed(options.createMemCpy(rewriter, loc, *buffer, *devAlloc)))
return failure();
newInputBuffers.push_back(*devAlloc);
continue;
}

newInputBuffers.push_back(*buffer);
}

// Handle results
SmallVector<Type> outputBufferTypes;
outputBufferTypes.reserve(enqueueAllocOp.getNumResults());
for (unsigned i = 0; i < enqueueAllocOp.getNumResults(); ++i) {
Type resultType = enqueueAllocOp->getResultTypes()[i];
MemRefType memRefType;
if (auto tensorType = resultType.dyn_cast<TensorType>()) {
memRefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
MemRefLayoutAttrInterface(),
plan::MemorySpaceAttr::get(ctx, plan::MemorySpace::device));
} else {
memRefType = resultType.cast<MemRefType>();
}
outputBufferTypes.push_back(memRefType);
}

// Create the new operation
rewriter.create<EnqueueAllocOp>(
loc, TypeRange(outputBufferTypes), enqueueAllocOp.getExecutionContext(),
enqueueAllocOp.getStream(), newInputBuffers, enqueueAllocOp.getHostTensorArgsAttr());

return success();
}
};

} // namespace

void trtrt::registerBufferizableOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(
+[](MLIRContext *ctx, trtrt::TensorRTRuntimeDialect *dialect) {
trtrt::EnqueueOp::attachInterface<EnqueueOpInterface>(*ctx);
trtrt::EnqueueAllocOp::attachInterface<EnqueueAllocOpInterface>(*ctx);
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ static Status enqueueV3Wrapper(AllocTracker &tracker,
}

static Status
allocEnqueueV3Wrapper(AllocTracker &tracker, ResourceTracker &resourceTracker,
enqueueAllocV3Wrapper(AllocTracker &tracker, ResourceTracker &resourceTracker,
OutputAllocatorTracker &outputAllocatorTracker,
NvInferExecContextWrapper &context, CudaStreamPtr stream,
sol::table &va, OutputDescriptor outputDesc) {
Expand Down Expand Up @@ -568,16 +568,16 @@ void mlirtrt::runtime::registerExecutorTensorRTModuleLuaRuntimeMethods(
SET_LUA_ERROR_IF_ERROR(result, state);
};

lua["_trtrt_alloc_enqueue"] =
lua["_trtrt_enqueue_alloc"] =
[allocTracker, resourceTracker, outputAllocatorTracker](
sol::this_state state,
std::shared_ptr<NvInferExecContextWrapper> context,
CudaStreamPtr stream, uintptr_t outputDesc, sol::table va) {
ADD_TENSORRT_MODULE_RANGE("trtrt_alloc_enqueue");
ADD_TENSORRT_MODULE_RANGE("trtrt_enqueue_alloc");
sol::state_view luaState(state);
assert(context != nullptr);
assert(stream != nullptr && "expected valid stream");
Status result = allocEnqueueV3Wrapper(*allocTracker, *resourceTracker,
Status result = enqueueAllocV3Wrapper(*allocTracker, *resourceTracker,
*outputAllocatorTracker, *context,
stream, va, outputDesc);
SET_LUA_ERROR_IF_ERROR(result, state);
Expand Down
Loading

0 comments on commit 36c143c

Please sign in to comment.