diff --git a/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp b/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp index 3b8308449..45357387c 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp +++ b/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp @@ -223,6 +223,10 @@ StableHLOToExecutableOptions::StableHLOToExecutableOptions( disallowHostTensorsInTensorRTClusters, llvm::cl::init(false), llvm::cl::desc("Don't allow TensorRt clusters to contain host tensor " "calculations (but they can still be inputs)")); + addOption( + "enable-non-dps-returns", enableNonDPSReturns, llvm::cl::init(false), + llvm::cl::desc( + "allow tensorrt based output allocations using output allocator")); addOption("executor-index-bitwidth", executorIndexBitwidth, llvm::cl::init(64)); addOption("device-compute-capability", deviceComputeCapability, @@ -307,6 +311,7 @@ void StableHloToExecutableTask::buildStablehloClusteringPipeline( plan::StablehloClusteringPassOptions clusteringOpts{}; clusteringOpts.disallowHostTensorsInTensorRTClusters = opts.disallowHostTensorsInTensorRTClusters; + clusteringOpts.enableNonDPSReturns = opts.enableNonDPSReturns; clusteringOpts.entrypoint = opts.entrypoint; plan::buildPlanSegmentationPipeline(pm, clusteringOpts); @@ -340,7 +345,9 @@ void StableHloToExecutableTask::buildPostClusteringPipeline( // Perform bufferization. pm.addPass(createMemRefCastEliminationPass()); - pm.addPass(plan::createPlanAllocTensorsPass()); + plan::PlanAllocTensorsPassOptions allocTensorsOpts{}; + allocTensorsOpts.enableNonDPSReturns = opts.enableNonDPSReturns; + pm.addPass(plan::createPlanAllocTensorsPass(allocTensorsOpts)); pm.addPass(plan::createPlanBufferizePass()); pm.addPass(createMemRefCastEliminationPass()); pm.addPass(createCanonicalizerPass()); @@ -529,6 +536,11 @@ struct ClusteringPipelineCliOpts *this, "device-compute-capability", llvm::cl::desc("target device compute capability (SM version)"), llvm::cl::init(60)}; + Option enableNonDPSReturns{ + *this, "enable-non-dps-returns", + llvm::cl::desc( + "allow tensorrt based output allocations using output allocator"), + llvm::cl::init(false)}; Option deviceMaxSharedMemoryPerBlockKb{ *this, "device-max-smem-per-block", llvm::cl::desc("max shared memory per block (in kilobytes)"), @@ -556,6 +568,7 @@ static StableHLOToExecutableOptions populateStablehloClusteringPipelineOpts( opts.deviceComputeCapability = cliOpts.deviceComputeCapability; opts.deviceMaxSharedMemoryPerBlockKb = cliOpts.deviceMaxSharedMemoryPerBlockKb; + opts.enableNonDPSReturns = cliOpts.enableNonDPSReturns; opts.shouldInferDeviceOptionsFromHost = cliOpts.inferDeviceOptionsFromHost; opts.entrypoint = cliOpts.entrypoint; return opts; diff --git a/mlir-tensorrt/compiler/lib/Conversion/TensorRTRuntimeToExecutor/TensorRTRuntimeToExecutor.cpp b/mlir-tensorrt/compiler/lib/Conversion/TensorRTRuntimeToExecutor/TensorRTRuntimeToExecutor.cpp index 2a585a43d..8c10b35d4 100644 --- a/mlir-tensorrt/compiler/lib/Conversion/TensorRTRuntimeToExecutor/TensorRTRuntimeToExecutor.cpp +++ b/mlir-tensorrt/compiler/lib/Conversion/TensorRTRuntimeToExecutor/TensorRTRuntimeToExecutor.cpp @@ -379,20 +379,22 @@ struct ConvertEnqueueAllocToCall // Create output memrefs from output descriptors SmallVector results; + // Initialize output descriptor offset to skip number of results. + // `outputDescOffset` is used to retrieve rank, ptr, shapes, and strides per + // result. + outputDescOffset = 1; for (unsigned i = 0; i < op->getNumResults(); ++i) { unsigned rank = cast(op->getResult(i).getType()).getRank(); - unsigned offset = - 1 + - i * (2 * rank + 2); // num res, (i * (rank, ptr, [shape], [stride])) - Value rankOffset = b.create( b.getI64Type(), structType, - ArrayRef{this->createIndexConstant(b, 0), - rewriter.getI64IntegerAttr(offset++)}); + ArrayRef{ + this->createIndexConstant(b, 0), + rewriter.getI64IntegerAttr(outputDescOffset++)}); Value devicePtrOffset = b.create( b.getI64Type(), structType, - ArrayRef{this->createIndexConstant(b, 0), - rewriter.getI64IntegerAttr(offset++)}); + ArrayRef{ + this->createIndexConstant(b, 0), + rewriter.getI64IntegerAttr(outputDescOffset++)}); [[maybe_unused]] Value rankValue = b.create( b.getI64Type(), outputDescriptors, rankOffset); @@ -406,8 +408,9 @@ struct ConvertEnqueueAllocToCall for (unsigned r = 0; r < rank; ++r) { Value shapeOffset = b.create( b.getI64Type(), structType, - ArrayRef{this->createIndexConstant(b, 0), - rewriter.getI64IntegerAttr(offset++)}); + ArrayRef{ + this->createIndexConstant(b, 0), + rewriter.getI64IntegerAttr(outputDescOffset++)}); Value shape = b.create( b.getI64Type(), outputDescriptors, shapeOffset); shapes.push_back(shape); @@ -416,8 +419,9 @@ struct ConvertEnqueueAllocToCall for (unsigned r = 0; r < rank; ++r) { Value strideOffset = b.create( b.getI64Type(), structType, - ArrayRef{this->createIndexConstant(b, 0), - rewriter.getI64IntegerAttr(offset++)}); + ArrayRef{ + this->createIndexConstant(b, 0), + rewriter.getI64IntegerAttr(outputDescOffset++)}); Value shape = b.create( b.getI64Type(), outputDescriptors, strideOffset); shapes.push_back(shape); diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CMakeLists.txt b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CMakeLists.txt index 65add5f43..2f0bb5c11 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CMakeLists.txt +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CMakeLists.txt @@ -37,6 +37,7 @@ add_mlir_tensorrt_library(MLIRTensorRTPlanTransforms MLIRTensorRTStablehloScalarToArith MLIRTensorRTStablehloToTensorRT MLIRTensorRTTensorRTRuntimeDialect + MLIRBufferizationToMemRef MLIRTransforms StablehloOps ) diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp index eb29494c9..84d67b037 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp @@ -65,17 +65,26 @@ struct RemoveWithValuesRewriter : public OpRewritePattern { } // namespace /// Get a map from `tensorrt.func` functions to associated `tensorrt.call` -/// operations. -static llvm::DenseMap> +/// and `tensorrt.call_alloc` operations. +static llvm::DenseMap> getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) { - llvm::DenseMap> map; - op->walk([&](tensorrt::CallOp callOp) { - func::FuncOp func = callOp.getFuncCallee(collection); - if (map.contains(func)) { - map[func].push_back(callOp); + llvm::DenseMap> map; + op->walk([&](Operation *callOp) { + if (!isa(callOp)) return; + + func::FuncOp func; + if (auto call = dyn_cast(callOp)) { + func = call.getFuncCallee(collection); + } else { + auto callAlloc = dyn_cast(callOp); + func = callAlloc.getFuncCallee(collection); } - map.insert(std::make_pair(func, SmallVector{callOp})); + + if (map.count(func)) + map[func].push_back(callOp); + else + map.insert({func, SmallVector{callOp}}); }); return map; } @@ -84,7 +93,7 @@ getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) { /// `tensorrt.call` operations. static LogicalResult removeUnusedArgs(SymbolTableCollection &collection, ModuleOp op, func::FuncOp funcOp, - ArrayRef callOps) { + ArrayRef callOps) { llvm::SmallBitVector unusedArgs(funcOp.getNumArguments(), 0); for (BlockArgument arg : funcOp.getArguments()) { if (arg.use_empty()) @@ -99,10 +108,16 @@ static LogicalResult removeUnusedArgs(SymbolTableCollection &collection, funcOp.eraseArgument(i); // Update the call ops. - for (tensorrt::CallOp callOp : callOps) - callOp.getInputsMutable().erase(i); + for (Operation *callOp : callOps) { + if (auto call = dyn_cast(callOp)) + call.getInputsMutable().erase(i); + else if (auto callAlloc = dyn_cast(callOp)) + callAlloc.getInputsMutable().erase(i); + else + return emitError(funcOp->getLoc()) + << "Unexpected operation type in callOps"; + } } - return success(); } diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp index d593393f6..391d86578 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp @@ -156,28 +156,32 @@ getTensorRTShapeProfile(plan::BoundsAttr attr, Value v) { return getProfileAttr(attr.getMinShape(), attr.getMaxShape()); } -static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter, - plan::InlineClosedGroupOp op) { - tensorrt::TensorRTModuleOp trtModuleOp = getOrCreateTensorRTModuleOp(op); - auto funcArgTypes = llvm::to_vector(TypeRange(op.getInputs())); - FailureOr func = createOutlinedFunc( - rewriter, op.getLoc(), op, trtModuleOp, "tensorrt_cluster", - "cluster.tensorrt", TypeRange(op.getInputs()), - op.getYield()->getOperandTypes()); - if (failed(func)) - return failure(); - assert(func->getFunctionBody().getBlocks().size() == 1 && - "expected body with one block"); - func->setPublic(); - - rewriter.setInsertionPoint(op); - auto callOp = rewriter.create( - op.getLoc(), op.getResultTypes(), op.getInputs(), op.getOuts(), - SymbolRefAttr::get(trtModuleOp.getNameAttr(), - {FlatSymbolRefAttr::get(*func)})); +template +static auto createCallOp(RewriterBase &rewriter, OpType op, + tensorrt::TensorRTModuleOp trtModuleOp, + FunctionOpInterface func) { + static_assert( + std::is_same_v || + std::is_same_v, + "OpType must be either InlineClosedGroupOp or InlineClosedAllocGroupOp"); + if constexpr (std::is_same_v) + return rewriter.create( + op.getLoc(), op.getResultTypes(), op.getInputs(), op.getOuts(), + SymbolRefAttr::get(trtModuleOp.getNameAttr(), + {FlatSymbolRefAttr::get(func)})); + else if constexpr (std::is_same_v) + return rewriter.create( + op.getLoc(), op.getResultTypes(), op.getInputs(), + SymbolRefAttr::get(trtModuleOp.getNameAttr(), + {FlatSymbolRefAttr::get(func)})); +} +template +static LogicalResult populateFunctionAttributes(RewriterBase &rewriter, + OpType op, + FunctionOpInterface *func) { // Populate the function arguments attributes. - for (unsigned i = 0; i < (*func).getNumArguments(); i++) { + for (unsigned i = 0; i < func->getNumArguments(); i++) { BoundsAttr srcAttr = cast(op.getInputAttrs()[i]); // We may have scalar (index|signless int)-typed values since we haven't // eliminated `plan.(with_shape|with_values)` ops yet. @@ -202,30 +206,57 @@ static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter, func->setArgAttr(i, mlir::getHostTensorArgAttrName(), rewriter.getUnitAttr()); } - // Populate the function result attributes. - for (unsigned i = 0; i < (*func).getNumResults(); i++) { - BoundsAttr srcAttr = cast(op.getResAttrs()[i]); - if (srcAttr.isNone()) - continue; - FailureOr boundsAttr = - getTensorRTShapeProfile(srcAttr, op.getResults()[i]); - if (failed(boundsAttr)) - return op->emitOpError("failed to create TensorRT shape profile " - "attribute from Plan BoundsAttr for result #") - << i << " (" << srcAttr << ")"; - if (srcAttr.isShapeBound()) { + // Populate the function result attributes for DPS call op. + if constexpr (std::is_same_v) { + for (unsigned i = 0; i < func->getNumResults(); i++) { + BoundsAttr srcAttr = cast(op.getResAttrs()[i]); + if (srcAttr.isNone()) + continue; + FailureOr boundsAttr = + getTensorRTShapeProfile(srcAttr, op.getResults()[i]); + if (failed(boundsAttr)) + return op->emitOpError("failed to create TensorRT shape profile " + "attribute from Plan BoundsAttr for result #") + << i << " (" << srcAttr << ")"; + if (srcAttr.isShapeBound()) { + func->setResultAttr( + i, tensorrt::TensorRTDialect::getShapeProfileArgAttrName(), + *boundsAttr); + continue; + } + assert(srcAttr.isValueBound() && "expected value bound or shape bound"); func->setResultAttr( - i, tensorrt::TensorRTDialect::getShapeProfileArgAttrName(), + i, tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName(), *boundsAttr); - continue; + func->setResultAttr(i, mlir::getHostTensorArgAttrName(), + rewriter.getUnitAttr()); } - assert(srcAttr.isValueBound() && "expected value bound or shape bound"); - func->setResultAttr( - i, tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName(), - *boundsAttr); - func->setResultAttr(i, mlir::getHostTensorArgAttrName(), - rewriter.getUnitAttr()); } + return success(); +} + +template +static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter, OpType op) { + tensorrt::TensorRTModuleOp trtModuleOp = getOrCreateTensorRTModuleOp(op); + auto funcArgTypes = llvm::to_vector(TypeRange(op.getInputs())); + + FailureOr func = createOutlinedFunc( + rewriter, op.getLoc(), op, trtModuleOp, "tensorrt_cluster", + "cluster.tensorrt", TypeRange(op.getInputs()), + op.getYield()->getOperandTypes()); + + if (failed(func)) + return failure(); + + assert(func->getFunctionBody().getBlocks().size() == 1 && + "expected body with one block"); + func->setPublic(); + + rewriter.setInsertionPoint(op); + auto callOp = createCallOp(rewriter, op, trtModuleOp, *func); + + if (failed(populateFunctionAttributes(rewriter, op, &(*func)))) + return failure(); // Populate the function entry block. rewriter.eraseBlock(&func->getFunctionBody().front()); @@ -234,14 +265,14 @@ static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter, // ops to the `tensorrt.module` op. This is needed since `tensorrt.module` op // has its own symbol table. SymbolTableCollection symbolTable; - for (auto compositeOp : op.getBody().getOps()) { + for (auto compositeOp : + op.getBody().template getOps()) { auto decompositionFunc = dyn_cast_if_present( - symbolTable.lookupSymbolIn(op->getParentOfType(), + symbolTable.lookupSymbolIn(op->template getParentOfType(), compositeOp.getDecompositionAttr())); if (!decompositionFunc) return emitError(compositeOp.getLoc()) - << "failed to lookup stablehlo.composite decomposition " - "function: " + << "failed to lookup stablehlo.composite decomposition function: " << compositeOp.getDecompositionAttr(); rewriter.moveOpAfter(decompositionFunc, func->getOperation()); } @@ -254,24 +285,20 @@ static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter, rewriter.replaceOpWithNewOp(regionYieldOp, regionYieldOp->getOperands()); - // Erase the DPS arugments, which now should be unused. - if (llvm::any_of(func->getArguments().take_back(op.getOuts().size()), - [](BlockArgument arg) { return !arg.use_empty(); })) - return failure(); - func->getFunctionBody().front().eraseArguments(op.getInputs().size(), - op.getOuts().size()); + if constexpr (std::is_same_v) { + // Erase the DPS arugments, which now should be unused. + if (llvm::any_of(func->getArguments().take_back(op.getOuts().size()), + [](BlockArgument arg) { return !arg.use_empty(); })) + return failure(); + func->getFunctionBody().front().eraseArguments(op.getInputs().size(), + op.getOuts().size()); + } // replace the original region results. rewriter.replaceOp(op, callOp); return success(); } -static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter, - plan::InlineClosedAllocGroupOp op) { - return op.emitError("outlinining inline closed alloc group ops to tensorrt " - "dialect is not yet implemented"); -} - /// Create outlined functions for each `scf.execute_region` operation within /// `region`. static FailureOr> @@ -302,12 +329,14 @@ createFunctionsFromRegions(RewriterBase &rewriter, Region ®ion, } if (auto group = dyn_cast(op)) { - if (failed(outlineTensorRTRegion(rewriter, group))) + if (failed(outlineTensorRTRegion(rewriter, + group))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto allocGroup = dyn_cast(op)) { - if (failed(outlineTensorRTRegion(rewriter, allocGroup))) + if (failed(outlineTensorRTRegion( + rewriter, allocGroup))) return WalkResult::interrupt(); return WalkResult::advance(); } diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/Passes.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/Passes.cpp index 753bea88e..1a0a6415d 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/Passes.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/Passes.cpp @@ -24,6 +24,7 @@ //===----------------------------------------------------------------------===// #include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h" #include "mlir-tensorrt/Transforms/Passes.h" +#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" #include "mlir/Dialect/Bufferization/Pipelines/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" @@ -80,6 +81,7 @@ void plan::buildPlanBufferDeallocationPipeline( pm.addPass(createCanonicalizerPass()); pm.addPass(bufferization::createBufferDeallocationSimplificationPass()); pm.addPass(bufferization::createLowerDeallocationsPass()); + pm.addPass(mlir::createBufferizationToMemRefPass()); pm.addPass(createCSEPass()); pm.addPass(createCanonicalizerPass()); } @@ -103,31 +105,21 @@ struct ClusteringPipelineCliOpts llvm::cl::init(NV_TENSORRT_MAJOR)}; }; -struct PlanBufferizationPipelineCliOpts - : public PassPipelineOptions { - Option enableNonDPSReturns{ - *this, "enable-non-dps-returns", - llvm::cl::desc("allow backend clusters to directly allocate outputs"), - llvm::cl::init(false)}; -}; - } // namespace // Register pipelines. void plan::registerPlanDialectPipelines() { - PassPipelineRegistration - executorBufferizationPipeline( - "plan-bufferize-pipeline", - "perform bufferization and standard pre/post processing passes", - [](OpPassManager &pm, const PlanBufferizationPipelineCliOpts &opts) { - PlanAllocTensorsPassOptions allocTensorOpts{}; - allocTensorOpts.enableNonDPSReturns = opts.enableNonDPSReturns; - buildPlanBufferizationPipeline(pm, allocTensorOpts); - buildPlanBufferOptimizationPipeline(pm); - buildPlanBufferDeallocationPipeline( - pm, bufferization::DeallocationOptions{false}); - }); + PassPipelineRegistration<> executorBufferizationPipeline( + "plan-bufferize-pipeline", + "perform bufferization and standard pre/post processing passes", + [](OpPassManager &pm) { + PlanAllocTensorsPassOptions allocTensorOpts{}; + buildPlanBufferizationPipeline(pm, allocTensorOpts); + buildPlanBufferOptimizationPipeline(pm); + buildPlanBufferDeallocationPipeline( + pm, bufferization::DeallocationOptions{false}); + }); PassPipelineRegistration<> bufferOptPipeline( "plan-buffer-opt-pipeline", "perform post-bufferization optimizations", diff --git a/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h b/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h index 807d3dc23..35041e312 100644 --- a/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h +++ b/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h @@ -53,7 +53,7 @@ extern "C" { /// caller must be sure to delete errors via mtrtStatusDestroy. //===----------------------------------------------------------------------===// -typedef struct MTRT_RuntimeClient MTRT_Runtimeclient; +typedef struct MTRT_RuntimeClient MTRT_RuntimeClient; //===----------------------------------------------------------------------===// // MTRT_GlobalDebug @@ -87,7 +87,7 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtStreamCreate(MTRT_Stream *stream); static inline bool mtrtStreamIsNull(MTRT_Stream stream) { return !stream.ptr; } /// Returns null stream. -static inline MTRT_Stream mtrtStreamGetNull() { return MTRT_Stream{nullptr}; } +static inline MTRT_Stream mtrtStreamGetNull() { return MTRT_Stream{NULL}; } /// Synchronizes `MTRT_Stream` MLIR_CAPI_EXPORTED MTRT_Status mtrtStreamSynchronize(MTRT_Stream stream); @@ -108,7 +108,7 @@ static inline bool mtrtDeviceIsNull(MTRT_Device device) { return !device.ptr; } /// Return a null MTRT_Device. This should be used where MTRT_Device input /// arguments are optional in functions below. -static inline MTRT_Device mtrtDeviceGetNull() { return MTRT_Device{nullptr}; } +static inline MTRT_Device mtrtDeviceGetNull() { return MTRT_Device{NULL}; } //===----------------------------------------------------------------------===// // MTRT_MemRefValue @@ -215,6 +215,11 @@ static inline bool mtrtRuntimeClientIsNull(MTRT_RuntimeClient client) { return !client.ptr; } +/// Returns null client. +static inline MTRT_RuntimeClient mtrtRuntimeClientGetNull() { + return MTRT_RuntimeClient{NULL}; +} + /// Creates a `MTRT_RuntimeClient`. Client must be alive for the lifetime of the /// program execution. /// The `stream` passed to the client is used by all underlying CUDA methods @@ -308,6 +313,12 @@ static inline bool mtrtRuntimeValueIsNull(MTRT_RuntimeValue value) { return !value.ptr; } +// Returns whether the RuntimeValue is MemRef. +MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value); + +// Returns whether the RuntimeValue is Scalar. +MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value); + /// Cast a MTRT_MemRefValue to a generic MTRT_RuntimeValue. MLIR_CAPI_EXPORTED MTRT_RuntimeValue mtrtMemRefCastToRuntimeValue(MTRT_MemRefValue memref); @@ -338,6 +349,9 @@ mtrtScalarValueCastToRuntimeValue(MTRT_ScalarValue v); MLIR_CAPI_EXPORTED MTRT_Status mtrtScalarValueGetType(MTRT_ScalarValue scalar, MTRT_ScalarTypeCode *code); +MLIR_CAPI_EXPORTED MTRT_Status mtrtScalarValueGet(MTRT_ScalarValue scalar, + int64_t *data); + //===----------------------------------------------------------------------===// // MTRT_RuntimeSessionOptions //===----------------------------------------------------------------------===// @@ -391,16 +405,27 @@ static inline bool mtrtRuntimeSessionIsNull(MTRT_RuntimeSession session) { return !session.ptr; } -/// Using `session`, execute the pubic function with the specified name. -/// The `inArgs` and `outArgs` are arrays for input arguments and destination -/// arguments, respectively. Input arguments may be MemRefs or scalars, but -/// destination arguments must be MemRefs. +/// Using `session`, execute the public function with the specified name. +/// The `inArgs`, `outArgs`, and `results` are arrays for input arguments, +/// output arguments, and return values, respectively. Arguments and results +/// can be MemRefs, scalars, or other supported types. Both `outArgs` and +/// `results` can be used simultaneously, allowing for functions that both +/// modify arguments and return values. /// A stream may optionally be specified, otherwise pass the result of /// `mtrtStreamGetNull()`. +/// +/// The `results` array must point to an array with at least the number of +/// elements returned by mtrtRuntimeSessionGetNumResults for the given function. MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionExecuteFunction( MTRT_RuntimeSession session, MTRT_StringView name, const MTRT_RuntimeValue *inArgs, size_t numInArgs, - const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream); + const MTRT_RuntimeValue *outArgs, size_t numOutArgs, + MTRT_RuntimeValue *results, MTRT_Stream stream, MTRT_RuntimeClient client); + +/// Return number of results given a function name. Function name refers +/// to an exported function in the executable. +MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionGetNumResults( + MTRT_RuntimeSession session, MTRT_StringView name, int64_t *numResults); //===----------------------------------------------------------------------===// // DLPack diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h index 96aec51db..672a3f915 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h @@ -427,7 +427,7 @@ class ExecutableView { /// Return a function by name. This asserts that the function with the given /// name exists. - FunctionView getFunction(std::string_view name) const; + StatusOr getFunction(std::string_view name) const; ConstantView getConstant(int64_t idx) const { assert(view->constants() && "expected valid constant pointer"); diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h index 88c616bc7..a58b1d022 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h @@ -104,13 +104,6 @@ executeFunctionWithLuaBackend(LuaRuntimeSession &session, std::string_view name, std::optional stream = {}, std::optional client = {}); -// Parses the results of a function call, handling both scalar and MemRef return -// types -StatusOr>> -parseResults(const sol::protected_function_result &pfr, - const FunctionSignatureView &sig, - std::optional client); - } // namespace mlirtrt::runtime #endif // MLIR_TENSORRT_RUNTIME_BACKEND_LUA_LUARUNTIME_H diff --git a/mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h b/mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h index cf5bd3169..7df229832 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h @@ -104,6 +104,9 @@ class PinnedMemoryAllocator { PinnedMemoryAllocator(); ~PinnedMemoryAllocator(); + /// Untracks + void untrack(uintptr_t ptr); + StatusOr allocate(size_t size); /// Free the block associated with the given pointer on the given stream. An @@ -114,6 +117,9 @@ class PinnedMemoryAllocator { private: EventPool eventPool; + /// Tracks all the pointers which need not to freed up. + static std::vector untrackedPtrs; + /// Tracks all blocks allocated by the allocator. struct BlockTracker; std::unique_ptr blockTracker; diff --git a/mlir-tensorrt/executor/lib/CAPI/Common/Common.cpp b/mlir-tensorrt/executor/lib/CAPI/Common/Common.cpp index 46bd564e3..d66659c36 100644 --- a/mlir-tensorrt/executor/lib/CAPI/Common/Common.cpp +++ b/mlir-tensorrt/executor/lib/CAPI/Common/Common.cpp @@ -344,7 +344,7 @@ MTRT_Status mtrtBoundsGetMax(MTRT_Bounds bounds, MTRT_ArrayRefI64 *maxBounds) { MTRT_FunctionSignature mtrtGetFunctionSignature(MTRT_Executable exec, const char *name) { auto sig = const_cast( - unwrap(exec)->getFunction(name).getSignature().view); + (*unwrap(exec)->getFunction(name)).getSignature().view); return wrap(sig); } diff --git a/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp b/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp index fa8bc850e..d2a1d8f62 100644 --- a/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp +++ b/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp @@ -37,6 +37,17 @@ #include "cuda_runtime_api.h" #endif +#if defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" +#endif +#include "cuda_bf16.h" +#include "cuda_fp16.h" +#include "cuda_fp8.h" +#if defined(__clang__) +#pragma GCC diagnostic pop +#endif + struct MTRT_StreamImpl; #define DEFINE_C_API_PTR_METHODS(name, cpptype) \ @@ -682,6 +693,16 @@ MTRT_ScalarValue mtrtRuntimeValueDynCastToScalar(MTRT_RuntimeValue v) { return wrap(static_cast(x)); } +bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value) { + RuntimeValue *x = unwrap(value); + return x->getKind() == RuntimeValue::Kind::MemRef; +} + +bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value) { + RuntimeValue *x = unwrap(value); + return x->getKind() == RuntimeValue::Kind::Scalar; +} + //===----------------------------------------------------------------------===// // MTRT_RuntimeSessionOptions //===----------------------------------------------------------------------===// @@ -728,7 +749,8 @@ MTRT_Status mtrtRuntimeSessionDestroy(MTRT_RuntimeSession session) { MTRT_Status mtrtRuntimeSessionExecuteFunction( MTRT_RuntimeSession session, MTRT_StringView name, const MTRT_RuntimeValue *inArgs, size_t numInArgs, - const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream) { + const MTRT_RuntimeValue *outArgs, size_t numOutArgs, + MTRT_RuntimeValue *results, MTRT_Stream stream, MTRT_RuntimeClient client) { LuaRuntimeSession *cppSession = static_cast(unwrap(session)); @@ -738,19 +760,38 @@ MTRT_Status mtrtRuntimeSessionExecuteFunction( llvm::SmallVector outArgValues = llvm::map_to_vector(llvm::ArrayRef(outArgs, numOutArgs), [](MTRT_RuntimeValue arg) { return unwrap(arg); }); - - StatusOr>> result = + StatusOr>> resultValues = executeFunctionWithLuaBackend( *cppSession, std::string_view(name.data, name.length), inArgValues, outArgValues, !mtrtStreamIsNull(stream) ? std::optional(unwrap(stream)->getRawStream()) - : std::nullopt); - if (!result.isOk()) - return wrap(result.getStatus()); + : std::nullopt, + !mtrtRuntimeClientIsNull(client) ? std::optional(unwrap(client)) + : std::nullopt); + if (!resultValues.isOk()) + return wrap(resultValues.getStatus()); + + for (size_t i = 0; i < resultValues->size(); ++i) + results[i] = wrap((*resultValues)[i].release()); return mtrtStatusGetOk(); } + +MTRT_Status mtrtRuntimeSessionGetNumResults(MTRT_RuntimeSession session, + MTRT_StringView name, + int64_t *numResults) { + LuaRuntimeSession *cppSession = + static_cast(unwrap(session)); + StatusOr func = cppSession->getExecutable().getFunction( + std::string_view(name.data, name.length)); + if (func.isError()) { + return wrap(func.getStatus()); + } + *numResults = (*func).getSignature().getNumResults(); + return mtrtStatusGetOk(); +} + //===----------------------------------------------------------------------===// // MTRT_RuntimeClient //===----------------------------------------------------------------------===// @@ -796,3 +837,51 @@ MTRT_Status mtrtScalarValueGetType(MTRT_ScalarValue scalar, *code = static_cast(cppScalar->getType().getCode()); return mtrtStatusGetOk(); } + +MTRT_Status mtrtScalarValueGet(MTRT_ScalarValue scalar, int64_t *data) { + ScalarValue *cppScalar = unwrap(scalar); + ScalarTypeCode code = cppScalar->getType().getCode(); + switch (code) { + case ScalarTypeCode::f8e4m3fn: + *data = static_cast(cppScalar->get<__nv_fp8_e4m3>()); + break; + case ScalarTypeCode::f16: + *data = static_cast(cppScalar->get<__half>()); + break; + case ScalarTypeCode::bf16: + *data = static_cast(cppScalar->get()); + break; + case ScalarTypeCode::f32: + *data = static_cast(cppScalar->get()); + break; + case ScalarTypeCode::f64: + *data = static_cast(cppScalar->get()); + break; + case ScalarTypeCode::i1: + *data = static_cast(cppScalar->get()); + break; + case ScalarTypeCode::i4: + *data = static_cast(cppScalar->get()); + break; + case ScalarTypeCode::i8: + *data = static_cast(cppScalar->get()); + break; + case ScalarTypeCode::ui8: + *data = static_cast(cppScalar->get()); + break; + case ScalarTypeCode::i16: + *data = static_cast(cppScalar->get()); + break; + case ScalarTypeCode::i32: + *data = static_cast(cppScalar->get()); + break; + case ScalarTypeCode::i64: + *data = cppScalar->get(); + break; + default: + return wrap(getInvalidArgStatus( + "function input argument with scalar type {0} is unsupported", + impl::EnumNameScalarTypeCode(code))); + } + return mtrtStatusGetOk(); +} diff --git a/mlir-tensorrt/executor/lib/Conversion/MemRefToExecutor.cpp b/mlir-tensorrt/executor/lib/Conversion/MemRefToExecutor.cpp index 90e4c9681..34e69188d 100644 --- a/mlir-tensorrt/executor/lib/Conversion/MemRefToExecutor.cpp +++ b/mlir-tensorrt/executor/lib/Conversion/MemRefToExecutor.cpp @@ -548,6 +548,7 @@ void executor::populateMemRefToExecutorPatterns( } namespace { + /// Pass to convert `memref` to `executor` dialect operrations. class ConvertMemRefToExecutorPass : public mlir::executor::impl::ConvertMemRefToExecutorPassBase< @@ -579,6 +580,7 @@ class ConvertMemRefToExecutorPass RewritePatternSet patterns(ctx); executor::populateMemRefToExecutorPatterns( patterns, typeConverter, allowUncheckedMemrefCastConversion); + if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); diff --git a/mlir-tensorrt/executor/lib/Runtime/API/API.cpp b/mlir-tensorrt/executor/lib/Runtime/API/API.cpp index 661c11b4e..4a13ef80a 100644 --- a/mlir-tensorrt/executor/lib/Runtime/API/API.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/API/API.cpp @@ -149,14 +149,17 @@ static bool isHostVisible(PointerType type) { // ExecutableView //===----------------------------------------------------------------------===// -FunctionView ExecutableView::getFunction(std::string_view name) const { +StatusOr +ExecutableView::getFunction(std::string_view name) const { const flatbuffers::Vector> &functions = *view->functions(); auto it = std::find_if(functions.begin(), functions.end(), [&](const impl::Function *x) { return x->name()->string_view() == name; }); - assert(it != view->functions()->end()); + if (it == view->functions()->end()) + return getStatusWithMsg(StatusCode::InvalidArgument, "Function with name (", + name, ") is not present in the executable"); return FunctionView(*it); } @@ -367,6 +370,7 @@ RuntimeSession::RuntimeSession(RuntimeSessionOptions options, //===----------------------------------------------------------------------===// AllocTracker::~AllocTracker() { + MTRT_DBGF("Destroying alloc tracker %p", static_cast(this)); MTRT_DBGF("checking %u allocations", map.size()); llvm::SmallVector ptrsToFree; ptrsToFree.reserve(map.size()); @@ -452,12 +456,19 @@ void AllocTracker::track(PointerInfo info) { // (e.g. function argument), in which case it may have been deallocated, // allowing an internal allocator to pick up that same address. That case is // not an error. - assert((!contains(info.ptr) || get(info.ptr).isExternallyManaged()) && - "an internally managed pointer should not already be tracked"); + if (contains(info.ptr) and get(info.ptr).isInternallyManaged()) { + MTRT_DBGF("Allocator %p: Internally managed pointer 0x%lx should not be " + "already tracked", + static_cast(this), info.ptr); + assert(0 && + "an internally managed pointer should not already be tracked"); + } } - MTRT_DBGF("AllocTracker is now tracking 0x%lx size=%lu space=%s ownership=%s", - info.ptr, info.size, runtime::impl::EnumNamePointerType(info.type), - runtime::impl::EnumNamePointerOwner(info.owner)); + MTRT_DBGF( + "AllocTracker %p is now tracking 0x%lx size=%lx space=%s ownership=%s", + static_cast(this), info.ptr, info.size, + runtime::impl::EnumNamePointerType(info.type), + runtime::impl::EnumNamePointerOwner(info.owner)); auto value = std::make_unique(); value->externalReferenceCount.store(0); value->releasedInternally = false; @@ -487,6 +498,8 @@ void AllocTracker::track(PointerInfo info) { } void AllocTracker::untrack(uintptr_t ptr) { + MTRT_DBGF("AllocTracker %p is now untracking 0x%lx)", + static_cast(this), ptr); assert(llvm::is_contained(map, ptr) && llvm::formatv("Untracked pointer {0}", ptr).str().c_str()); map.erase(map.find(ptr)); @@ -596,7 +609,7 @@ mlirtrt::Status runtime::safeDeallocate(AllocTracker &tracker, uintptr_t ptr, PointerInfo obj = tracker.get(ptr); if (obj.owner == PointerOwner::external) { - MTRT_DBGF("Untracking externally managed pointer 0x%lx", ptr); + MTRT_DBGF("Untracking externally managed 0x%lx", ptr); tracker.untrack(obj.ptr); return mlirtrt::Status::getOk(); } @@ -747,9 +760,16 @@ StatusOr> MemRefValue::create( if (!::getFootprintInBytes(shape, strides, bitsPerElement).isOk()) return getInvalidArgStatus( "only memrefs with non-negative strides are allowed"); - if (!ptr) + + auto isEmptyTensor = [](llvm::ArrayRef shape) -> bool { + return std::any_of(shape.begin(), shape.end(), + [](int64_t s) { return s == 0; }); + }; + + if (!ptr && !isEmptyTensor(shape)) return getInvalidArgStatus( - "MemRef objects must be created with a valid pointer"); + "MemRef objects must be created with a valid pointer for a non-empty " + "tensor"); if (isDeviceVisible(addressSpace) && (!device || !*device)) return getInvalidArgStatus("a specific device must be provided for MemRefs " diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp index aea894610..0ff8f0406 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp @@ -565,37 +565,45 @@ getScalarValue(const sol::protected_function_result &pfr, int index, } } -// Parses the results of a function call, handling both scalar and MemRef return -// types -StatusOr>> -runtime::parseResults(const sol::protected_function_result &pfr, - const FunctionSignatureView &sig, - std::optional client) { +/// Parses the results of a function call, handling both scalar and MemRef +/// return types. +/// +/// @param pfr The protected function result to parse. +/// @param sig The function signature view. +/// @param sessionAllocTracker The allocation tracker for the current session. +/// @param client Optional runtime client pointer. +/// @return A vector of unique pointers to RuntimeValue, or an error status. +static StatusOr>> +parseResults(const sol::protected_function_result &pfr, + const FunctionSignatureView &sig, LuaRuntimeSession &session, + std::optional client) { llvm::SmallVector> results; + results.reserve(sig.getNumResults()); + for (unsigned i = 0; i < sig.getNumResults(); ++i) { + const auto &resultType = sig.getResult(i); - if (sig.getResult(i).isa()) { - auto scalar = getScalarValue(pfr, i, sig); - if (!scalar.isOk()) - return scalar.getStatus(); - results.push_back(std::move(*scalar)); + if (resultType.isa()) { + auto scalarValue = getScalarValue(pfr, i, sig); + if (!scalarValue.isOk()) + return scalarValue.getStatus(); + results.push_back(std::move(*scalarValue)); continue; } - MemRefTableReader reader(pfr, i); - - if (!sig.getResult(i).isa()) + if (!resultType.isa()) return getInvalidArgStatus("Result can only be a memref or scalar"); // Handle MemRef return values - const auto &resultView = sig.getResult(i).get(); - unsigned rank = resultView.getRank(); + const auto &memRefView = resultType.get(); + MemRefTableReader reader(pfr, i); // Extract MemRef metadata uintptr_t allocPtr = reader.getNextValue(); [[maybe_unused]] uintptr_t alignedPtr = reader.getNextValue(); int64_t offset = reader.getNextValue(); + unsigned rank = memRefView.getRank(); llvm::SmallVector shape(rank); llvm::SmallVector strides(rank); @@ -608,15 +616,43 @@ runtime::parseResults(const sol::protected_function_result &pfr, if (!client) return getInvalidArgStatus("Runtime client cannot be nullptr"); - // Create MemRefValue from extracted data - auto memref = (*client)->createExternalMemRef( - resultView.getAddressSpace(), resultView.getElementType().getBitWidth(), - allocPtr, offset, shape, strides, (*client)->getDevices()[0].get(), - resultView.getElementType()); + // Create an external MemRef and track it in both session and client + // allocation trackers + MTRT_DBGF("Creating external MemRef for ptr 0x%lx: " + "Session alloc tracker: %p, Session pinner memory allocator: %p, " + "Client: %p, Client tracker: %p. " + "This ptr is registered with the session and will now be tracked " + "by the client as well.", + allocPtr, static_cast(&session.getAllocTracker()), + static_cast(&session.getPinnedMemorAllocator()), + static_cast(*client), + static_cast(&(*client)->getAllocTracker())); + + // We need here actually is to "release" the pointer from the session + // ownership and have the client assume + PointerInfo info = session.getAllocTracker().get(allocPtr); + session.getAllocTracker().untrack(info.ptr); + + // It is possible that pinned memory also tracks the memory for + // deallocation. + session.getPinnedMemorAllocator().untrack(info.ptr); + + AllocTracker &allocator = (*client)->getAllocTracker(); + // if (!allocator.contains(info.ptr)) + allocator.track(info); + + // Create a memref so that client now tracks it. + auto memref = MemRefValue::create( + *client, memRefView.getAddressSpace(), + memRefView.getElementType().getBitWidth(), allocPtr, offset, shape, + strides, (*client)->getDevices()[0].get(), memRefView.getElementType()); if (!memref.isOk()) return memref.getStatus(); + // Increment external reference count since we are returning a memref + allocator.incrementExternalCount(info.ptr); + results.push_back(std::move(*memref)); } @@ -630,8 +666,11 @@ runtime::executeFunctionWithLuaBackend( llvm::ArrayRef outputArgs, std::optional stream, std::optional client) { - FunctionView meta = session.getExecutable().getFunction(name); - FunctionSignatureView sig = meta.getSignature(); + StatusOr func = session.getExecutable().getFunction(name); + if (func.isError()) + return func.getStatus(); + + FunctionSignatureView sig = (*func).getSignature(); // Call the main function, if present. sol::state &lua = session.getLuaState(); @@ -715,5 +754,5 @@ runtime::executeFunctionWithLuaBackend( "\": ", err.what()); } - return parseResults(pfr, sig, client); + return parseResults(pfr, sig, session, client); } diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp index b24355eb3..943493f99 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp @@ -141,9 +141,12 @@ class OutputAllocatorImpl : public nvinfer1::IOutputAllocator { size = std::max(size, static_cast(1)); if (size > mOutputSize) { size = roundUp(size, alignment); - if (mOutputPtr) + if (mOutputPtr) { + MTRT_DBGF("tensorrt module output allocator deallocating 0x%lx", + mOutputPtr); mlirtrt::runtime::safeDeallocate(*mTracker, mOutputPtr, CudaStreamPtr(stream)); + } mOutputPtr = 0; mOutputSize = 0; StatusOr memory = @@ -152,6 +155,9 @@ class OutputAllocatorImpl : public nvinfer1::IOutputAllocator { if (memory.isOk()) { mOutputPtr = (*memory).ptr; mOutputSize = memory->size; + MTRT_DBGF( + "tensorrt module output allocator allocating %lu bytes at 0x%lx", + mOutputSize, mOutputPtr); } return reinterpret_cast(mOutputPtr); } diff --git a/mlir-tensorrt/executor/lib/Support/Allocators.cpp b/mlir-tensorrt/executor/lib/Support/Allocators.cpp index ce7310ad7..1a542e187 100644 --- a/mlir-tensorrt/executor/lib/Support/Allocators.cpp +++ b/mlir-tensorrt/executor/lib/Support/Allocators.cpp @@ -206,6 +206,8 @@ static void cudaFreeHostWrapper(uintptr_t ptr) { #endif } +std::vector PinnedMemoryAllocator::untrackedPtrs; + struct PinnedMemoryAllocator::BlockTracker { std::set blocks; llvm::DenseMap pointerToBlock; @@ -216,9 +218,13 @@ struct PinnedMemoryAllocator::BlockTracker { "[PinnedMemoryAllocator] Releasing block tracker that has %lu blocks", blocks.size()); for (Block *block : blocks) { - ALLOC_DBGF("[PinnedMemoryAllocator] releasing block %lu of size %lu", - block->ptr, block->size); - cudaFreeHostWrapper(block->ptr); + if (std::find(PinnedMemoryAllocator::untrackedPtrs.begin(), + PinnedMemoryAllocator::untrackedPtrs.end(), + block->ptr) == PinnedMemoryAllocator::untrackedPtrs.end()) { + ALLOC_DBGF("[PinnedMemoryAllocator] releasing block %lu of size %lu", + block->ptr, block->size); + cudaFreeHostWrapper(block->ptr); + } } } }; @@ -269,6 +275,13 @@ StatusOr PinnedMemoryAllocator::allocate(size_t size) { #endif } +// Free the given block. +void PinnedMemoryAllocator::untrack(uintptr_t ptr) { + if (!llvm::is_contained(untrackedPtrs, ptr)) { + untrackedPtrs.emplace_back(ptr); + } +} + // Free the given block. Status PinnedMemoryAllocator::freeAsync(uintptr_t ptr, CudaStream stream) { #ifdef MLIR_EXECUTOR_ENABLE_CUDA @@ -296,4 +309,4 @@ Status PinnedMemoryAllocator::freeAsync(uintptr_t ptr, CudaStream stream) { return getInternalErrorStatus( "MLIR-Executor was not built with CUDA enabled"); #endif -} \ No newline at end of file +} diff --git a/mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp b/mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp index 9e8a9e50a..869373d44 100644 --- a/mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp +++ b/mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp @@ -53,6 +53,7 @@ class ExecutorBufferizationTestPass } } }; + } // namespace namespace mlir::executor { diff --git a/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp b/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp index ddb0c74e2..f35085b85 100644 --- a/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp +++ b/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp @@ -244,7 +244,7 @@ class PyRuntimeClient using Base::Base; DECLARE_WRAPPER_CONSTRUCTORS(PyRuntimeClient); - static constexpr auto kMethodTable = CAPITable{ + static constexpr auto kMethodTable = CAPITable{ mtrtRuntimeClientIsNull, mtrtRuntimeClientDestroy}; }; @@ -600,6 +600,15 @@ static MTRT_RuntimeValue convertArgType(py::object obj) { throw std::runtime_error("argument must be MemRef or scalar"); } +/// Convert Runtime value to PyMemRefValue or PyScalarValue object. +static py::object convertGenericArgToPyObject(MTRT_RuntimeValue value) { + if (mtrtRuntimeValueIsMemRef(value)) + return py::cast(mtrtRuntimeValueDynCastToMemRef(value)); + if (mtrtRuntimeValueIsScalar(value)) + return py::cast(mtrtRuntimeValueDynCastToScalar(value)); + throw std::runtime_error("argument must be MemRef or scalar"); +} + //===----------------------------------------------------------------------===// // Declare the bindings. //===----------------------------------------------------------------------===// @@ -615,11 +624,19 @@ PYBIND11_MODULE(_api, m) { py::buffer_protocol()) .def_property_readonly(MTRT_PYTHON_CAPI_PTR_ATTR, &PyScalarValue::getCapsule) - .def_property_readonly("type", [](PyScalarValue &self) { - MTRT_ScalarTypeCode code; - MTRT_Status s = mtrtScalarValueGetType(self, &code); + .def_property_readonly("type", + [](PyScalarValue &self) { + MTRT_ScalarTypeCode code; + MTRT_Status s = + mtrtScalarValueGetType(self, &code); + THROW_IF_MTRT_ERROR(s); + return code; + }) + .def_property_readonly("data", [](PyScalarValue &self) { + int64_t data; + MTRT_Status s = mtrtScalarValueGet(self, &data); THROW_IF_MTRT_ERROR(s); - return code; + return data; }); py::class_(m, "MemRefValue", py::module_local(), py::buffer_protocol()) @@ -950,22 +967,45 @@ PYBIND11_MODULE(_api, m) { .def( "execute_function", [](PyRuntimeSession &self, std::string name, - std::vector inArgs, std::vector outArgs, - std::optional stream) { + std::vector inArgs, + std::optional> outArgs, + std::optional stream, + PyRuntimeClient *client = nullptr) { MTRT_StringView nameRef{name.data(), name.size()}; + int64_t numResults; + MTRT_Status s = + mtrtRuntimeSessionGetNumResults(self, nameRef, &numResults); + THROW_IF_MTRT_ERROR(s); + auto inArgsGeneric = llvm::map_to_vector(inArgs, convertArgType); - auto outArgsGeneric = llvm::map_to_vector(outArgs, convertArgType); + auto outArgsGeneric = + outArgs ? llvm::map_to_vector(*outArgs, convertArgType) + : llvm::SmallVector{}; + + std::vector resultsGeneric(numResults); - MTRT_Status s = mtrtRuntimeSessionExecuteFunction( + s = mtrtRuntimeSessionExecuteFunction( self, nameRef, inArgsGeneric.data(), inArgsGeneric.size(), outArgsGeneric.data(), outArgsGeneric.size(), - stream ? *stream : mtrtStreamGetNull()); + resultsGeneric.data(), stream ? *stream : mtrtStreamGetNull(), + client ? MTRT_RuntimeClient(*client) + : mtrtRuntimeClientGetNull()); THROW_IF_MTRT_ERROR(s); - }, - py::arg("name"), py::arg("in_args"), py::arg("out_args"), - py::arg("stream") = py::none()); + std::vector resultPyObject; + if (numResults > 0) { + for (const auto &arg : resultsGeneric) + resultPyObject.push_back(convertGenericArgToPyObject(arg)); + } + + return resultPyObject; + }, + py::arg("name"), py::arg("in_args"), py::arg("out_args") = py::none(), + py::arg("stream") = py::none(), py::arg("client") = nullptr, + "Execute a function given input and optional output arguments. " + "Return optional results as a Python object if output arguments are " + "not present."); py::class_(m, "GlobalDebug", py::module_local()) .def_property_static("flag", &PyGlobalDebugFlag::get, &PyGlobalDebugFlag::set, "LLVM-wide debug flag") @@ -977,4 +1017,4 @@ PYBIND11_MODULE(_api, m) { py::overload_cast &>( &PyGlobalDebugFlag::set_types), "Sets specific debug types to be produced by LLVM"); -} \ No newline at end of file +} diff --git a/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp b/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp index 6ae989205..bea4391b1 100644 --- a/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp +++ b/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp @@ -520,25 +520,25 @@ static void packNonSplatInt4Tensor(ElementsAttr values, int64_t count, } } -static void serializeSplatElements(DenseIntOrFPElementsAttr values, - std::vector &data) { +static LogicalResult serializeSplatElements(DenseIntOrFPElementsAttr values, + std::vector &data) { assert(values.isSplat() && "expected SplatElementsAttr"); auto rtt = cast(values.getType()); if (rtt.getElementType().isInteger(32)) { std::fill_n(reinterpret_cast(data.data()), values.getNumElements(), values.getSplatValue()); - return; + return llvm::success(); } if (rtt.getElementType().isInteger(8)) { std::fill_n(reinterpret_cast(data.data()), values.getNumElements(), values.getSplatValue()); - return; + return llvm::success(); } if (rtt.getElementType().isF32()) { std::fill_n(reinterpret_cast(data.data()), values.getNumElements(), values.getSplatValue()); - return; + return llvm::success(); } if (rtt.getElementType().isF16() || rtt.getElementType().isBF16()) { APInt tmp = values.getSplatValue().bitcastToAPInt(); @@ -546,7 +546,7 @@ static void serializeSplatElements(DenseIntOrFPElementsAttr values, uint16_t fillValue = *reinterpret_cast(tmp.getRawData()); std::fill_n(reinterpret_cast(data.data()), values.getNumElements(), fillValue); - return; + return llvm::success(); } if (rtt.getElementType().isFloat8E4M3FN()) { APInt tmp = values.getSplatValue().bitcastToAPInt(); @@ -554,7 +554,7 @@ static void serializeSplatElements(DenseIntOrFPElementsAttr values, uint8_t fillValue = *reinterpret_cast(tmp.getRawData()); std::fill_n(reinterpret_cast(data.data()), values.getNumElements(), fillValue); - return; + return llvm::success(); } if (rtt.getElementType().isInteger(4)) { APInt tmp = values.getSplatValue(); @@ -566,11 +566,12 @@ static void serializeSplatElements(DenseIntOrFPElementsAttr values, packed |= ((value & 0x0F) << 4); // Fill `data` vector with `packed` std::fill_n(reinterpret_cast(data.data()), data.size(), packed); - return; + return llvm::success(); } - llvm_unreachable("unsupported data type to convert MLIR splat attribute to " - "TensorRT weights!"); + return emitError(UnknownLoc::get(values.getContext())) + << "unsupported data type to convert MLIR splat attribute to TensorRT " + "weights!"; } FailureOr @@ -615,8 +616,10 @@ NvInferNetworkEncoder::getNvInferWeights(ElementsAttr values) { weights.values = data.data(); if (values.isSplat() && isa(values)) { - serializeSplatElements(cast(values), - weightsMap[values]); + LogicalResult status = serializeSplatElements( + cast(values), weightsMap[values]); + if (failed(status)) + return failure(); return weights; } diff --git a/mlir-tensorrt/test/Conversion/TensorRTRuntimeToExecutor/tensorrt-runtime-to-executor.mlir b/mlir-tensorrt/test/Conversion/TensorRTRuntimeToExecutor/tensorrt-runtime-to-executor.mlir index 55e7535d0..28b8900c8 100644 --- a/mlir-tensorrt/test/Conversion/TensorRTRuntimeToExecutor/tensorrt-runtime-to-executor.mlir +++ b/mlir-tensorrt/test/Conversion/TensorRTRuntimeToExecutor/tensorrt-runtime-to-executor.mlir @@ -102,35 +102,93 @@ func.func @main(%arg0: memref<1x3x256x256xf32, #executor.memory_type>) - // CHECK: %[[v8:.*]] = cuda.stream.create : !cuda.stream // CHECK: %[[v9:.*]] = builtin.unrealized_conversion_cast %[[v8]] : !cuda.stream to !executor.ptr // CHECK: %[[v10:.*]] = executor.alloca %[[c1]] x !executor.table : (i64) -> !executor.ptr -// CHECK: %[[v11:.*]] = executor.getoffset[0, 0] : () -> i64, !executor.table -// CHECK: executor.store %[[c1]] to %[[v10]] + %[[v11]] : i64, !executor.ptr, i64 -// CHECK: %[[v12:.*]] = executor.getoffset[0, 1] : () -> i64, !executor.table -// CHECK: executor.store %[[c4]] to %[[v10]] + %[[v12]] : i64, !executor.ptr, i64 +// CHECK: %[[v11:.*]] = executor.getoffset[0, 0] +// CHECK: executor.store %[[c1]] to %[[v10]] + %[[v11]] +// CHECK: %[[v12:.*]] = executor.getoffset[0, 1] +// CHECK: executor.store %[[c4]] to %[[v10]] + %[[v12]] // CHECK: %[[v13:.*]] = executor.table.get %[[v6]][1] : , !executor.ptr, i64, i64, i64, i64, i64, i64, i64, i64, i64> // CHECK: %[[v14:.*]] = executor.table.create(%[[v13]], %[[c0]], %[[c4]], %[[c1]], %[[c3]], %[[c256]], %[[c256]] : !executor.ptr, i64, i64, i64, i64, i64, i64) : , i64, i64, i64, i64, i64, i64> // CHECK: executor.call @_trtrt_enqueue_alloc(%[[v7]], %[[v9]], %[[v10]], %[[v14]]) : (!executor.opaque<"trtrt_context">, !executor.ptr, !executor.ptr, !executor.table, i64, i64, i64, i64, i64, i64>) -> () -// CHECK: %[[v15:.*]] = executor.getoffset[0, 2] : () -> i64, !executor.table -// CHECK: %[[v16:.*]] = executor.load %[[v10]] + %[[v12]] : (!executor.ptr, i64) -> i64 -// CHECK: %[[v17:.*]] = executor.load %[[v10]] + %[[v15]] : (!executor.ptr, i64) -> i64 -// CHECK: %[[v18:.*]] = executor.inttoptr %[[v17]] : (i64) -> !executor.ptr -// CHECK: %[[v19:.*]] = executor.getoffset[0, 3] : () -> i64, !executor.table -// CHECK: %[[v20:.*]] = executor.load %[[v10]] + %[[v19]] : (!executor.ptr, i64) -> i64 -// CHECK: %[[v21:.*]] = executor.getoffset[0, 4] : () -> i64, !executor.table -// CHECK: %[[v22:.*]] = executor.load %[[v10]] + %[[v21]] : (!executor.ptr, i64) -> i64 -// CHECK: %[[v23:.*]] = executor.getoffset[0, 5] : () -> i64, !executor.table -// CHECK: %[[v24:.*]] = executor.load %[[v10]] + %[[v23]] : (!executor.ptr, i64) -> i64 -// CHECK: %[[v25:.*]] = executor.getoffset[0, 6] : () -> i64, !executor.table -// CHECK: %[[v26:.*]] = executor.load %[[v10]] + %[[v25]] : (!executor.ptr, i64) -> i64 -// CHECK: %[[v27:.*]] = executor.getoffset[0, 7] : () -> i64, !executor.table -// CHECK: %[[v28:.*]] = executor.load %[[v10]] + %[[v27]] : (!executor.ptr, i64) -> i64 -// CHECK: %[[v29:.*]] = executor.getoffset[0, 8] : () -> i64, !executor.table -// CHECK: %[[v30:.*]] = executor.load %[[v10]] + %[[v29]] : (!executor.ptr, i64) -> i64 -// CHECK: %[[v31:.*]] = executor.getoffset[0, 9] : () -> i64, !executor.table -// CHECK: %[[v32:.*]] = executor.load %[[v10]] + %[[v31]] : (!executor.ptr, i64) -> i64 -// CHECK: %[[v33:.*]] = executor.getoffset[0, 10] : () -> i64, !executor.table -// CHECK: %[[v34:.*]] = executor.load %[[v10]] + %[[v33]] : (!executor.ptr, i64) -> i64 +// CHECK: %[[v15:.*]] = executor.getoffset[0, 2] +// CHECK: %[[v16:.*]] = executor.load %[[v10]] + %[[v12]] +// CHECK: %[[v17:.*]] = executor.load %[[v10]] + %[[v15]] +// CHECK: %[[v18:.*]] = executor.inttoptr %[[v17]] +// CHECK: %[[v19:.*]] = executor.getoffset[0, 3] +// CHECK: %[[v20:.*]] = executor.load %[[v10]] + %[[v19]] +// CHECK: %[[v21:.*]] = executor.getoffset[0, 4] +// CHECK: %[[v22:.*]] = executor.load %[[v10]] + %[[v21]] +// CHECK: %[[v23:.*]] = executor.getoffset[0, 5] +// CHECK: %[[v24:.*]] = executor.load %[[v10]] + %[[v23]] +// CHECK: %[[v25:.*]] = executor.getoffset[0, 6] +// CHECK: %[[v26:.*]] = executor.load %[[v10]] + %[[v25]] +// CHECK: %[[v27:.*]] = executor.getoffset[0, 7] +// CHECK: %[[v28:.*]] = executor.load %[[v10]] + %[[v27]] +// CHECK: %[[v29:.*]] = executor.getoffset[0, 8] +// CHECK: %[[v30:.*]] = executor.load %[[v10]] + %[[v29]] +// CHECK: %[[v31:.*]] = executor.getoffset[0, 9] +// CHECK: %[[v32:.*]] = executor.load %[[v10]] + %[[v31]] +// CHECK: %[[v33:.*]] = executor.getoffset[0, 10] +// CHECK: %[[v34:.*]] = executor.load %[[v10]] + %[[v33]] // CHECK: %[[v35:.*]] = executor.table.create(%[[v18]], %[[v18]], %[[c0]], %[[v20]], %[[v22]], %[[v24]], %[[v26]], %[[v28]], %[[v30]], %[[v32]], %[[v34]] : !executor.ptr, !executor.ptr, i64, i64, i64, i64, i64, i64, i64, i64, i64) : , !executor.ptr, i64, i64, i64, i64, i64, i64, i64, i64, i64> // CHECK: %[[v36:.*]] = builtin.unrealized_conversion_cast %[[v35]] : !executor.table, !executor.ptr, i64, i64, i64, i64, i64, i64, i64, i64, i64> to memref> // CHECK: cuda.stream.sync %[[v8]] : !cuda.stream // CHECK: return %[[v36]] : memref> -// CHECK: } \ No newline at end of file +// CHECK: } + +// ----- + +func.func @main(%arg0: memref, %arg1: memref, %context: !trtrt.context, %stream: !cuda.stream) -> (memref, memref) attributes {executor.function_metadata = #executor.func_meta<[memref {#executor.dim_bounds}, memref {#executor.dim_bounds}], [memref {#executor.dim_bounds}, memref {#executor.dim_bounds}], num_output_args = 0>} { + %2:2 = trtrt.enqueue_alloc %context stream(%stream) (%arg1, %arg0) : (memref, memref) -> (memref, memref) + return %2#0, %2#1 : memref, memref +} + +// CHECK-LABEL: module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry, #dlti.dl_entry, 64 : i64>, #dlti.dl_entry, 64 : i64>>} { +// executor.func private @_trtrt_enqueue_alloc(!executor.opaque<"trtrt_context">, !executor.ptr, !executor.ptr, ...) +// CHECK-LABEL: func.func @main +// CHECK-SAME: (%[[arg0:.+]]: memref, %[[arg1:.+]]: memref, %[[arg2:.+]]: !trtrt.context, %[[arg3:.+]]: !cuda.stream) -> (memref, memref) attributes {executor.function_metadata = #executor.func_meta<[memref {#executor.dim_bounds}, memref {#executor.dim_bounds}], [memref {#executor.dim_bounds}, memref {#executor.dim_bounds}], num_output_args = 0>} { +// CHECK-DAG: %[[c2:.+]] = executor.constant 2 : i64 +// CHECK-DAG: %[[c0:.+]] = executor.constant 0 : i64 +// CHECK-DAG: %[[c1:.+]] = executor.constant 1 : i64 +// CHECK: %[[v0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] +// CHECK: %[[v1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] +// CHECK: %[[v2:.+]] = builtin.unrealized_conversion_cast %[[arg3]] +// CHECK: %[[v3:.+]] = builtin.unrealized_conversion_cast %[[arg2]] +// CHECK: %[[v4:.+]] = executor.alloca %[[c1]] x !executor.table +// CHECK: %[[v5:.+]] = executor.getoffset[0, 0] +// CHECK: executor.store %[[c2]] to %[[v4]] + %[[v5]] +// CHECK: %[[v6:.+]] = executor.getoffset[0, 1] +// CHECK: executor.store %[[c1]] to %[[v4]] + %[[v6]] +// CHECK: %[[v7:.+]] = executor.getoffset[0, 5] +// CHECK: executor.store %[[c2]] to %[[v4]] + %[[v7]] +// CHECK: %[[v8:.+]] = executor.table.get %[[v1]][1] +// CHECK: %[[v9:.+]] = executor.table.get %[[v1]][3] +// CHECK: %[[v10:.+]] = executor.table.get %[[v1]][4] +// CHECK: %[[v11:.+]] = executor.table.get %[[v0]][1] +// CHECK: %[[v12:.+]] = executor.table.get %[[v0]][3] +// CHECK: %[[v13:.+]] = executor.table.create(%[[v8]], %[[c0]], %[[c2]], %[[v9]], %[[v10]], %[[v11]], %[[c0]], %[[c1]], %[[v12]] : !executor.ptr, i64, i64, i64, i64, !executor.ptr, i64, i64, i64) +// CHECK: executor.call @_trtrt_enqueue_alloc(%[[v3]], %[[v2]], %[[v4]], %[[v13]]) +// CHECK: %[[v14:.+]] = executor.getoffset[0, 2] +// CHECK: %[[v15:.+]] = executor.load %[[v4]] + %[[v6]] +// CHECK: %[[v16:.+]] = executor.load %[[v4]] + %[[v14]] +// CHECK: %[[v17:.+]] = executor.inttoptr %[[v16]] +// CHECK: %[[v18:.+]] = executor.getoffset[0, 3] +// CHECK: %[[v19:.+]] = executor.load %[[v4]] + %[[v18]] +// CHECK: %[[v20:.+]] = executor.getoffset[0, 4] +// CHECK: %[[v21:.+]] = executor.load %[[v4]] + %[[v20]] +// CHECK: %[[v22:.+]] = executor.table.create(%[[v17]], %[[v17]], %[[c0]], %[[v19]], %[[v21]] : !executor.ptr, !executor.ptr, i64, i64, i64) +// CHECK: %[[v23:.+]] = executor.getoffset[0, 6] +// CHECK: %[[v24:.+]] = executor.load %[[v4]] + %[[v7]] +// CHECK: %[[v25:.+]] = executor.load %[[v4]] + %[[v23]] +// CHECK: %[[v26:.+]] = executor.inttoptr %[[v25]] +// CHECK: %[[v27:.+]] = executor.getoffset[0, 7] +// CHECK: %[[v28:.+]] = executor.load %[[v4]] + %[[v27]] +// CHECK: %[[v29:.+]] = executor.getoffset[0, 8] +// CHECK: %[[v30:.+]] = executor.load %[[v4]] + %[[v29]] +// CHECK: %[[v31:.+]] = executor.getoffset[0, 9] +// CHECK: %[[v32:.+]] = executor.load %[[v4]] + %[[v31]] +// CHECK: %[[v33:.+]] = executor.getoffset[0, 10] +// CHECK: %[[v34:.+]] = executor.load %[[v4]] + %[[v33]] +// CHECK: %[[v35:.+]] = executor.table.create(%[[v26]], %[[v26]], %[[c0]], %[[v28]], %[[v30]], %[[v32]], %[[v34]] : !executor.ptr, !executor.ptr, i64, i64, i64, i64, i64) +// CHECK: %[[v36:.+]] = builtin.unrealized_conversion_cast %[[v35]] +// CHECK: %[[v37:.+]] = builtin.unrealized_conversion_cast %[[v22]] +// CHECK: return %[[v37]], %[[v36]] : memref, memref \ No newline at end of file diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_non_dps_cconv.py b/mlir-tensorrt/test/python/IntegrationTests/test_non_dps_cconv.py new file mode 100644 index 000000000..242275c29 --- /dev/null +++ b/mlir-tensorrt/test/python/IntegrationTests/test_non_dps_cconv.py @@ -0,0 +1,323 @@ +# RUN: %PYTHON %s +import time + +import mlir_tensorrt.compiler.api as compiler +import mlir_tensorrt.compiler.ir as ir +import mlir_tensorrt.runtime.api as runtime +import numpy as np + +single_return = """ +func.func @main(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + %1 = stablehlo.add %arg0, %arg0 : (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + func.return %1 : tensor<2x3x4xf32> +} +""" + +scalar_return = """ +func.func @main(%arg0: tensor<2x3x4xf32>) -> index { + %1 = tensor.rank %arg0 : tensor<2x3x4xf32> + func.return %1 : index +} +""" + +mixed_return = """ +func.func @main(%arg0: tensor<2x3x4xf32>) -> (tensor<2x3x4xf32>, index) { + %1 = stablehlo.add %arg0, %arg0 : (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + %2 = tensor.rank %1 : tensor<2x3x4xf32> + func.return %1, %2 : tensor<2x3x4xf32>, index +} +""" + +multiple_return = """ +func.func @main(%arg0: tensor<2x3x4xf32>) -> (tensor<2x3x4xf32>, tensor<2x3x4xf32>) { + %1 = stablehlo.add %arg0, %arg0 : (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + %2 = stablehlo.add %arg0, %1 : (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + func.return %1, %2 : tensor<2x3x4xf32>, tensor<2x3x4xf32> +} +""" + +dynamic_shape = """ +func.func @main(%arg0: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}, + %arg1: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}) + -> tensor { + %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor) -> tensor + %1 = stablehlo.reshape %0 : (tensor) -> tensor<1xi32> + %2 = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor) -> tensor + %3 = stablehlo.reshape %2 : (tensor) -> tensor<1xi32> + %4 = stablehlo.concatenate %1, %3, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %5 = stablehlo.get_dimension_size %arg1, dim = 0 : (tensor) -> tensor + %6 = stablehlo.reshape %5 : (tensor) -> tensor<1xi32> + %7 = stablehlo.get_dimension_size %arg1, dim = 1 : (tensor) -> tensor + %8 = stablehlo.reshape %7 : (tensor) -> tensor<1xi32> + %9 = stablehlo.concatenate %6, %8, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %10 = stablehlo.maximum %4, %9 : tensor<2xi32> + %11 = stablehlo.dynamic_broadcast_in_dim %arg0, %10, dims = [0, 1] : (tensor, tensor<2xi32>) -> tensor + %12 = stablehlo.dynamic_broadcast_in_dim %arg1, %10, dims = [0, 1] : (tensor, tensor<2xi32>) -> tensor + %13 = stablehlo.add %11, %12 : tensor + return %13 : tensor +} +""" + +session_tracking_h2h = """ +func.func @main() -> (tensor> {tensorrt.host_tensor}) { + %c = stablehlo.constant dense<[1, 2]> : tensor<2xi32> + %0 = bufferization.alloc_tensor() {memory_space = #plan.memory_space} : tensor<2xi32, #plan.memory_space> + %1 = bufferization.materialize_in_destination %c in %0 : (tensor<2xi32>, tensor<2xi32, #plan.memory_space>) -> tensor<2xi32, #plan.memory_space> + %cast = tensor.cast %1 : tensor<2xi32, #plan.memory_space> to tensor> + return %cast : tensor> +} +""" + +empty_shape_tensor = """ +func.func @main() -> (tensor> {tensorrt.host_tensor}) { + %c = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> + %c_0 = stablehlo.constant dense<2> : tensor + %c_1 = stablehlo.constant dense<1> : tensor<1xi32> + %c_2 = stablehlo.constant dense<2> : tensor<1xi32> + %c_3 = stablehlo.constant dense<2> : tensor + %c_4 = stablehlo.constant dense<2> : tensor<1xi32> + %0 = stablehlo.concatenate %c_2, %c_4, %c_1, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> + %1 = stablehlo.dynamic_reshape %c, %0 : (tensor<2x2xi32>, tensor<3xi32>) -> tensor + %c_5 = stablehlo.constant dense<2> : tensor + %c_6 = stablehlo.constant dense<2> : tensor<1xi32> + %c_7 = stablehlo.constant dense<2> : tensor + %c_8 = stablehlo.constant dense<2> : tensor<1xi32> + %c_9 = stablehlo.constant dense<0> : tensor + %c_10 = stablehlo.constant dense<0> : tensor<1xi32> + %2 = stablehlo.concatenate %c_6, %c_8, %c_10, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> + %3 = stablehlo.dynamic_broadcast_in_dim %1, %2, dims = [0, 1, 2] : (tensor, tensor<3xi32>) -> tensor + %c_11 = stablehlo.constant dense<2> : tensor<1xi32> + %c_12 = stablehlo.constant dense<> : tensor<0xi32> + %c_13 = stablehlo.constant dense<> : tensor<0xi32> + %4 = stablehlo.compare EQ, %c_12, %c_13 : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi1> + %5 = stablehlo.select %4, %c_12, %c_12 : tensor<0xi1>, tensor<0xi32> + %6 = stablehlo.dynamic_broadcast_in_dim %c_7, %5, dims = [] : (tensor, tensor<0xi32>) -> tensor + %7 = stablehlo.dynamic_broadcast_in_dim %c_9, %5, dims = [] : (tensor, tensor<0xi32>) -> tensor + %8 = stablehlo.multiply %6, %7 : tensor + %9 = stablehlo.reshape %8 : (tensor) -> tensor<1xi32> + %10 = stablehlo.concatenate %c_11, %9, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %11 = stablehlo.dynamic_reshape %3, %10 : (tensor, tensor<2xi32>) -> tensor + %c0 = arith.constant 0 : index + %dim = tensor.dim %11, %c0 : tensor + %c1 = arith.constant 1 : index + %dim_14 = tensor.dim %11, %c1 : tensor + %12 = bufferization.alloc_tensor(%dim, %dim_14) {memory_space = #plan.memory_space} : tensor> + %13 = bufferization.materialize_in_destination %11 in %12 : (tensor, tensor>) -> tensor> + %cast = tensor.cast %13 : tensor> to tensor> + return %cast : tensor> +} +""" + + +# The RuntimeClient can and should persist across multiple Executables, RuntimeSessions, etc. +# It is primarily an interface for creating and manipulating buffers. +client = runtime.RuntimeClient() +stream = client.create_stream() +devices = client.get_devices() + + +def compile_executable(program, debug=False): + # Build/parse the main function. + with ir.Context() as context: + m = ir.Module.parse(program) + + # Use the compiler API to compile to executable. + client = compiler.CompilerClient(context) + c_opts = [ + "--tensorrt-builder-opt-level=3", + "--tensorrt-strongly-typed=false", + "--entrypoint=main", + "--enable-non-dps-returns", + ] + opts = compiler.StableHLOToExecutableOptions(client, c_opts) + if debug: + opts.set_debug_options(False, [], "tmp") + exe = compiler.compiler_stablehlo_to_executable(client, m.operation, opts) + return exe + + +def test_single_return(): + exe = compile_executable(single_return) + session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) + session = runtime.RuntimeSession(session_options, exe) + arg0 = client.create_memref( + np.arange(0.0, 24.0, dtype=np.float32).reshape(2, 3, 4).data, + device=devices[0], + stream=stream, + ) + results = session.execute_function( + "main", in_args=[arg0], stream=stream, client=client + ) + + output = np.asarray(client.copy_to_host(results[0], stream=stream)) + stream.sync() + + print(output) + + +def test_scalar_return(): + exe = compile_executable(scalar_return) + session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) + session = runtime.RuntimeSession(session_options, exe) + arg0 = client.create_memref( + np.arange(0.0, 24.0, dtype=np.float32).reshape(2, 3, 4).data, + device=devices[0], + stream=stream, + ) + results = session.execute_function( + "main", in_args=[arg0], stream=stream, client=client + ) + + print(results[0].data) + + +def test_mixed_return(): + exe = compile_executable(mixed_return) + session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) + session = runtime.RuntimeSession(session_options, exe) + arg0 = client.create_memref( + np.arange(0.0, 24.0, dtype=np.float32).reshape(2, 3, 4).data, + device=devices[0], + stream=stream, + ) + results = session.execute_function( + "main", in_args=[arg0], stream=stream, client=client + ) + + assert type(results[0]) == runtime.MemRefValue + assert type(results[1]) == runtime.ScalarValue + + output = np.asarray(client.copy_to_host(results[0], stream=stream)) + stream.sync() + + print(output) + print(results[1].data) + + +def test_multiple_return(): + exe = compile_executable(multiple_return) + session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) + session = runtime.RuntimeSession(session_options, exe) + arg0 = client.create_memref( + np.arange(0.0, 24.0, dtype=np.float32).reshape(2, 3, 4).data, + device=devices[0], + stream=stream, + ) + results = session.execute_function( + "main", in_args=[arg0], stream=stream, client=client + ) + + output_0 = np.asarray(client.copy_to_host(results[0], stream=stream)) + output_1 = np.asarray(client.copy_to_host(results[1], stream=stream)) + + stream.sync() + + print(output_0) + print(output_1) + + +def test_dynamic_shape(): + exe = compile_executable(dynamic_shape) + session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) + session = runtime.RuntimeSession(session_options, exe) + arg0 = client.create_memref( + np.arange(0.0, 8.0, dtype=np.float32).reshape((4, 2)).data, + device=devices[0], + stream=stream, + ) + arg1 = client.create_memref( + np.ones((4, 2), dtype=np.float32).data, device=devices[0], stream=stream + ) + + results = session.execute_function( + "main", in_args=[arg0, arg1], stream=stream, client=client + ) + + output = np.asarray(client.copy_to_host(results[0], stream=stream)) + stream.sync() + + print(output) + + +def test_session_tracking_d2h(): + exe = compile_executable(session_tracking_h2h) + session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) + session = runtime.RuntimeSession(session_options, exe) + results = session.execute_function("main", in_args=[], stream=stream, client=client) + stream.sync() + print(np.asarray(results[0])) + + +def test_empty_shape_tensor(): + exe = compile_executable(empty_shape_tensor) + session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) + session = runtime.RuntimeSession(session_options, exe) + results = session.execute_function("main", in_args=[], stream=stream, client=client) + stream.sync() + print(np.asarray(results[0])) + + +if __name__ == "__main__": + print("Test: single return") + test_single_return() + # CHECK-LABEL: Test: single return + # CHECK: [[[ 0. 2. 4. 6.] + # CHECK: [ 8. 10. 12. 14.] + # CHECK: [16. 18. 20. 22.]] + # CHECK: + # CHECK: [[24. 26. 28. 30.] + # CHECK: [32. 34. 36. 38.] + # CHECK: [40. 42. 44. 46.]]] + + print("Test: multiple return") + test_multiple_return() + # CHECK-LABEL: Test: multiple return + # CHECK: [[[ 0. 2. 4. 6.] + # CHECK: [ 8. 10. 12. 14.] + # CHECK: [16. 18. 20. 22.]] + # CHECK: + # CHECK: [[24. 26. 28. 30.] + # CHECK: [32. 34. 36. 38.] + # CHECK: [40. 42. 44. 46.]]] + # CHECK: [[[ 0. 3. 6. 9.] + # CHECK: [12. 15. 18. 21.] + # CHECK: [24. 27. 30. 33.]] + # CHECK: + # CHECK: [[36. 39. 42. 45.] + # CHECK: [48. 51. 54. 57.] + # CHECK: [60. 63. 66. 69.]]] + + print("Test: dynamic shape") + test_dynamic_shape() + # CHECK-LABEL: Test: dynamic shape + # CHECK: [[1. 2.] + # CHECK: [3. 4.] + # CHECK: [5. 6.] + # CHECK: [7. 8.]] + + print("Test: device to host copy") + test_session_tracking_d2h() + # CHECK-LABEL: Test: device to host copy + # CHECK: [1 2] + + print("Test: empty shape tensor") + test_empty_shape_tensor() + # CHECK-LABEL: Test: empty shape tensor + # CHECK: [] + + print("Test: scalar return") + test_scalar_return() + # CHECK-LABEL: Test: scalar return + # CHECK: 3 + print("Test: mixed return") + + test_mixed_return() + # CHECK-LABEL: Test: mixed return + # CHECK: [[[ 0. 2. 4. 6.] + # CHECK: [ 8. 10. 12. 14.] + # CHECK: [16. 18. 20. 22.]] + # CHECK: + # CHECK: [[24. 26. 28. 30.] + # CHECK: [32. 34. 36. 38.] + # CHECK: [40. 42. 44. 46.]]] + # CHECK: 3