Skip to content

Commit c51f4f1

Browse files
committed
Enable end to end non-DPS testing
Implement python binding changes to allow execute function return multiple returns. Update tests to use non-DPS style calling convention. Also, enable end to end lowering by enabling conversion of closed alloc group op to tensorrt dialect. Miscellaneous fixes: 1. Add missing handling of `CallAllocOp` in EliminateShapeOps pass. 2. Skip non ranked tensor type function arguments while collecting host tensor arguments. 3. Temporarily add a pass to remove clone operation in MemRefToExecutor dialect conversion. 4. Relax memref creation for empty shape tensors. 5. Fix memref life returned from Lua function results. This required session allocator to track returned memref.
1 parent 3b7a2af commit c51f4f1

File tree

21 files changed

+623
-100
lines changed

21 files changed

+623
-100
lines changed

mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ StableHLOToExecutableOptions::StableHLOToExecutableOptions(
223223
disallowHostTensorsInTensorRTClusters, llvm::cl::init(false),
224224
llvm::cl::desc("Don't allow TensorRt clusters to contain host tensor "
225225
"calculations (but they can still be inputs)"));
226+
addOption(
227+
"enable-non-dps-returns", enableNonDPSReturns, llvm::cl::init(false),
228+
llvm::cl::desc(
229+
"allow tensorrt based output allocations using output allocator"));
226230
addOption("executor-index-bitwidth", executorIndexBitwidth,
227231
llvm::cl::init(64));
228232
addOption("device-compute-capability", deviceComputeCapability,
@@ -307,6 +311,7 @@ void StableHloToExecutableTask::buildStablehloClusteringPipeline(
307311
plan::StablehloClusteringPassOptions clusteringOpts{};
308312
clusteringOpts.disallowHostTensorsInTensorRTClusters =
309313
opts.disallowHostTensorsInTensorRTClusters;
314+
clusteringOpts.enableNonDPSReturns = opts.enableNonDPSReturns;
310315
clusteringOpts.entrypoint = opts.entrypoint;
311316
plan::buildPlanSegmentationPipeline(pm, clusteringOpts);
312317

@@ -340,7 +345,9 @@ void StableHloToExecutableTask::buildPostClusteringPipeline(
340345

341346
// Perform bufferization.
342347
pm.addPass(createMemRefCastEliminationPass());
343-
pm.addPass(plan::createPlanAllocTensorsPass());
348+
plan::PlanAllocTensorsPassOptions allocTensorsOpts{};
349+
allocTensorsOpts.enableNonDPSReturns = opts.enableNonDPSReturns;
350+
pm.addPass(plan::createPlanAllocTensorsPass(allocTensorsOpts));
344351
pm.addPass(plan::createPlanBufferizePass());
345352
pm.addPass(createMemRefCastEliminationPass());
346353
pm.addPass(createCanonicalizerPass());
@@ -529,6 +536,11 @@ struct ClusteringPipelineCliOpts
529536
*this, "device-compute-capability",
530537
llvm::cl::desc("target device compute capability (SM version)"),
531538
llvm::cl::init(60)};
539+
Option<bool> enableNonDPSReturns{
540+
*this, "enable-non-dps-returns",
541+
llvm::cl::desc(
542+
"allow tensorrt based output allocations using output allocator"),
543+
llvm::cl::init(false)};
532544
Option<int64_t> deviceMaxSharedMemoryPerBlockKb{
533545
*this, "device-max-smem-per-block",
534546
llvm::cl::desc("max shared memory per block (in kilobytes)"),
@@ -556,6 +568,7 @@ static StableHLOToExecutableOptions populateStablehloClusteringPipelineOpts(
556568
opts.deviceComputeCapability = cliOpts.deviceComputeCapability;
557569
opts.deviceMaxSharedMemoryPerBlockKb =
558570
cliOpts.deviceMaxSharedMemoryPerBlockKb;
571+
opts.enableNonDPSReturns = cliOpts.enableNonDPSReturns;
559572
opts.shouldInferDeviceOptionsFromHost = cliOpts.inferDeviceOptionsFromHost;
560573
opts.entrypoint = cliOpts.entrypoint;
561574
return opts;

mlir-tensorrt/compiler/lib/Conversion/TensorRTToTensorRTRuntime/TensorRTToTensorRTRuntime.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ convertCallOp(Operation *op, IRRewriter &rewriter,
9393
SmallVector<int64_t> hostTensorArgs;
9494
for (auto [idx, arg] : llvm::enumerate(trtFunc.getArguments())) {
9595
const TensorKindLattice *kind = solver.lookupState<TensorKindLattice>(arg);
96+
if (!isa<RankedTensorType>(arg.getType()))
97+
continue;
9698
RankedTensorType rtt = cast<RankedTensorType>(arg.getType());
9799
// To be conservative, we only do this if type is i32 and num elements
98100
// <= 8.

mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,26 @@ struct RemoveWithValuesRewriter : public OpRewritePattern<plan::WithValuesOp> {
6565
} // namespace
6666

6767
/// Get a map from `tensorrt.func` functions to associated `tensorrt.call`
68-
/// operations.
69-
static llvm::DenseMap<func::FuncOp, SmallVector<tensorrt::CallOp>>
68+
/// and `tensorrt.call_alloc` operations.
69+
static llvm::DenseMap<func::FuncOp, SmallVector<Operation *>>
7070
getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) {
71-
llvm::DenseMap<func::FuncOp, SmallVector<tensorrt::CallOp>> map;
72-
op->walk([&](tensorrt::CallOp callOp) {
73-
func::FuncOp func = callOp.getFuncCallee(collection);
74-
if (map.contains(func)) {
75-
map[func].push_back(callOp);
71+
llvm::DenseMap<func::FuncOp, SmallVector<Operation *>> map;
72+
op->walk([&](Operation *callOp) {
73+
if (!isa<tensorrt::CallOp, tensorrt::CallAllocOp>(callOp))
7674
return;
77-
}
78-
map.insert(std::make_pair(func, SmallVector<tensorrt::CallOp>{callOp}));
75+
76+
func::FuncOp func;
77+
if (auto call = dyn_cast<tensorrt::CallOp>(callOp))
78+
func = call.getFuncCallee(collection);
79+
else if (auto callAlloc = dyn_cast<tensorrt::CallAllocOp>(callOp))
80+
func = callAlloc.getFuncCallee(collection);
81+
else
82+
return;
83+
84+
if (map.count(func))
85+
map[func].push_back(callOp);
86+
else
87+
map.insert({func, SmallVector<Operation *>{callOp}});
7988
});
8089
return map;
8190
}
@@ -84,7 +93,7 @@ getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) {
8493
/// `tensorrt.call` operations.
8594
static LogicalResult removeUnusedArgs(SymbolTableCollection &collection,
8695
ModuleOp op, func::FuncOp funcOp,
87-
ArrayRef<tensorrt::CallOp> callOps) {
96+
ArrayRef<Operation *> callOps) {
8897
llvm::SmallBitVector unusedArgs(funcOp.getNumArguments(), 0);
8998
for (BlockArgument arg : funcOp.getArguments()) {
9099
if (arg.use_empty())
@@ -99,8 +108,17 @@ static LogicalResult removeUnusedArgs(SymbolTableCollection &collection,
99108
funcOp.eraseArgument(i);
100109

101110
// Update the call ops.
102-
for (tensorrt::CallOp callOp : callOps)
103-
callOp.getInputsMutable().erase(i);
111+
for (Operation *callOp : callOps) {
112+
if (auto call = dyn_cast<tensorrt::CallOp>(callOp))
113+
call.getInputsMutable().erase(i);
114+
else if (auto callAlloc = dyn_cast<tensorrt::CallAllocOp>(callOp))
115+
callAlloc.getInputsMutable().erase(i);
116+
else {
117+
llvm::errs() << "Unexpected operation type in callOps\n";
118+
callOp->dump();
119+
return failure();
120+
}
121+
}
104122
}
105123

106124
return success();

mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,82 @@ static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,
268268

269269
static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,
270270
plan::InlineClosedAllocGroupOp op) {
271-
return op.emitError("outlinining inline closed alloc group ops to tensorrt "
272-
"dialect is not yet implemented");
271+
tensorrt::TensorRTModuleOp trtModuleOp = getOrCreateTensorRTModuleOp(op);
272+
auto funcArgTypes = llvm::to_vector(TypeRange(op.getInputs()));
273+
FailureOr<FunctionOpInterface> func = createOutlinedFunc(
274+
rewriter, op.getLoc(), op, trtModuleOp, "tensorrt_cluster",
275+
"cluster.tensorrt", TypeRange(op.getInputs()),
276+
op.getYield()->getOperandTypes());
277+
if (failed(func))
278+
return failure();
279+
assert(func->getFunctionBody().getBlocks().size() == 1 &&
280+
"expected body with one block");
281+
func->setPublic();
282+
283+
rewriter.setInsertionPoint(op);
284+
285+
auto callOp = rewriter.create<tensorrt::CallAllocOp>(
286+
op.getLoc(), op.getResultTypes(), op.getInputs(),
287+
SymbolRefAttr::get(trtModuleOp.getNameAttr(),
288+
{FlatSymbolRefAttr::get(*func)}));
289+
290+
// Populate the function arguments attributes.
291+
for (unsigned i = 0; i < (*func).getNumArguments(); i++) {
292+
BoundsAttr srcAttr = cast<BoundsAttr>(op.getInputAttrs()[i]);
293+
// We may have scalar (index|signless int)-typed values since we haven't
294+
// eliminated `plan.(with_shape|with_values)` ops yet.
295+
if (!op.argHasTensorType(i) || srcAttr.isNone())
296+
continue;
297+
FailureOr<tensorrt::ShapeProfileAttr> boundAttr =
298+
getTensorRTShapeProfile(srcAttr, op.getInputs()[i]);
299+
if (failed(boundAttr))
300+
return op->emitOpError("failed to create TensorRT shape profile "
301+
"attribute from Plan BoundsAttr for argument #")
302+
<< i << " (" << srcAttr << ")";
303+
if (srcAttr.isShapeBound()) {
304+
func->setArgAttr(i,
305+
tensorrt::TensorRTDialect::getShapeProfileArgAttrName(),
306+
*boundAttr);
307+
continue;
308+
}
309+
assert(srcAttr.isValueBound() && "expected value bound or shape bound");
310+
func->setArgAttr(
311+
i, tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName(),
312+
*boundAttr);
313+
func->setArgAttr(i, mlir::getHostTensorArgAttrName(),
314+
rewriter.getUnitAttr());
315+
}
316+
317+
// Populate the function entry block.
318+
rewriter.eraseBlock(&func->getFunctionBody().front());
319+
320+
// Move private decomposition funcs associated with all `stablehlo.composite`
321+
// ops to the `tensorrt.module` op. This is needed since `tensorrt.module` op
322+
// has its own symbol table.
323+
SymbolTableCollection symbolTable;
324+
for (auto compositeOp : op.getBody().getOps<stablehlo::CompositeOp>()) {
325+
auto decompositionFunc = dyn_cast_if_present<func::FuncOp>(
326+
symbolTable.lookupSymbolIn(op->getParentOfType<ModuleOp>(),
327+
compositeOp.getDecompositionAttr()));
328+
if (!decompositionFunc)
329+
return emitError(compositeOp.getLoc())
330+
<< "failed to lookup stablehlo.composite decomposition "
331+
"function: "
332+
<< compositeOp.getDecompositionAttr();
333+
rewriter.moveOpAfter(decompositionFunc, func->getOperation());
334+
}
335+
336+
// Move region op operations to the func body.
337+
Operation *regionYieldOp = op.getYield();
338+
rewriter.inlineRegionBefore(op.getRegion(), func->getFunctionBody(),
339+
func->getFunctionBody().end());
340+
rewriter.setInsertionPoint(regionYieldOp);
341+
rewriter.replaceOpWithNewOp<func::ReturnOp>(regionYieldOp,
342+
regionYieldOp->getOperands());
343+
344+
// replace the original region results.
345+
rewriter.replaceOp(op, callOp);
346+
return success();
273347
}
274348

275349
/// Create outlined functions for each `scf.execute_region` operation within

mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,11 @@ static inline bool mtrtRuntimeClientIsNull(MTRT_RuntimeClient client) {
215215
return !client.ptr;
216216
}
217217

218+
/// Returns null client.
219+
static inline MTRT_RuntimeClient mtrtRuntimeClientGetNull() {
220+
return MTRT_RuntimeClient{nullptr};
221+
}
222+
218223
/// Creates a `MTRT_RuntimeClient`. Client must be alive for the lifetime of the
219224
/// program execution.
220225
/// The `stream` passed to the client is used by all underlying CUDA methods
@@ -308,6 +313,12 @@ static inline bool mtrtRuntimeValueIsNull(MTRT_RuntimeValue value) {
308313
return !value.ptr;
309314
}
310315

316+
// Returns whether the RuntimeValue is MemRef.
317+
MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value);
318+
319+
// Returns whether the RuntimeValue is Scalar.
320+
MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value);
321+
311322
/// Cast a MTRT_MemRefValue to a generic MTRT_RuntimeValue.
312323
MLIR_CAPI_EXPORTED MTRT_RuntimeValue
313324
mtrtMemRefCastToRuntimeValue(MTRT_MemRefValue memref);
@@ -391,16 +402,25 @@ static inline bool mtrtRuntimeSessionIsNull(MTRT_RuntimeSession session) {
391402
return !session.ptr;
392403
}
393404

394-
/// Using `session`, execute the pubic function with the specified name.
395-
/// The `inArgs` and `outArgs` are arrays for input arguments and destination
396-
/// arguments, respectively. Input arguments may be MemRefs or scalars, but
397-
/// destination arguments must be MemRefs.
405+
/// Using `session`, execute the public function with the specified name.
406+
/// The `inArgs`, `outArgs`, and `results` are arrays for input arguments,
407+
/// output arguments, and return values, respectively. Arguments and results
408+
/// can be MemRefs, scalars, or other supported types. Both `outArgs` and
409+
/// `results` can be used simultaneously, allowing for functions that both
410+
/// modify arguments and return values.
398411
/// A stream may optionally be specified, otherwise pass the result of
399412
/// `mtrtStreamGetNull()`.
413+
///
414+
/// The `results` array must point to an array with at least the number of
415+
/// elements returned by mtrtRuntimeSessionGetNumResults for the given function.
400416
MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionExecuteFunction(
401417
MTRT_RuntimeSession session, MTRT_StringView name,
402418
const MTRT_RuntimeValue *inArgs, size_t numInArgs,
403-
const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream);
419+
const MTRT_RuntimeValue *outArgs, size_t numOutArgs,
420+
MTRT_RuntimeValue *results, MTRT_Stream stream, MTRT_RuntimeClient client);
421+
422+
MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionGetNumResults(
423+
MTRT_RuntimeSession session, MTRT_StringView name, int64_t *numResults);
404424

405425
//===----------------------------------------------------------------------===//
406426
// DLPack

mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,6 @@ executeFunctionWithLuaBackend(LuaRuntimeSession &session, std::string_view name,
104104
std::optional<CudaStream> stream = {},
105105
std::optional<RuntimeClient *> client = {});
106106

107-
// Parses the results of a function call, handling both scalar and MemRef return
108-
// types
109-
StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>>
110-
parseResults(const sol::protected_function_result &pfr,
111-
const FunctionSignatureView &sig,
112-
std::optional<RuntimeClient *> client);
113-
114107
} // namespace mlirtrt::runtime
115108

116109
#endif // MLIR_TENSORRT_RUNTIME_BACKEND_LUA_LUARUNTIME_H

mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,16 @@ MTRT_ScalarValue mtrtRuntimeValueDynCastToScalar(MTRT_RuntimeValue v) {
675675
return wrap(static_cast<ScalarValue *>(x));
676676
}
677677

678+
bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value) {
679+
RuntimeValue *x = unwrap(value);
680+
return x->getKind() == RuntimeValue::Kind::MemRef;
681+
}
682+
683+
bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value) {
684+
RuntimeValue *x = unwrap(value);
685+
return x->getKind() == RuntimeValue::Kind::Scalar;
686+
}
687+
678688
//===----------------------------------------------------------------------===//
679689
// MTRT_RuntimeSessionOptions
680690
//===----------------------------------------------------------------------===//
@@ -721,7 +731,8 @@ MTRT_Status mtrtRuntimeSessionDestroy(MTRT_RuntimeSession session) {
721731
MTRT_Status mtrtRuntimeSessionExecuteFunction(
722732
MTRT_RuntimeSession session, MTRT_StringView name,
723733
const MTRT_RuntimeValue *inArgs, size_t numInArgs,
724-
const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream) {
734+
const MTRT_RuntimeValue *outArgs, size_t numOutArgs,
735+
MTRT_RuntimeValue *results, MTRT_Stream stream, MTRT_RuntimeClient client) {
725736
LuaRuntimeSession *cppSession =
726737
static_cast<LuaRuntimeSession *>(unwrap(session));
727738

@@ -731,19 +742,36 @@ MTRT_Status mtrtRuntimeSessionExecuteFunction(
731742
llvm::SmallVector<RuntimeValue *> outArgValues =
732743
llvm::map_to_vector(llvm::ArrayRef(outArgs, numOutArgs),
733744
[](MTRT_RuntimeValue arg) { return unwrap(arg); });
734-
735-
StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>> result =
745+
StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>> resultValues =
736746
executeFunctionWithLuaBackend(
737747
*cppSession, std::string_view(name.data, name.length), inArgValues,
738748
outArgValues,
739749
!mtrtStreamIsNull(stream)
740750
? std::optional(unwrap(stream)->getRawStream())
741-
: std::nullopt);
742-
if (!result.isOk())
743-
return wrap(result.getStatus());
751+
: std::nullopt,
752+
!mtrtRuntimeClientIsNull(client) ? std::optional(unwrap(client))
753+
: std::nullopt);
754+
if (!resultValues.isOk())
755+
return wrap(resultValues.getStatus());
756+
757+
for (size_t i = 0; i < resultValues->size(); ++i)
758+
results[i] = wrap((*resultValues)[i].release());
744759

745760
return mtrtStatusGetOk();
746761
}
762+
763+
MTRT_Status mtrtRuntimeSessionGetNumResults(MTRT_RuntimeSession session,
764+
MTRT_StringView name,
765+
int64_t *numResults) {
766+
LuaRuntimeSession *cppSession =
767+
static_cast<LuaRuntimeSession *>(unwrap(session));
768+
*numResults = cppSession->getExecutable()
769+
.getFunction(std::string_view(name.data, name.length))
770+
.getSignature()
771+
.getNumResults();
772+
return mtrtStatusGetOk();
773+
}
774+
747775
//===----------------------------------------------------------------------===//
748776
// MTRT_RuntimeClient
749777
//===----------------------------------------------------------------------===//

mlir-tensorrt/executor/lib/Conversion/MemRefToExecutor.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir-executor/Conversion/ConvertToExecutorCommon.h"
2525
#include "mlir-executor/Conversion/Passes.h"
2626
#include "mlir-executor/Executor/IR/Executor.h"
27+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
2728
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2829
#include "mlir/IR/ImplicitLocOpBuilder.h"
2930
#include "mlir/IR/Matchers.h"
@@ -548,6 +549,21 @@ void executor::populateMemRefToExecutorPatterns(
548549
}
549550

550551
namespace {
552+
553+
class RemoveNoOpClonePattern : public OpRewritePattern<bufferization::CloneOp> {
554+
public:
555+
using OpRewritePattern<bufferization::CloneOp>::OpRewritePattern;
556+
557+
LogicalResult matchAndRewrite(bufferization::CloneOp op,
558+
PatternRewriter &rewriter) const override {
559+
if (op.getInput().getType() == op.getOutput().getType()) {
560+
rewriter.replaceOp(op, op.getInput());
561+
return success();
562+
}
563+
return failure();
564+
}
565+
};
566+
551567
/// Pass to convert `memref` to `executor` dialect operrations.
552568
class ConvertMemRefToExecutorPass
553569
: public mlir::executor::impl::ConvertMemRefToExecutorPassBase<
@@ -579,6 +595,10 @@ class ConvertMemRefToExecutorPass
579595
RewritePatternSet patterns(ctx);
580596
executor::populateMemRefToExecutorPatterns(
581597
patterns, typeConverter, allowUncheckedMemrefCastConversion);
598+
599+
// Remove unrealized cast and redundant clone operations.
600+
patterns.add<RemoveNoOpClonePattern>(ctx);
601+
582602
if (failed(applyPartialConversion(getOperation(), target,
583603
std::move(patterns))))
584604
return signalPassFailure();

0 commit comments

Comments
 (0)