diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/TensorRTRuntime/IR/TensorRTRuntimeOps.td b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/TensorRTRuntime/IR/TensorRTRuntimeOps.td index c23280453..405c7f82d 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/TensorRTRuntime/IR/TensorRTRuntimeOps.td +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/TensorRTRuntime/IR/TensorRTRuntimeOps.td @@ -161,49 +161,4 @@ def TensorRTRuntime_EnqueueAllocOp : TensorRTRuntime_Op<"enqueue_alloc", [ }]; } -//===----------------------------------------------------------------------===// -// EnqueueAllocOp -//===----------------------------------------------------------------------===// - -def TensorRTRuntime_EnqueueAllocOp : TensorRTRuntime_Op<"enqueue_alloc", [ - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, -]> { - let description = [{ - Asynchronously executes the computation represented by the - `execution_context` on the specified CUDA stream. This operation - can accept inputs of either tensor or memref types and returns - results of either tensor or memref types. - }]; - - let arguments = (ins - TensorRTRuntime_Context:$execution_context, - CUDA_Stream:$stream, - Variadic>:$inputs, - OptionalAttr:$host_tensor_args - ); - - let results = (outs Variadic>:$results); - - let assemblyFormat = [{ - $execution_context `stream` `(` $stream `)` ` ` - (`host_tensor_args` $host_tensor_args^ ` ` )? - `(` $inputs `)` - attr-dict `:` functional-type($inputs, $results) - }]; - - let hasVerifier = 1; - - let extraClassDeclaration = [{ - /// Return true if the operand is a host tensor argument. - bool isOperandOnHost(OpOperand *operand) { - unsigned operandIdx = operand->getOperandNumber(); - if(std::optional> indices = getHostTensorArgs()) { - return llvm::is_contained(*indices, operandIdx - 2); - } - return false; - } - }]; -} - #endif // MLIR_TENSORRT_DIALECT_TENSORRT_IR_TENSORRTRUNTIMEOPS_TD diff --git a/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp b/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp index 5bc1c7e5b..6d2aaacfe 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp +++ b/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp @@ -223,7 +223,7 @@ StableHLOToExecutableOptions::StableHLOToExecutableOptions( llvm::cl::desc("Don't allow TensorRt clusters to contain host tensor " "calculations (but they can still be inputs)")); addOption( - "use-non-dps-call-conv", useNonDPSCallConv, llvm::cl::init(false), + "enable-non-dps-returns", enableNonDPSReturns, llvm::cl::init(false), llvm::cl::desc( "allow tensorrt based output allocations using output allocator")); addOption("executor-index-bitwidth", executorIndexBitwidth, @@ -307,7 +307,7 @@ void StableHloToExecutableTask::buildStablehloClusteringPipeline( plan::StablehloClusteringPassOptions clusteringOpts{}; clusteringOpts.disallowHostTensorsInTensorRTClusters = opts.disallowHostTensorsInTensorRTClusters; - clusteringOpts.useNonDPSCallConv = opts.useNonDPSCallConv; + clusteringOpts.enableNonDPSReturns = opts.enableNonDPSReturns; clusteringOpts.entrypoint = opts.entrypoint; plan::buildPlanSegmentationPipeline(pm, clusteringOpts); @@ -342,7 +342,7 @@ void StableHloToExecutableTask::buildPostClusteringPipeline( // Perform bufferization. pm.addPass(createMemRefCastEliminationPass()); plan::PlanAllocTensorsPassOptions allocTensorsOpts{}; - allocTensorsOpts.useNonDPSCallConv = opts.useNonDPSCallConv; + allocTensorsOpts.enableNonDPSReturns = opts.enableNonDPSReturns; pm.addPass(plan::createPlanAllocTensorsPass(allocTensorsOpts)); pm.addPass(plan::createPlanBufferizePass()); pm.addPass(createMemRefCastEliminationPass()); @@ -532,8 +532,8 @@ struct ClusteringPipelineCliOpts *this, "device-compute-capability", llvm::cl::desc("target device compute capability (SM version)"), llvm::cl::init(60)}; - Option useNonDPSCallConv{ - *this, "use-non-dps-call-conv", + Option enableNonDPSReturns{ + *this, "enable-non-dps-returns", llvm::cl::desc( "allow tensorrt based output allocations using output allocator"), llvm::cl::init(false)}; @@ -564,7 +564,7 @@ static StableHLOToExecutableOptions populateStablehloClusteringPipelineOpts( opts.deviceComputeCapability = cliOpts.deviceComputeCapability; opts.deviceMaxSharedMemoryPerBlockKb = cliOpts.deviceMaxSharedMemoryPerBlockKb; - opts.useNonDPSCallConv = cliOpts.useNonDPSCallConv; + opts.enableNonDPSReturns = cliOpts.enableNonDPSReturns; opts.shouldInferDeviceOptionsFromHost = cliOpts.inferDeviceOptionsFromHost; opts.entrypoint = cliOpts.entrypoint; return opts; diff --git a/mlir-tensorrt/compiler/lib/Conversion/TensorRTToTensorRTRuntime/TensorRTToTensorRTRuntime.cpp b/mlir-tensorrt/compiler/lib/Conversion/TensorRTToTensorRTRuntime/TensorRTToTensorRTRuntime.cpp index 5023176ac..52f838eac 100644 --- a/mlir-tensorrt/compiler/lib/Conversion/TensorRTToTensorRTRuntime/TensorRTToTensorRTRuntime.cpp +++ b/mlir-tensorrt/compiler/lib/Conversion/TensorRTToTensorRTRuntime/TensorRTToTensorRTRuntime.cpp @@ -93,6 +93,8 @@ convertCallOp(Operation *op, IRRewriter &rewriter, SmallVector hostTensorArgs; for (auto [idx, arg] : llvm::enumerate(trtFunc.getArguments())) { const TensorKindLattice *kind = solver.lookupState(arg); + if (!isa(arg.getType())) + continue; RankedTensorType rtt = cast(arg.getType()); // To be conservative, we only do this if type is i32 and num elements // <= 8. diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp index eb29494c9..393802270 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp @@ -65,17 +65,26 @@ struct RemoveWithValuesRewriter : public OpRewritePattern { } // namespace /// Get a map from `tensorrt.func` functions to associated `tensorrt.call` -/// operations. -static llvm::DenseMap> +/// and `tensorrt.call_alloc` operations. +static llvm::DenseMap> getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) { - llvm::DenseMap> map; - op->walk([&](tensorrt::CallOp callOp) { - func::FuncOp func = callOp.getFuncCallee(collection); - if (map.contains(func)) { - map[func].push_back(callOp); + llvm::DenseMap> map; + op->walk([&](Operation *callOp) { + if (!isa(callOp)) return; - } - map.insert(std::make_pair(func, SmallVector{callOp})); + + func::FuncOp func; + if (auto call = dyn_cast(callOp)) + func = call.getFuncCallee(collection); + else if (auto callAlloc = dyn_cast(callOp)) + func = callAlloc.getFuncCallee(collection); + else + return; + + if (map.count(func)) + map[func].push_back(callOp); + else + map.insert({func, SmallVector{callOp}}); }); return map; } @@ -84,7 +93,7 @@ getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) { /// `tensorrt.call` operations. static LogicalResult removeUnusedArgs(SymbolTableCollection &collection, ModuleOp op, func::FuncOp funcOp, - ArrayRef callOps) { + ArrayRef callOps) { llvm::SmallBitVector unusedArgs(funcOp.getNumArguments(), 0); for (BlockArgument arg : funcOp.getArguments()) { if (arg.use_empty()) @@ -99,8 +108,16 @@ static LogicalResult removeUnusedArgs(SymbolTableCollection &collection, funcOp.eraseArgument(i); // Update the call ops. - for (tensorrt::CallOp callOp : callOps) - callOp.getInputsMutable().erase(i); + for (Operation *callOp : callOps) { + if (auto call = dyn_cast(callOp)) { + call.getInputsMutable().erase(i); + } else if (auto callAlloc = dyn_cast(callOp)) { + callAlloc.getInputsMutable().erase(i); + } else { + llvm::errs() << "Unexpected operation type in callOps\n"; + callOp->dump(); + } + } } return success(); diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp index d593393f6..0eae7f594 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp @@ -268,8 +268,82 @@ static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter, static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter, plan::InlineClosedAllocGroupOp op) { - return op.emitError("outlinining inline closed alloc group ops to tensorrt " - "dialect is not yet implemented"); + tensorrt::TensorRTModuleOp trtModuleOp = getOrCreateTensorRTModuleOp(op); + auto funcArgTypes = llvm::to_vector(TypeRange(op.getInputs())); + FailureOr func = createOutlinedFunc( + rewriter, op.getLoc(), op, trtModuleOp, "tensorrt_cluster", + "cluster.tensorrt", TypeRange(op.getInputs()), + op.getYield()->getOperandTypes()); + if (failed(func)) + return failure(); + assert(func->getFunctionBody().getBlocks().size() == 1 && + "expected body with one block"); + func->setPublic(); + + rewriter.setInsertionPoint(op); + + auto callOp = rewriter.create( + op.getLoc(), op.getResultTypes(), op.getInputs(), + SymbolRefAttr::get(trtModuleOp.getNameAttr(), + {FlatSymbolRefAttr::get(*func)})); + + // Populate the function arguments attributes. + for (unsigned i = 0; i < (*func).getNumArguments(); i++) { + BoundsAttr srcAttr = cast(op.getInputAttrs()[i]); + // We may have scalar (index|signless int)-typed values since we haven't + // eliminated `plan.(with_shape|with_values)` ops yet. + if (!op.argHasTensorType(i) || srcAttr.isNone()) + continue; + FailureOr boundAttr = + getTensorRTShapeProfile(srcAttr, op.getInputs()[i]); + if (failed(boundAttr)) + return op->emitOpError("failed to create TensorRT shape profile " + "attribute from Plan BoundsAttr for argument #") + << i << " (" << srcAttr << ")"; + if (srcAttr.isShapeBound()) { + func->setArgAttr(i, + tensorrt::TensorRTDialect::getShapeProfileArgAttrName(), + *boundAttr); + continue; + } + assert(srcAttr.isValueBound() && "expected value bound or shape bound"); + func->setArgAttr( + i, tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName(), + *boundAttr); + func->setArgAttr(i, mlir::getHostTensorArgAttrName(), + rewriter.getUnitAttr()); + } + + // Populate the function entry block. + rewriter.eraseBlock(&func->getFunctionBody().front()); + + // Move private decomposition funcs associated with all `stablehlo.composite` + // ops to the `tensorrt.module` op. This is needed since `tensorrt.module` op + // has its own symbol table. + SymbolTableCollection symbolTable; + for (auto compositeOp : op.getBody().getOps()) { + auto decompositionFunc = dyn_cast_if_present( + symbolTable.lookupSymbolIn(op->getParentOfType(), + compositeOp.getDecompositionAttr())); + if (!decompositionFunc) + return emitError(compositeOp.getLoc()) + << "failed to lookup stablehlo.composite decomposition " + "function: " + << compositeOp.getDecompositionAttr(); + rewriter.moveOpAfter(decompositionFunc, func->getOperation()); + } + + // Move region op operations to the func body. + Operation *regionYieldOp = op.getYield(); + rewriter.inlineRegionBefore(op.getRegion(), func->getFunctionBody(), + func->getFunctionBody().end()); + rewriter.setInsertionPoint(regionYieldOp); + rewriter.replaceOpWithNewOp(regionYieldOp, + regionYieldOp->getOperands()); + + // replace the original region results. + rewriter.replaceOp(op, callOp); + return success(); } /// Create outlined functions for each `scf.execute_region` operation within diff --git a/mlir-tensorrt/executor/lib/Runtime/API/API.cpp b/mlir-tensorrt/executor/lib/Runtime/API/API.cpp index 8722c3c60..49e61fab8 100644 --- a/mlir-tensorrt/executor/lib/Runtime/API/API.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/API/API.cpp @@ -410,7 +410,7 @@ void AllocTracker::incrementExternalCount(uintptr_t ptr) { llvm::formatv("Untracked pointer {0}", ptr).str().c_str()); std::unique_ptr const &metadata = map.at(ptr); int32_t ref = ++metadata->externalReferenceCount; - MTRT_DBG("Incremented external reference for pointer %d to %d", ptr, ref); + MTRT_DBGF("Incremented external reference for pointer 0x%lx to %d", ptr, ref); } void AllocTracker::decrementExternalCount(uintptr_t ptr) { @@ -422,11 +422,12 @@ void AllocTracker::decrementExternalCount(uintptr_t ptr) { llvm::formatv("External reference count cannot be negative: {0}", ref) .str() .c_str()); - MTRT_DBG("Decremented external reference for pointer %d to %d", ptr, ref); + MTRT_DBGF("Decremented external reference for pointer 0x%lx to %d", ptr, ref); if (ref == 0 && metadata->releasedInternally) { - MTRT_DBG("External reference to an internally released pointer %d is 0, " - "try deallocating pointer memory of size %lu", - ptr, ref, metadata->info.size); + MTRT_DBGF( + "External reference to an internally released pointer 0x%lx is 0, " + "try deallocating pointer memory of size %lu", + ptr, metadata->info.size); Status s = safeDeallocate(*this, metadata->info.ptr); if (!s.isOk()) MTRT_DBGF("error while deallocating dangling memory: %s", diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp index 3ef238d18..02fdcfc1b 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp @@ -502,30 +502,34 @@ static Status validateArgsTypesAgainstFuncArgs(const RuntimeValue *runArg, return getOkStatus(); } -static constexpr int MEMREF_FIXED_FIELDS = 3; // allocPtr, alignedPtr, offset +[[maybe_unused]] static constexpr int MEMREF_FIXED_FIELDS = + 3; // allocPtr, alignedPtr, offset // MemRefTableReader encapsulates the logic for reading MemRef data from a Lua // table class MemRefTableReader { public: - MemRefTableReader(const sol::protected_function_result &pfr, + MemRefTableReader(const sol::protected_function_result &pfr, int resultIndex, impl::CallingConvention conv) : mPfr(pfr), mConv(conv), mIndex(1) { // Currently, we only support unpacked calling convention assert(mConv == CallingConvention::unpacked && "Only unpacked calling convention is supported"); + + // Assume result is always a memref. + sol::object obj = mPfr[resultIndex]; + assert(obj.is() && "Expected a table for MemRefValue"); + mMemRefTable = obj.as(); } // Retrieves the next value of type T from the MemRef table // This method advances the internal index automatically template T getNextValue() { - sol::object obj = mPfr[0]; - assert(obj.is() && "Expected a table for MemRefValue"); - sol::table memRefTable = obj.as(); - return memRefTable.get(mIndex++); + return mMemRefTable.get(mIndex++); } + // TODO: This may not be required since each pfr index stores a memref. // Moves to the next MemRef in the table // This is called after processing all data for the current MemRef void nextMemRef(int offset) { @@ -537,6 +541,7 @@ class MemRefTableReader { private: const sol::protected_function_result &mPfr; impl::CallingConvention mConv; + sol::table mMemRefTable; int mIndex; }; @@ -571,9 +576,10 @@ parseResults(const sol::protected_function_result &pfr, const FunctionSignatureView &sig, std::optional client) { llvm::SmallVector> results; - MemRefTableReader reader(pfr, sig.getCConv()); - for (unsigned i = 0; i < sig.getNumResults(); ++i) { + + MemRefTableReader reader(pfr, i, sig.getCConv()); + if (sig.getResult(i).isa()) { auto scalar = getScalarValue(pfr, i, sig); if (!scalar.isOk()) @@ -607,16 +613,16 @@ parseResults(const sol::protected_function_result &pfr, return getInvalidArgStatus("Runtime client cannot be nullptr"); // Create MemRefValue from extracted data - auto memref = MemRefValue::create( - *client, resultView.getAddressSpace(), - resultView.getElementType().getBitWidth(), allocPtr, offset, shape, - strides, (*client)->getDevices()[0].get(), resultView.getElementType()); + + auto memref = (*client)->createExternalMemRef( + resultView.getAddressSpace(), resultView.getElementType().getBitWidth(), + allocPtr, offset, shape, strides, (*client)->getDevices()[0].get(), + resultView.getElementType()); if (!memref.isOk()) return memref.getStatus(); results.push_back(std::move(*memref)); - reader.nextMemRef(MEMREF_FIXED_FIELDS + rank * 2); } return results; diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp index e9e9e847a..322fc1870 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp @@ -282,55 +282,52 @@ void OutputAllocatorImpl::setTensorName(const std::string &name) { mTensorName = name; } -void OutputAllocatorImpl::setCurrentMemory(void *memory, int64_t size) { - mCurrentMemory = memory; +void OutputAllocatorImpl::setCurrentMemory(uintptr_t memory, int64_t size) { + mOutputPtr = memory; mOutputSize = size; } -void *OutputAllocatorImpl::reallocateOutputAsync(const char *name, - void *memory, uint64_t size, +void *OutputAllocatorImpl::reallocateOutputAsync(const char *name, void *memory, + uint64_t size, uint64_t alignment, - cudaStream_t /*stream*/) { - assert(memory == mCurrentMemory && "Output buffer mismatch"); + cudaStream_t stream) { + assert((!mOutputPtr || reinterpret_cast(memory) == mOutputPtr) && + "Output buffer mismatch"); assert(name == mTensorName && "Tensor name mismatch"); - assert(!mReallocateOutputCalled && "Duplicate call to reallocateOutput"); - mReallocateOutputCalled = true; size = std::max(size, static_cast(1)); if (size > mOutputSize) { size = roundUp(size, alignment); - if (mOutputPtr) { - cudaFree(mOutputPtr); - } + if (mOutputPtr) + mlirtrt::runtime::safeDeallocate(*mTracker, mOutputPtr, + CudaStreamPtr(stream)); - mOutputPtr = nullptr; + mOutputPtr = 0; mOutputSize = 0; - void *newMemory; - if (cudaMalloc(&newMemory, size) == cudaSuccess) { - mOutputPtr = newMemory; - mOutputSize = size; + StatusOr memory = mlirtrt::runtime::allocate( + *mTracker, PointerType::device, size, alignment, CudaStreamPtr(stream)); + if (memory.isOk()) { + mOutputPtr = (*memory).ptr; + mOutputSize = memory->size; } - return mOutputPtr; + return reinterpret_cast(mOutputPtr); } - return mCurrentMemory; + return reinterpret_cast(mOutputPtr); } void OutputAllocatorImpl::notifyShape(const char *name, const nvinfer1::Dims &dims) { - assert(mReallocateOutputCalled && "TensorRT must invoke reallocateOutput first"); - assert(!mNotifyShapeCalled && "Duplicate call to notifyShape"); assert(name == mTensorName && "Tensor name mismatch"); - mNotifyShapeCalled = true; mOutputDims = dims; } /// OutputAllocator - Manages multiple OutputAllocatorImpl instances. class OutputAllocator { public: - explicit OutputAllocator(int64_t nbResults); + explicit OutputAllocator(AllocTracker *tracker, int64_t nbResults); // Disable copy and move operations OutputAllocator(const OutputAllocator &) = delete; @@ -338,7 +335,7 @@ class OutputAllocator { OutputAllocator(OutputAllocator &&) = delete; OutputAllocator &operator=(OutputAllocator &&) = delete; - void registerAllocator(size_t index, const char *name, void *ptr, + void registerAllocator(size_t index, const char *name, uintptr_t ptr, int64_t size, nvinfer1::IExecutionContext *context); OutputAllocatorImpl *getAllocator(size_t index); @@ -346,15 +343,15 @@ class OutputAllocator { std::vector> mAllocators; }; -OutputAllocator::OutputAllocator(int64_t nbResults) { +OutputAllocator::OutputAllocator(AllocTracker *tracker, int64_t nbResults) { mAllocators.reserve(nbResults); for (int64_t i = 0; i < nbResults; ++i) { - mAllocators.push_back(std::make_unique()); + mAllocators.push_back(std::make_unique(tracker)); } } void OutputAllocator::registerAllocator(size_t index, const char *name, - void *ptr, int64_t size, + uintptr_t ptr, int64_t size, nvinfer1::IExecutionContext *context) { assert(index >= 0 && index < mAllocators.size() && "Index out of bounds"); mAllocators[index]->setTensorName(name); @@ -751,11 +748,6 @@ void mlirtrt::runtime::registerExecutorTensorRTModuleLuaRuntimeMethods( return *ctx; }; - lua["_trtrt_create_allocator"] = - [](int64_t nbResults) -> std::unique_ptr { - return std::make_unique(nbResults); - }; - lua["_trtrt_enqueue"] = [allocTracker, resourceTracker](sol::this_state state, diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_add.py b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_add.py index 228019d7d..4e9b9ceff 100644 --- a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_add.py +++ b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_add.py @@ -26,7 +26,7 @@ def stablehlo_add(use_non_dps=False, debug=False): "--tensorrt-strongly-typed=false", ] if use_non_dps: - c_opts.append("--use-non-dps-call-conv") + c_opts.append("--enable-non-dps-returns") if debug: c_opts.append("--debug=true") c_opts.append(f"--mlir-print-ir-tree-dir=mlir-dumps-add-no-clone") diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic.py b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic.py index 0429e20f1..364156667 100644 --- a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic.py +++ b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic.py @@ -89,21 +89,23 @@ def infer_output_shape(client, session, exe, input_shape): return output_shape -def test_program(program: str, input_shape: Iterable[int], debug: bool = True): +def test_program( + program: str, input_shape: Iterable[int], use_non_dps=False, debug: bool = False +): # Build/parse the main function. with ir.Context() as context: m = ir.Module.parse(program) # Use the compiler API to compile to executable. client = compiler.CompilerClient(context) - opts = compiler.StableHLOToExecutableOptions( - client, - [ - "--tensorrt-builder-opt-level=3", - "--tensorrt-strongly-typed=false", - "--entrypoint=main", - ], - ) + c_opts = [ + "--tensorrt-builder-opt-level=3", + "--tensorrt-strongly-typed=false", + "--entrypoint=main", + ] + if use_non_dps: + c_opts.append("--enable-non-dps-returns") + opts = compiler.StableHLOToExecutableOptions(client, c_opts) if debug: opts.set_debug_options(False, [], "tmp") exe = compiler.compiler_stablehlo_to_executable(client, m.operation, opts) @@ -129,29 +131,66 @@ def test_program(program: str, input_shape: Iterable[int], debug: bool = True): np.ones(input_shape, dtype=np.float32).data, device=devices[0], stream=stream ) - output_shape = infer_output_shape(client, session, exe, input_shape) - - arg2 = client.create_memref( - np.zeros(output_shape, dtype=np.float32).data, - device=devices[0], - stream=stream, - ) - - session.execute_function( - "main", in_args=[arg0, arg1], out_args=[arg2], stream=stream, client=client - ) - data = np.asarray(client.copy_to_host(arg2, stream=stream)) + result = None + if use_non_dps: + results = session.execute_function( + "main", in_args=[arg0, arg1], stream=stream, client=client + ) + result = results[0] + else: + output_shape = infer_output_shape(client, session, exe, input_shape) + result = client.create_memref( + np.zeros(output_shape, dtype=np.float32).data, + device=devices[0], + stream=stream, + ) + session.execute_function( + "main", + in_args=[arg0, arg1], + out_args=[result], + stream=stream, + client=client, + ) + data = np.asarray(client.copy_to_host(result, stream=stream)) stream.sync() print(data) if __name__ == "__main__": + print("DPS style execution:") print("Test (3, ?, 2)") test_program(program1, (3, 4, 2)) print("Test (?, 2)") test_program(program2, (4, 2)) + print("Non DPS style execution:") + print("Test (3, ?, 2)") + test_program(program1, (3, 4, 2), use_non_dps=True) + print("Test (?, 2)") + test_program(program2, (4, 2), use_non_dps=True) + +# CHECK-LABEL: DPS style execution: +# CHECK-LABEL: Test (3, ?, 2) +# CHECK: [{{\[}}[2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.]] +# CHECK: {{\[}}[2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.]] +# CHECK: {{\[}}[2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.]]] + +# CHECK-LABEL: Test (?, 2) +# CHECK: {{\[}}[2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.]] +# CHECK-LABEL: Non DPS style execution: # CHECK-LABEL: Test (3, ?, 2) # CHECK: [{{\[}}[2. 2.] # CHECK: [2. 2.] diff --git a/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_runtime_debug_dump.py b/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_runtime_debug_dump.py index 68ca684e8..4d379c3e6 100644 --- a/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_runtime_debug_dump.py +++ b/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_runtime_debug_dump.py @@ -49,7 +49,9 @@ def stablehlo_add(): device=devices[0], stream=stream, ) - session.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream) + session.execute_function( + "main", in_args=[arg0], out_args=[arg1], stream=stream, client=client + ) data = np.asarray(client.copy_to_host(arg1, stream=stream)) stream.sync() diff --git a/tripy/tripy/backend/mlir/compiler.py b/tripy/tripy/backend/mlir/compiler.py index ccd6eac45..d36ab0ed9 100644 --- a/tripy/tripy/backend/mlir/compiler.py +++ b/tripy/tripy/backend/mlir/compiler.py @@ -58,7 +58,7 @@ def _make_mlir_opts(self, trt_builder_opt_level): f"--tensorrt-timing-cache-path={G_TIMING_CACHE_FILE}", f"--tensorrt-builder-opt-level={trt_builder_opt_level}", "--tensorrt-strongly-typed=True", - "--use-non-dps-call-conv", + "--enable-non-dps-returns", ] if config.enable_mlir_debug or config.enable_tensorrt_debug: opts.append("--debug=true") @@ -74,7 +74,9 @@ def compile_stabehlo_program(self, code: str) -> compiler.Executable: with self.mlir_context: module = ir.Module.parse(code) opts = self._make_mlir_opts(self.trt_builder_opt_level) - return compiler.compiler_stablehlo_to_executable(self.compiler_client, module.operation, opts) + return compiler.compiler_stablehlo_to_executable( + self.compiler_client, module.operation, opts + ) @utils.log_time def infer_shapes(self, mlir_module: ir.Module, flat_ir: Optional["FlatIR"] = None): @@ -94,8 +96,12 @@ def infer_shapes(self, mlir_module: ir.Module, flat_ir: Optional["FlatIR"] = Non # The optional flat_ir parameter is used to generate nicer error messages. @utils.log_time - def compile(self, mlir_module: ir.Module, flat_ir: Optional["FlatIR"] = None) -> compiler.Executable: - logger.mlir(lambda: f"{mlir_module.operation.get_asm(large_elements_limit=32)}\n") + def compile( + self, mlir_module: ir.Module, flat_ir: Optional["FlatIR"] = None + ) -> compiler.Executable: + logger.mlir( + lambda: f"{mlir_module.operation.get_asm(large_elements_limit=32)}\n" + ) opts = self._make_mlir_opts(self.trt_builder_opt_level) try: