diff --git a/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp b/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp index c2bb73bcd..57fd29bb7 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp +++ b/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp @@ -222,6 +222,10 @@ StableHLOToExecutableOptions::StableHLOToExecutableOptions( disallowHostTensorsInTensorRTClusters, llvm::cl::init(false), llvm::cl::desc("Don't allow TensorRt clusters to contain host tensor " "calculations (but they can still be inputs)")); + addOption( + "enable-non-dps-returns", enableNonDPSReturns, llvm::cl::init(false), + llvm::cl::desc( + "allow tensorrt based output allocations using output allocator")); addOption("executor-index-bitwidth", executorIndexBitwidth, llvm::cl::init(64)); addOption("device-compute-capability", deviceComputeCapability, @@ -306,6 +310,7 @@ void StableHloToExecutableTask::buildStablehloClusteringPipeline( plan::StablehloClusteringPassOptions clusteringOpts{}; clusteringOpts.disallowHostTensorsInTensorRTClusters = opts.disallowHostTensorsInTensorRTClusters; + clusteringOpts.enableNonDPSReturns = opts.enableNonDPSReturns; clusteringOpts.entrypoint = opts.entrypoint; plan::buildPlanSegmentationPipeline(pm, clusteringOpts); @@ -339,7 +344,9 @@ void StableHloToExecutableTask::buildPostClusteringPipeline( // Perform bufferization. pm.addPass(createMemRefCastEliminationPass()); - pm.addPass(plan::createPlanAllocTensorsPass()); + plan::PlanAllocTensorsPassOptions allocTensorsOpts{}; + allocTensorsOpts.enableNonDPSReturns = opts.enableNonDPSReturns; + pm.addPass(plan::createPlanAllocTensorsPass(allocTensorsOpts)); pm.addPass(plan::createPlanBufferizePass()); pm.addPass(createMemRefCastEliminationPass()); pm.addPass(createCanonicalizerPass()); @@ -525,6 +532,11 @@ struct ClusteringPipelineCliOpts *this, "device-compute-capability", llvm::cl::desc("target device compute capability (SM version)"), llvm::cl::init(60)}; + Option enableNonDPSReturns{ + *this, "enable-non-dps-returns", + llvm::cl::desc( + "allow tensorrt based output allocations using output allocator"), + llvm::cl::init(false)}; Option deviceMaxSharedMemoryPerBlockKb{ *this, "device-max-smem-per-block", llvm::cl::desc("max shared memory per block (in kilobytes)"), @@ -552,6 +564,7 @@ static StableHLOToExecutableOptions populateStablehloClusteringPipelineOpts( opts.deviceComputeCapability = cliOpts.deviceComputeCapability; opts.deviceMaxSharedMemoryPerBlockKb = cliOpts.deviceMaxSharedMemoryPerBlockKb; + opts.enableNonDPSReturns = cliOpts.enableNonDPSReturns; opts.shouldInferDeviceOptionsFromHost = cliOpts.inferDeviceOptionsFromHost; opts.entrypoint = cliOpts.entrypoint; return opts; diff --git a/mlir-tensorrt/compiler/lib/Conversion/TensorRTToTensorRTRuntime/TensorRTToTensorRTRuntime.cpp b/mlir-tensorrt/compiler/lib/Conversion/TensorRTToTensorRTRuntime/TensorRTToTensorRTRuntime.cpp index f8382596f..52f838eac 100644 --- a/mlir-tensorrt/compiler/lib/Conversion/TensorRTToTensorRTRuntime/TensorRTToTensorRTRuntime.cpp +++ b/mlir-tensorrt/compiler/lib/Conversion/TensorRTToTensorRTRuntime/TensorRTToTensorRTRuntime.cpp @@ -93,6 +93,8 @@ convertCallOp(Operation *op, IRRewriter &rewriter, SmallVector hostTensorArgs; for (auto [idx, arg] : llvm::enumerate(trtFunc.getArguments())) { const TensorKindLattice *kind = solver.lookupState(arg); + if (!isa(arg.getType())) + continue; RankedTensorType rtt = cast(arg.getType()); // To be conservative, we only do this if type is i32 and num elements // <= 8. @@ -144,6 +146,8 @@ class ConvertTensorRTToRuntimePass } IRRewriter rewriter(ctx); + SymbolTableCollection symbolTable; + DataFlowSolver solver; SmallVector trtModules = llvm::to_vector(module.getOps()); diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp index eb29494c9..393802270 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp @@ -65,17 +65,26 @@ struct RemoveWithValuesRewriter : public OpRewritePattern { } // namespace /// Get a map from `tensorrt.func` functions to associated `tensorrt.call` -/// operations. -static llvm::DenseMap> +/// and `tensorrt.call_alloc` operations. +static llvm::DenseMap> getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) { - llvm::DenseMap> map; - op->walk([&](tensorrt::CallOp callOp) { - func::FuncOp func = callOp.getFuncCallee(collection); - if (map.contains(func)) { - map[func].push_back(callOp); + llvm::DenseMap> map; + op->walk([&](Operation *callOp) { + if (!isa(callOp)) return; - } - map.insert(std::make_pair(func, SmallVector{callOp})); + + func::FuncOp func; + if (auto call = dyn_cast(callOp)) + func = call.getFuncCallee(collection); + else if (auto callAlloc = dyn_cast(callOp)) + func = callAlloc.getFuncCallee(collection); + else + return; + + if (map.count(func)) + map[func].push_back(callOp); + else + map.insert({func, SmallVector{callOp}}); }); return map; } @@ -84,7 +93,7 @@ getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) { /// `tensorrt.call` operations. static LogicalResult removeUnusedArgs(SymbolTableCollection &collection, ModuleOp op, func::FuncOp funcOp, - ArrayRef callOps) { + ArrayRef callOps) { llvm::SmallBitVector unusedArgs(funcOp.getNumArguments(), 0); for (BlockArgument arg : funcOp.getArguments()) { if (arg.use_empty()) @@ -99,8 +108,16 @@ static LogicalResult removeUnusedArgs(SymbolTableCollection &collection, funcOp.eraseArgument(i); // Update the call ops. - for (tensorrt::CallOp callOp : callOps) - callOp.getInputsMutable().erase(i); + for (Operation *callOp : callOps) { + if (auto call = dyn_cast(callOp)) { + call.getInputsMutable().erase(i); + } else if (auto callAlloc = dyn_cast(callOp)) { + callAlloc.getInputsMutable().erase(i); + } else { + llvm::errs() << "Unexpected operation type in callOps\n"; + callOp->dump(); + } + } } return success(); diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp index d593393f6..0eae7f594 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp @@ -268,8 +268,82 @@ static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter, static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter, plan::InlineClosedAllocGroupOp op) { - return op.emitError("outlinining inline closed alloc group ops to tensorrt " - "dialect is not yet implemented"); + tensorrt::TensorRTModuleOp trtModuleOp = getOrCreateTensorRTModuleOp(op); + auto funcArgTypes = llvm::to_vector(TypeRange(op.getInputs())); + FailureOr func = createOutlinedFunc( + rewriter, op.getLoc(), op, trtModuleOp, "tensorrt_cluster", + "cluster.tensorrt", TypeRange(op.getInputs()), + op.getYield()->getOperandTypes()); + if (failed(func)) + return failure(); + assert(func->getFunctionBody().getBlocks().size() == 1 && + "expected body with one block"); + func->setPublic(); + + rewriter.setInsertionPoint(op); + + auto callOp = rewriter.create( + op.getLoc(), op.getResultTypes(), op.getInputs(), + SymbolRefAttr::get(trtModuleOp.getNameAttr(), + {FlatSymbolRefAttr::get(*func)})); + + // Populate the function arguments attributes. + for (unsigned i = 0; i < (*func).getNumArguments(); i++) { + BoundsAttr srcAttr = cast(op.getInputAttrs()[i]); + // We may have scalar (index|signless int)-typed values since we haven't + // eliminated `plan.(with_shape|with_values)` ops yet. + if (!op.argHasTensorType(i) || srcAttr.isNone()) + continue; + FailureOr boundAttr = + getTensorRTShapeProfile(srcAttr, op.getInputs()[i]); + if (failed(boundAttr)) + return op->emitOpError("failed to create TensorRT shape profile " + "attribute from Plan BoundsAttr for argument #") + << i << " (" << srcAttr << ")"; + if (srcAttr.isShapeBound()) { + func->setArgAttr(i, + tensorrt::TensorRTDialect::getShapeProfileArgAttrName(), + *boundAttr); + continue; + } + assert(srcAttr.isValueBound() && "expected value bound or shape bound"); + func->setArgAttr( + i, tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName(), + *boundAttr); + func->setArgAttr(i, mlir::getHostTensorArgAttrName(), + rewriter.getUnitAttr()); + } + + // Populate the function entry block. + rewriter.eraseBlock(&func->getFunctionBody().front()); + + // Move private decomposition funcs associated with all `stablehlo.composite` + // ops to the `tensorrt.module` op. This is needed since `tensorrt.module` op + // has its own symbol table. + SymbolTableCollection symbolTable; + for (auto compositeOp : op.getBody().getOps()) { + auto decompositionFunc = dyn_cast_if_present( + symbolTable.lookupSymbolIn(op->getParentOfType(), + compositeOp.getDecompositionAttr())); + if (!decompositionFunc) + return emitError(compositeOp.getLoc()) + << "failed to lookup stablehlo.composite decomposition " + "function: " + << compositeOp.getDecompositionAttr(); + rewriter.moveOpAfter(decompositionFunc, func->getOperation()); + } + + // Move region op operations to the func body. + Operation *regionYieldOp = op.getYield(); + rewriter.inlineRegionBefore(op.getRegion(), func->getFunctionBody(), + func->getFunctionBody().end()); + rewriter.setInsertionPoint(regionYieldOp); + rewriter.replaceOpWithNewOp(regionYieldOp, + regionYieldOp->getOperands()); + + // replace the original region results. + rewriter.replaceOp(op, callOp); + return success(); } /// Create outlined functions for each `scf.execute_region` operation within diff --git a/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h b/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h index 807d3dc23..6bdf72ad6 100644 --- a/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h +++ b/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h @@ -215,6 +215,11 @@ static inline bool mtrtRuntimeClientIsNull(MTRT_RuntimeClient client) { return !client.ptr; } +/// Returns null client. +static inline MTRT_RuntimeClient mtrtRuntimeClientGetNull() { + return MTRT_RuntimeClient{nullptr}; +} + /// Creates a `MTRT_RuntimeClient`. Client must be alive for the lifetime of the /// program execution. /// The `stream` passed to the client is used by all underlying CUDA methods @@ -308,6 +313,12 @@ static inline bool mtrtRuntimeValueIsNull(MTRT_RuntimeValue value) { return !value.ptr; } +// Returns whether the RuntimeValue is MemRef. +MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value); + +// Returns whether the RuntimeValue is Scalar. +MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value); + /// Cast a MTRT_MemRefValue to a generic MTRT_RuntimeValue. MLIR_CAPI_EXPORTED MTRT_RuntimeValue mtrtMemRefCastToRuntimeValue(MTRT_MemRefValue memref); @@ -391,16 +402,25 @@ static inline bool mtrtRuntimeSessionIsNull(MTRT_RuntimeSession session) { return !session.ptr; } -/// Using `session`, execute the pubic function with the specified name. -/// The `inArgs` and `outArgs` are arrays for input arguments and destination -/// arguments, respectively. Input arguments may be MemRefs or scalars, but -/// destination arguments must be MemRefs. +/// Using `session`, execute the public function with the specified name. +/// The `inArgs`, `outArgs`, and `results` are arrays for input arguments, +/// output arguments, and return values, respectively. Arguments and results +/// can be MemRefs, scalars, or other supported types. Both `outArgs` and +/// `results` can be used simultaneously, allowing for functions that both +/// modify arguments and return values. /// A stream may optionally be specified, otherwise pass the result of /// `mtrtStreamGetNull()`. +/// +/// The `results` array must point to an array with at least the number of +/// elements returned by mtrtRuntimeSessionGetNumResults for the given function. MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionExecuteFunction( MTRT_RuntimeSession session, MTRT_StringView name, const MTRT_RuntimeValue *inArgs, size_t numInArgs, - const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream); + const MTRT_RuntimeValue *outArgs, size_t numOutArgs, + MTRT_RuntimeValue *results, MTRT_Stream stream, MTRT_RuntimeClient client); + +MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionGetNumResults( + MTRT_RuntimeSession session, MTRT_StringView name, int64_t *numResults); //===----------------------------------------------------------------------===// // DLPack diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h index 88c616bc7..a58b1d022 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h @@ -104,13 +104,6 @@ executeFunctionWithLuaBackend(LuaRuntimeSession &session, std::string_view name, std::optional stream = {}, std::optional client = {}); -// Parses the results of a function call, handling both scalar and MemRef return -// types -StatusOr>> -parseResults(const sol::protected_function_result &pfr, - const FunctionSignatureView &sig, - std::optional client); - } // namespace mlirtrt::runtime #endif // MLIR_TENSORRT_RUNTIME_BACKEND_LUA_LUARUNTIME_H diff --git a/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp b/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp index f6320dce4..ab325a3ef 100644 --- a/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp +++ b/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp @@ -675,6 +675,16 @@ MTRT_ScalarValue mtrtRuntimeValueDynCastToScalar(MTRT_RuntimeValue v) { return wrap(static_cast(x)); } +bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value) { + RuntimeValue *x = unwrap(value); + return x->getKind() == RuntimeValue::Kind::MemRef; +} + +bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value) { + RuntimeValue *x = unwrap(value); + return x->getKind() == RuntimeValue::Kind::Scalar; +} + //===----------------------------------------------------------------------===// // MTRT_RuntimeSessionOptions //===----------------------------------------------------------------------===// @@ -721,7 +731,8 @@ MTRT_Status mtrtRuntimeSessionDestroy(MTRT_RuntimeSession session) { MTRT_Status mtrtRuntimeSessionExecuteFunction( MTRT_RuntimeSession session, MTRT_StringView name, const MTRT_RuntimeValue *inArgs, size_t numInArgs, - const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream) { + const MTRT_RuntimeValue *outArgs, size_t numOutArgs, + MTRT_RuntimeValue *results, MTRT_Stream stream, MTRT_RuntimeClient client) { LuaRuntimeSession *cppSession = static_cast(unwrap(session)); @@ -731,19 +742,36 @@ MTRT_Status mtrtRuntimeSessionExecuteFunction( llvm::SmallVector outArgValues = llvm::map_to_vector(llvm::ArrayRef(outArgs, numOutArgs), [](MTRT_RuntimeValue arg) { return unwrap(arg); }); - - StatusOr>> result = + StatusOr>> resultValues = executeFunctionWithLuaBackend( *cppSession, std::string_view(name.data, name.length), inArgValues, outArgValues, !mtrtStreamIsNull(stream) ? std::optional(unwrap(stream)->getRawStream()) - : std::nullopt); - if (!result.isOk()) - return wrap(result.getStatus()); + : std::nullopt, + !mtrtRuntimeClientIsNull(client) ? std::optional(unwrap(client)) + : std::nullopt); + if (!resultValues.isOk()) + return wrap(resultValues.getStatus()); + + for (size_t i = 0; i < resultValues->size(); ++i) + results[i] = wrap((*resultValues)[i].release()); return mtrtStatusGetOk(); } + +MTRT_Status mtrtRuntimeSessionGetNumResults(MTRT_RuntimeSession session, + MTRT_StringView name, + int64_t *numResults) { + LuaRuntimeSession *cppSession = + static_cast(unwrap(session)); + *numResults = cppSession->getExecutable() + .getFunction(std::string_view(name.data, name.length)) + .getSignature() + .getNumResults(); + return mtrtStatusGetOk(); +} + //===----------------------------------------------------------------------===// // MTRT_RuntimeClient //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/executor/lib/Conversion/MemRefToExecutor.cpp b/mlir-tensorrt/executor/lib/Conversion/MemRefToExecutor.cpp index 90e4c9681..1670e1b56 100644 --- a/mlir-tensorrt/executor/lib/Conversion/MemRefToExecutor.cpp +++ b/mlir-tensorrt/executor/lib/Conversion/MemRefToExecutor.cpp @@ -24,11 +24,13 @@ #include "mlir-executor/Conversion/ConvertToExecutorCommon.h" #include "mlir-executor/Conversion/Passes.h" #include "mlir-executor/Executor/IR/Executor.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/MathExtras.h" @@ -548,6 +550,21 @@ void executor::populateMemRefToExecutorPatterns( } namespace { + +class RemoveNoOpClonePattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(bufferization::CloneOp op, + PatternRewriter &rewriter) const override { + if (op.getInput().getType() == op.getOutput().getType()) { + rewriter.replaceOp(op, op.getInput()); + return success(); + } + return failure(); + } +}; + /// Pass to convert `memref` to `executor` dialect operrations. class ConvertMemRefToExecutorPass : public mlir::executor::impl::ConvertMemRefToExecutorPassBase< @@ -579,6 +596,10 @@ class ConvertMemRefToExecutorPass RewritePatternSet patterns(ctx); executor::populateMemRefToExecutorPatterns( patterns, typeConverter, allowUncheckedMemrefCastConversion); + + // Remove unrealized cast and redundant clone operations. + patterns.add(ctx); + if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); diff --git a/mlir-tensorrt/executor/lib/Runtime/API/API.cpp b/mlir-tensorrt/executor/lib/Runtime/API/API.cpp index 5ff748b5f..37f3236b5 100644 --- a/mlir-tensorrt/executor/lib/Runtime/API/API.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/API/API.cpp @@ -367,6 +367,7 @@ RuntimeSession::RuntimeSession(RuntimeSessionOptions options, //===----------------------------------------------------------------------===// AllocTracker::~AllocTracker() { + MTRT_DBGF("Destroying alloc tracer 0x%p", static_cast(this)); MTRT_DBGF("checking %u allocations", map.size()); llvm::SmallVector ptrsToFree; ptrsToFree.reserve(map.size()); @@ -410,7 +411,7 @@ void AllocTracker::incrementExternalCount(uintptr_t ptr) { llvm::formatv("Untracked pointer {0}", ptr).str().c_str()); std::unique_ptr const &metadata = map.at(ptr); int32_t ref = ++metadata->externalReferenceCount; - MTRT_DBG("Incremented external reference for pointer %d to %d", ptr, ref); + MTRT_DBGF("Incremented external reference for 0x%lu to %d", ptr, ref); } void AllocTracker::decrementExternalCount(uintptr_t ptr) { @@ -422,11 +423,12 @@ void AllocTracker::decrementExternalCount(uintptr_t ptr) { llvm::formatv("External reference count cannot be negative: {0}", ref) .str() .c_str()); - MTRT_DBG("Decremented external reference for pointer %d to %d", ptr, ref); + MTRT_DBGF("Decremented external reference for pointer 0x%lu to %d", ptr, ref); if (ref == 0 && metadata->releasedInternally) { - MTRT_DBG("External reference to an internally released pointer %d is 0, " - "try deallocating pointer memory of size %lu", - ptr, ref, metadata->info.size); + MTRT_DBGF( + "External reference to an internally released pointer 0x%lu is 0, " + "try deallocating pointer memory of size %lu", + ptr, metadata->info.size); Status s = safeDeallocate(*this, metadata->info.ptr); if (!s.isOk()) MTRT_DBGF("error while deallocating dangling memory: %s", @@ -452,9 +454,11 @@ void AllocTracker::track(PointerInfo info) { assert((!contains(info.ptr) || get(info.ptr).isExternallyManaged()) && "an internally managed pointer should not already be tracked"); } - MTRT_DBGF("AllocTracker is now tracking 0x%lx size=%lu space=%s ownership=%s", - info.ptr, info.size, runtime::impl::EnumNamePointerType(info.type), - runtime::impl::EnumNamePointerOwner(info.owner)); + MTRT_DBGF( + "AllocTracker {%p} is now tracking 0x%lx size=%lu space=%s ownership=%s", + static_cast(this), info.ptr, info.size, + runtime::impl::EnumNamePointerType(info.type), + runtime::impl::EnumNamePointerOwner(info.owner)); auto value = std::make_unique(); value->externalReferenceCount.store(0); value->releasedInternally = false; @@ -574,7 +578,7 @@ mlirtrt::Status runtime::safeDeallocate(AllocTracker &tracker, uintptr_t ptr, PointerInfo obj = tracker.get(ptr); if (obj.owner == PointerOwner::external) { - MTRT_DBGF("Untracking externally managed pointer 0x%lx", ptr); + MTRT_DBGF("Untracking externally managed 0x%lu", ptr); tracker.untrack(obj.ptr); return mlirtrt::Status::getOk(); } diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp index aea894610..5bc2eb2ed 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp @@ -567,10 +567,10 @@ getScalarValue(const sol::protected_function_result &pfr, int index, // Parses the results of a function call, handling both scalar and MemRef return // types -StatusOr>> -runtime::parseResults(const sol::protected_function_result &pfr, - const FunctionSignatureView &sig, - std::optional client) { +static StatusOr>> parseResults( + const sol::protected_function_result &pfr, const FunctionSignatureView &sig, + AllocTracker &sessionAllocTracker, std::optional client, + int32_t device, std::optional stream) { llvm::SmallVector> results; for (unsigned i = 0; i < sig.getNumResults(); ++i) { @@ -609,17 +609,49 @@ runtime::parseResults(const sol::protected_function_result &pfr, return getInvalidArgStatus("Runtime client cannot be nullptr"); // Create MemRefValue from extracted data - auto memref = (*client)->createExternalMemRef( - resultView.getAddressSpace(), resultView.getElementType().getBitWidth(), - allocPtr, offset, shape, strides, (*client)->getDevices()[0].get(), - resultView.getElementType()); - - if (!memref.isOk()) - return memref.getStatus(); + MTRT_DBGF("Create a memref value wrapping internally allocted ptr %p", + reinterpret_cast(allocPtr)); + + // Ensure that session registered alloction tracker tracks this pointer + // externally. + sessionAllocTracker.incrementExternalCount(allocPtr); + MTRT_DBGF("Session alloc tracker 0x%p", + static_cast(&sessionAllocTracker)); + + // Increment external reference count this since the ptr get returned. This + // pointer must be managed externally. + auto &ClientAlloctracker = (*client)->getAllocTracker(); + MTRT_DBGF("client 0x%p, Alloc tracker 0x%p", static_cast(*client), + static_cast(&ClientAlloctracker)); + + std::unique_ptr memref; + if (allocPtr) { + // Create the descriptor. + memref = std::move(*(MemRefValue::create( + *client, resultView.getAddressSpace(), + resultView.getElementType().getBitWidth(), allocPtr, offset, shape, + strides, (*client)->getDevices()[0].get(), + resultView.getElementType()))); + } else { + auto devicePtr = Device::create(device); + memref = std::move(*((*client)->allocateMemRef( + PointerType::device, resultView.getElementType().getBitWidth(), shape, + strides, devicePtr->get(), stream, resultView.getElementType()))); + ClientAlloctracker.track( + (memref)->getPointerInfo(PointerOwner::external)); + ClientAlloctracker.incrementExternalCount(memref->getMemory()); + + // Create a Pointer Info: + auto nullPointerInfo = (memref)->getPointerInfo(PointerOwner::external); + nullPointerInfo.ptr = allocPtr; + ClientAlloctracker.track(nullPointerInfo); + ClientAlloctracker.incrementExternalCount(allocPtr); + } + ClientAlloctracker.track((memref)->getPointerInfo(PointerOwner::external)); + ClientAlloctracker.incrementExternalCount(memref->getMemory()); - results.push_back(std::move(*memref)); + results.push_back(std::move(memref)); } - return results; } @@ -715,5 +747,7 @@ runtime::executeFunctionWithLuaBackend( "\": ", err.what()); } - return parseResults(pfr, sig, client); + int32_t device = session.getOptions().getDeviceId(); + + return parseResults(pfr, sig, tracker, client, device, stream); } diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp index b24355eb3..f845666b1 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp @@ -141,9 +141,11 @@ class OutputAllocatorImpl : public nvinfer1::IOutputAllocator { size = std::max(size, static_cast(1)); if (size > mOutputSize) { size = roundUp(size, alignment); - if (mOutputPtr) + if (mOutputPtr) { + MTRT_DBGF("OutputAllocator deallocating at %lu", mOutputPtr); mlirtrt::runtime::safeDeallocate(*mTracker, mOutputPtr, CudaStreamPtr(stream)); + } mOutputPtr = 0; mOutputSize = 0; StatusOr memory = @@ -152,6 +154,8 @@ class OutputAllocatorImpl : public nvinfer1::IOutputAllocator { if (memory.isOk()) { mOutputPtr = (*memory).ptr; mOutputSize = memory->size; + MTRT_DBGF("OutputAllocator allocating %lu bytes at %lu", mOutputSize, + mOutputPtr); } return reinterpret_cast(mOutputPtr); } diff --git a/mlir-tensorrt/executor/lib/Support/Allocators.cpp b/mlir-tensorrt/executor/lib/Support/Allocators.cpp index ce7310ad7..f65103972 100644 --- a/mlir-tensorrt/executor/lib/Support/Allocators.cpp +++ b/mlir-tensorrt/executor/lib/Support/Allocators.cpp @@ -296,4 +296,4 @@ Status PinnedMemoryAllocator::freeAsync(uintptr_t ptr, CudaStream stream) { return getInternalErrorStatus( "MLIR-Executor was not built with CUDA enabled"); #endif -} \ No newline at end of file +} diff --git a/mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp b/mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp index 9e8a9e50a..0c20d7da4 100644 --- a/mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp +++ b/mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp @@ -53,22 +53,31 @@ class ExecutorBufferizationTestPass } } }; + +struct PlanBufferizationPipelineCliOpts + : public PassPipelineOptions { + Option enableNonDPSReturns{ + *this, "enable-non-dps-returns", + llvm::cl::desc("allow backend clusters to directly allocate outputs"), + llvm::cl::init(false)}; +}; + } // namespace namespace mlir::executor { void registerTestExecutorBufferizePass() { PassRegistration(); - - PassPipelineRegistration<> executorBufferizationPipeline( - "test-executor-bufferization-pipeline", - "Run one-shot-bufferization and buffer deallocation pipelines", - [](OpPassManager &pm) { - pm.addPass(std::make_unique()); - pm.addPass(bufferization::createDropEquivalentBufferResultsPass()); - bufferization::BufferDeallocationPipelineOptions deallocOptions{}; - bufferization::buildBufferDeallocationPipeline(pm, deallocOptions); - pm.addPass(createCSEPass()); - pm.addPass(createCanonicalizerPass()); - }); + PassPipelineRegistration + executorBufferizationPipeline( + "test-executor-bufferization-pipeline", + "Run one-shot-bufferization and buffer deallocation pipelines", + [](OpPassManager &pm, const PlanBufferizationPipelineCliOpts &opts) { + pm.addPass(std::make_unique()); + pm.addPass(bufferization::createDropEquivalentBufferResultsPass()); + bufferization::BufferDeallocationPipelineOptions deallocOptions{}; + bufferization::buildBufferDeallocationPipeline(pm, deallocOptions); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + }); } } // namespace mlir::executor diff --git a/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp b/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp index ddb0c74e2..1654549de 100644 --- a/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp +++ b/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp @@ -600,6 +600,15 @@ static MTRT_RuntimeValue convertArgType(py::object obj) { throw std::runtime_error("argument must be MemRef or scalar"); } +/// Convert Runtime value to PyMemRefValue or PyScalarValue object. +static py::object convertGenericArgToPyObject(MTRT_RuntimeValue value) { + if (mtrtRuntimeValueIsMemRef(value)) + return py::cast(mtrtRuntimeValueDynCastToMemRef(value)); + if (mtrtRuntimeValueIsScalar(value)) + return py::cast(mtrtRuntimeValueDynCastToScalar(value)); + return py::none(); +} + //===----------------------------------------------------------------------===// // Declare the bindings. //===----------------------------------------------------------------------===// @@ -950,22 +959,43 @@ PYBIND11_MODULE(_api, m) { .def( "execute_function", [](PyRuntimeSession &self, std::string name, - std::vector inArgs, std::vector outArgs, - std::optional stream) { + std::vector inArgs, + std::optional> outArgs, + std::optional stream, PyRuntimeClient &client) { MTRT_StringView nameRef{name.data(), name.size()}; + int64_t numResults; + MTRT_Status s = + mtrtRuntimeSessionGetNumResults(self, nameRef, &numResults); + THROW_IF_MTRT_ERROR(s); + auto inArgsGeneric = llvm::map_to_vector(inArgs, convertArgType); - auto outArgsGeneric = llvm::map_to_vector(outArgs, convertArgType); + auto outArgsGeneric = + outArgs ? llvm::map_to_vector(*outArgs, convertArgType) + : llvm::SmallVector{}; + + std::vector resultsGeneric(numResults); - MTRT_Status s = mtrtRuntimeSessionExecuteFunction( + s = mtrtRuntimeSessionExecuteFunction( self, nameRef, inArgsGeneric.data(), inArgsGeneric.size(), outArgsGeneric.data(), outArgsGeneric.size(), - stream ? *stream : mtrtStreamGetNull()); + resultsGeneric.data(), stream ? *stream : mtrtStreamGetNull(), + client); THROW_IF_MTRT_ERROR(s); - }, - py::arg("name"), py::arg("in_args"), py::arg("out_args"), - py::arg("stream") = py::none()); + std::vector resultPyObject; + if (numResults > 0) { + for (const auto &arg : resultsGeneric) + resultPyObject.push_back(convertGenericArgToPyObject(arg)); + } + + return resultPyObject; + }, + py::arg("name"), py::arg("in_args"), py::arg("out_args") = py::none(), + py::arg("stream") = py::none(), py::arg("client"), + "Execute a function given input and optional output arguments. " + "Return optional results as a Python object if output arguments are " + "not present."); py::class_(m, "GlobalDebug", py::module_local()) .def_property_static("flag", &PyGlobalDebugFlag::get, &PyGlobalDebugFlag::set, "LLVM-wide debug flag") @@ -977,4 +1007,4 @@ PYBIND11_MODULE(_api, m) { py::overload_cast &>( &PyGlobalDebugFlag::set_types), "Sets specific debug types to be produced by LLVM"); -} \ No newline at end of file +} diff --git a/mlir-tensorrt/test/python/IntegrationTests/TRT10/test_stablehlo_add.py b/mlir-tensorrt/test/python/IntegrationTests/TRT10/test_stablehlo_add.py index 480ce74d4..9535aef79 100644 --- a/mlir-tensorrt/test/python/IntegrationTests/TRT10/test_stablehlo_add.py +++ b/mlir-tensorrt/test/python/IntegrationTests/TRT10/test_stablehlo_add.py @@ -36,7 +36,11 @@ def test_stablehlo_add( session = runtime.RuntimeSession(session_options, exe) session.execute_function( - "main", in_args=test.in_args, out_args=test.out_args, stream=stream + "main", + in_args=test.in_args, + out_args=test.out_args, + stream=stream, + client=runtime_client, ) output = [ ( diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_call_validation.py b/mlir-tensorrt/test/python/IntegrationTests/test_call_validation.py index 415767d70..ee3c784f8 100644 --- a/mlir-tensorrt/test/python/IntegrationTests/test_call_validation.py +++ b/mlir-tensorrt/test/python/IntegrationTests/test_call_validation.py @@ -73,7 +73,11 @@ def execute(self, arg: runtime.RuntimeValue): session = runtime.RuntimeSession(self.session_options, self.exe) try: session.execute_function( - "main", in_args=[arg], out_args=[arg], stream=self.stream + "main", + in_args=[arg], + out_args=[arg], + stream=self.stream, + client=self.client, ) print("Test passed succesfully") except runtime.MTRTException as e: diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_executable_serialize.py b/mlir-tensorrt/test/python/IntegrationTests/test_executable_serialize.py index e4bb3cba5..c841a334e 100644 --- a/mlir-tensorrt/test/python/IntegrationTests/test_executable_serialize.py +++ b/mlir-tensorrt/test/python/IntegrationTests/test_executable_serialize.py @@ -47,7 +47,9 @@ def test_serialize(ASM): device=devices[0], stream=stream, ) - session0.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream) + session0.execute_function( + "main", in_args=[arg0], out_args=[arg1], stream=stream, client=client + ) output0 = np.asarray(client.copy_to_host(arg1, stream=stream)) stream.sync() @@ -57,7 +59,9 @@ def test_serialize(ASM): exe_reconstructed = compiler.Executable(serialized_exe) session1 = runtime.RuntimeSession(session_options, exe_reconstructed) - session1.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream) + session1.execute_function( + "main", in_args=[arg0], out_args=[arg1], stream=stream, client=client + ) output1 = np.asarray(client.copy_to_host(arg1, stream=stream)) stream.sync() diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_add.py b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_add.py index 2c95a3081..4e9b9ceff 100644 --- a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_add.py +++ b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_add.py @@ -14,17 +14,23 @@ """ -def stablehlo_add(): +def stablehlo_add(use_non_dps=False, debug=False): # Build/parse the main function. with ir.Context() as context: m = ir.Module.parse(ASM) # Use the compiler API to compile to executable. client = compiler.CompilerClient(context) - opts = compiler.StableHLOToExecutableOptions( - client, - ["--tensorrt-builder-opt-level=3", "--tensorrt-strongly-typed=false"], - ) + c_opts = [ + "--tensorrt-builder-opt-level=3", + "--tensorrt-strongly-typed=false", + ] + if use_non_dps: + c_opts.append("--enable-non-dps-returns") + if debug: + c_opts.append("--debug=true") + c_opts.append(f"--mlir-print-ir-tree-dir=mlir-dumps-add-no-clone") + opts = compiler.StableHLOToExecutableOptions(client, c_opts) exe = compiler.compiler_stablehlo_to_executable(client, m.operation, opts) # The RuntimeClient can and should persist across multiple Executables, RuntimeSessions, etc. @@ -44,36 +50,53 @@ def stablehlo_add(): device=devices[0], stream=stream, ) - arg1 = client.create_memref( - np.zeros(shape=(2, 3, 4), dtype=np.float32).data, - device=devices[0], - stream=stream, - ) - session.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream) - data = np.asarray(client.copy_to_host(arg1, stream=stream)) + result = None + if use_non_dps: + results = session.execute_function( + "main", in_args=[arg0], stream=stream, client=client + ) + result = results[0] + else: + result = client.create_memref( + np.zeros(shape=(2, 3, 4), dtype=np.float32).data, + device=devices[0], + stream=stream, + ) + session.execute_function( + "main", in_args=[arg0], out_args=[result], stream=stream, client=client + ) + + data = np.asarray(client.copy_to_host(result, stream=stream)) stream.sync() print(data) - # Run execution a bunch more times asynchronously so that it calculates - # `x * 2**num_iter`. - num_iter = 5 - start_time = time.time() - for _ in range(0, num_iter): - session.execute_function("main", in_args=[arg0], out_args=[arg0], stream=stream) - data = np.asarray(client.copy_to_host(arg1, stream=stream)) - stream.sync() - end_time = time.time() - elapsed = end_time - start_time + if not use_non_dps: + # Run execution a bunch more times asynchronously so that it calculates + # `x * 2**num_iter`. + num_iter = 5 + start_time = time.time() + for _ in range(0, num_iter): + session.execute_function( + "main", in_args=[arg0], out_args=[arg0], stream=stream, client=client + ) + data = np.asarray(client.copy_to_host(arg0, stream=stream)) + stream.sync() + end_time = time.time() + elapsed = end_time - start_time - print(np.asarray(client.copy_to_host(arg0))) - print(f"1000 iterations avg { (elapsed/num_iter)/1000.0} msec per iteration") + print(np.asarray(client.copy_to_host(arg0))) + print(f"1000 iterations avg { (elapsed/num_iter)/1000.0} msec per iteration") if __name__ == "__main__": + print("DPS style execution:") stablehlo_add() + print("Non DPS style execution:") + stablehlo_add(use_non_dps=True) +# CHECK-LABEL: DPS style execution: # CHECK: [ 0. 2. 4. 6.] # CHECK-NEXT: [ 8. 10. 12. 14.] # CHECK-NEXT: [16. 18. 20. 22.]] @@ -88,3 +111,11 @@ def stablehlo_add(): # CHECK-NEXT: [384. 416. 448. 480.] # CHECK-NEXT: [512. 544. 576. 608.] # CHECK-NEXT: [640. 672. 704. 736.] +# CHECK-LABEL: DPS style execution: +# CHECK: [ 0. 2. 4. 6.] +# CHECK-NEXT: [ 8. 10. 12. 14.] +# CHECK-NEXT: [16. 18. 20. 22.]] +# CHECK-NEXT: +# CHECK-NEXT: [24. 26. 28. 30.] +# CHECK-NEXT: [32. 34. 36. 38.] +# CHECK-NEXT: [40. 42. 44. 46.]]] diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic.py b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic.py index 35515e054..364156667 100644 --- a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic.py +++ b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic.py @@ -77,7 +77,10 @@ def infer_output_shape(client, session, exe, input_shape): outs = [client.create_memref(out_0, shape=shape, dtype=runtime.ScalarTypeCode.i64)] session.execute_function( - exe.get_signature("main").get_shape_func_name(), in_args=ins, out_args=outs + exe.get_signature("main").get_shape_func_name(), + in_args=ins, + out_args=outs, + client=client, ) # Copy output shape from device to host. Also, convert to int32 type since shape calculation uses int64 type. @@ -86,21 +89,23 @@ def infer_output_shape(client, session, exe, input_shape): return output_shape -def test_program(program: str, input_shape: Iterable[int], debug: bool = True): +def test_program( + program: str, input_shape: Iterable[int], use_non_dps=False, debug: bool = False +): # Build/parse the main function. with ir.Context() as context: m = ir.Module.parse(program) # Use the compiler API to compile to executable. client = compiler.CompilerClient(context) - opts = compiler.StableHLOToExecutableOptions( - client, - [ - "--tensorrt-builder-opt-level=3", - "--tensorrt-strongly-typed=false", - "--entrypoint=main", - ], - ) + c_opts = [ + "--tensorrt-builder-opt-level=3", + "--tensorrt-strongly-typed=false", + "--entrypoint=main", + ] + if use_non_dps: + c_opts.append("--enable-non-dps-returns") + opts = compiler.StableHLOToExecutableOptions(client, c_opts) if debug: opts.set_debug_options(False, [], "tmp") exe = compiler.compiler_stablehlo_to_executable(client, m.operation, opts) @@ -126,29 +131,66 @@ def test_program(program: str, input_shape: Iterable[int], debug: bool = True): np.ones(input_shape, dtype=np.float32).data, device=devices[0], stream=stream ) - output_shape = infer_output_shape(client, session, exe, input_shape) - - arg2 = client.create_memref( - np.zeros(output_shape, dtype=np.float32).data, - device=devices[0], - stream=stream, - ) - - session.execute_function( - "main", in_args=[arg0, arg1], out_args=[arg2], stream=stream - ) - data = np.asarray(client.copy_to_host(arg2, stream=stream)) + result = None + if use_non_dps: + results = session.execute_function( + "main", in_args=[arg0, arg1], stream=stream, client=client + ) + result = results[0] + else: + output_shape = infer_output_shape(client, session, exe, input_shape) + result = client.create_memref( + np.zeros(output_shape, dtype=np.float32).data, + device=devices[0], + stream=stream, + ) + session.execute_function( + "main", + in_args=[arg0, arg1], + out_args=[result], + stream=stream, + client=client, + ) + data = np.asarray(client.copy_to_host(result, stream=stream)) stream.sync() print(data) if __name__ == "__main__": + print("DPS style execution:") print("Test (3, ?, 2)") test_program(program1, (3, 4, 2)) print("Test (?, 2)") test_program(program2, (4, 2)) + print("Non DPS style execution:") + print("Test (3, ?, 2)") + test_program(program1, (3, 4, 2), use_non_dps=True) + print("Test (?, 2)") + test_program(program2, (4, 2), use_non_dps=True) + +# CHECK-LABEL: DPS style execution: +# CHECK-LABEL: Test (3, ?, 2) +# CHECK: [{{\[}}[2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.]] +# CHECK: {{\[}}[2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.]] +# CHECK: {{\[}}[2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.]]] + +# CHECK-LABEL: Test (?, 2) +# CHECK: {{\[}}[2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.]] +# CHECK-LABEL: Non DPS style execution: # CHECK-LABEL: Test (3, ?, 2) # CHECK: [{{\[}}[2. 2.] # CHECK: [2. 2.] diff --git a/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_runtime_debug_dump.py b/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_runtime_debug_dump.py index 68ca684e8..4d379c3e6 100644 --- a/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_runtime_debug_dump.py +++ b/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_runtime_debug_dump.py @@ -49,7 +49,9 @@ def stablehlo_add(): device=devices[0], stream=stream, ) - session.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream) + session.execute_function( + "main", in_args=[arg0], out_args=[arg1], stream=stream, client=client + ) data = np.asarray(client.copy_to_host(arg1, stream=stream)) stream.sync()