Skip to content

Commit

Permalink
Enable end to end non-DPS testing
Browse files Browse the repository at this point in the history
Implement python binding changes to allow execute function return
multiple returns. Update tests to use non-DPS style calling convention.

Also, enable end to end lowering by enabling conversion of closed alloc
group op to tensorrt dialect.

Miscellaneous fixes:
1. Add missing handling of `CallAllocOp` in EliminateShapeOps pass.
2. Skip non ranked tensor type function arguments while collecting host
   tensor arguments.
3. Temporarily add a pass to remove clone operation in MemRefToExecutor
   dialect conversion.
4. Relax memref creation for empty shape tensors.
5. Fix memref life returned from Lua function results. This required
   session allocator to track returned memref.

Also, address
Fix incorrect indexing into output memref results
Return error status instead of silently erroring out during TensorRT weight conversion
Address review comments
  • Loading branch information
jhalakpatel committed Nov 9, 2024
1 parent 59d09a9 commit eed9be4
Show file tree
Hide file tree
Showing 22 changed files with 888 additions and 216 deletions.
15 changes: 14 additions & 1 deletion mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ StableHLOToExecutableOptions::StableHLOToExecutableOptions(
disallowHostTensorsInTensorRTClusters, llvm::cl::init(false),
llvm::cl::desc("Don't allow TensorRt clusters to contain host tensor "
"calculations (but they can still be inputs)"));
addOption(
"enable-non-dps-returns", enableNonDPSReturns, llvm::cl::init(false),
llvm::cl::desc(
"allow tensorrt based output allocations using output allocator"));
addOption("executor-index-bitwidth", executorIndexBitwidth,
llvm::cl::init(64));
addOption("device-compute-capability", deviceComputeCapability,
Expand Down Expand Up @@ -307,6 +311,7 @@ void StableHloToExecutableTask::buildStablehloClusteringPipeline(
plan::StablehloClusteringPassOptions clusteringOpts{};
clusteringOpts.disallowHostTensorsInTensorRTClusters =
opts.disallowHostTensorsInTensorRTClusters;
clusteringOpts.enableNonDPSReturns = opts.enableNonDPSReturns;
clusteringOpts.entrypoint = opts.entrypoint;
plan::buildPlanSegmentationPipeline(pm, clusteringOpts);

Expand Down Expand Up @@ -340,7 +345,9 @@ void StableHloToExecutableTask::buildPostClusteringPipeline(

// Perform bufferization.
pm.addPass(createMemRefCastEliminationPass());
pm.addPass(plan::createPlanAllocTensorsPass());
plan::PlanAllocTensorsPassOptions allocTensorsOpts{};
allocTensorsOpts.enableNonDPSReturns = opts.enableNonDPSReturns;
pm.addPass(plan::createPlanAllocTensorsPass(allocTensorsOpts));
pm.addPass(plan::createPlanBufferizePass());
pm.addPass(createMemRefCastEliminationPass());
pm.addPass(createCanonicalizerPass());
Expand Down Expand Up @@ -529,6 +536,11 @@ struct ClusteringPipelineCliOpts
*this, "device-compute-capability",
llvm::cl::desc("target device compute capability (SM version)"),
llvm::cl::init(60)};
Option<bool> enableNonDPSReturns{
*this, "enable-non-dps-returns",
llvm::cl::desc(
"allow tensorrt based output allocations using output allocator"),
llvm::cl::init(false)};
Option<int64_t> deviceMaxSharedMemoryPerBlockKb{
*this, "device-max-smem-per-block",
llvm::cl::desc("max shared memory per block (in kilobytes)"),
Expand Down Expand Up @@ -556,6 +568,7 @@ static StableHLOToExecutableOptions populateStablehloClusteringPipelineOpts(
opts.deviceComputeCapability = cliOpts.deviceComputeCapability;
opts.deviceMaxSharedMemoryPerBlockKb =
cliOpts.deviceMaxSharedMemoryPerBlockKb;
opts.enableNonDPSReturns = cliOpts.enableNonDPSReturns;
opts.shouldInferDeviceOptionsFromHost = cliOpts.inferDeviceOptionsFromHost;
opts.entrypoint = cliOpts.entrypoint;
return opts;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,20 +379,22 @@ struct ConvertEnqueueAllocToCall

// Create output memrefs from output descriptors
SmallVector<Value> results;
// Initialize output descriptor offset to skip number of results.
// `outputDescOffset` is used to retrieve rank, ptr, shapes, and strides per
// result.
outputDescOffset = 1;
for (unsigned i = 0; i < op->getNumResults(); ++i) {
unsigned rank = cast<MemRefType>(op->getResult(i).getType()).getRank();
unsigned offset =
1 +
i * (2 * rank + 2); // num res, (i * (rank, ptr, [shape], [stride]))

Value rankOffset = b.create<executor::GetOffsetOp>(
b.getI64Type(), structType,
ArrayRef<OpFoldResult>{this->createIndexConstant(b, 0),
rewriter.getI64IntegerAttr(offset++)});
ArrayRef<OpFoldResult>{
this->createIndexConstant(b, 0),
rewriter.getI64IntegerAttr(outputDescOffset++)});
Value devicePtrOffset = b.create<executor::GetOffsetOp>(
b.getI64Type(), structType,
ArrayRef<OpFoldResult>{this->createIndexConstant(b, 0),
rewriter.getI64IntegerAttr(offset++)});
ArrayRef<OpFoldResult>{
this->createIndexConstant(b, 0),
rewriter.getI64IntegerAttr(outputDescOffset++)});

[[maybe_unused]] Value rankValue = b.create<executor::LoadOp>(
b.getI64Type(), outputDescriptors, rankOffset);
Expand All @@ -406,8 +408,9 @@ struct ConvertEnqueueAllocToCall
for (unsigned r = 0; r < rank; ++r) {
Value shapeOffset = b.create<executor::GetOffsetOp>(
b.getI64Type(), structType,
ArrayRef<OpFoldResult>{this->createIndexConstant(b, 0),
rewriter.getI64IntegerAttr(offset++)});
ArrayRef<OpFoldResult>{
this->createIndexConstant(b, 0),
rewriter.getI64IntegerAttr(outputDescOffset++)});
Value shape = b.create<executor::LoadOp>(
b.getI64Type(), outputDescriptors, shapeOffset);
shapes.push_back(shape);
Expand All @@ -416,8 +419,9 @@ struct ConvertEnqueueAllocToCall
for (unsigned r = 0; r < rank; ++r) {
Value strideOffset = b.create<executor::GetOffsetOp>(
b.getI64Type(), structType,
ArrayRef<OpFoldResult>{this->createIndexConstant(b, 0),
rewriter.getI64IntegerAttr(offset++)});
ArrayRef<OpFoldResult>{
this->createIndexConstant(b, 0),
rewriter.getI64IntegerAttr(outputDescOffset++)});
Value shape = b.create<executor::LoadOp>(
b.getI64Type(), outputDescriptors, strideOffset);
shapes.push_back(shape);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ add_mlir_tensorrt_library(MLIRTensorRTPlanTransforms
MLIRTensorRTStablehloScalarToArith
MLIRTensorRTStablehloToTensorRT
MLIRTensorRTTensorRTRuntimeDialect
MLIRBufferizationToMemRef
MLIRTransforms
StablehloOps
)
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,26 @@ struct RemoveWithValuesRewriter : public OpRewritePattern<plan::WithValuesOp> {
} // namespace

/// Get a map from `tensorrt.func` functions to associated `tensorrt.call`
/// operations.
static llvm::DenseMap<func::FuncOp, SmallVector<tensorrt::CallOp>>
/// and `tensorrt.call_alloc` operations.
static llvm::DenseMap<func::FuncOp, SmallVector<Operation *>>
getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) {
llvm::DenseMap<func::FuncOp, SmallVector<tensorrt::CallOp>> map;
op->walk([&](tensorrt::CallOp callOp) {
func::FuncOp func = callOp.getFuncCallee(collection);
if (map.contains(func)) {
map[func].push_back(callOp);
llvm::DenseMap<func::FuncOp, SmallVector<Operation *>> map;
op->walk([&](Operation *callOp) {
if (!isa<tensorrt::CallOp, tensorrt::CallAllocOp>(callOp))
return;

func::FuncOp func;
if (auto call = dyn_cast<tensorrt::CallOp>(callOp)) {
func = call.getFuncCallee(collection);
} else {
auto callAlloc = dyn_cast<tensorrt::CallAllocOp>(callOp);
func = callAlloc.getFuncCallee(collection);
}
map.insert(std::make_pair(func, SmallVector<tensorrt::CallOp>{callOp}));

if (map.count(func))
map[func].push_back(callOp);
else
map.insert({func, SmallVector<Operation *>{callOp}});
});
return map;
}
Expand All @@ -84,7 +93,7 @@ getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) {
/// `tensorrt.call` operations.
static LogicalResult removeUnusedArgs(SymbolTableCollection &collection,
ModuleOp op, func::FuncOp funcOp,
ArrayRef<tensorrt::CallOp> callOps) {
ArrayRef<Operation *> callOps) {
llvm::SmallBitVector unusedArgs(funcOp.getNumArguments(), 0);
for (BlockArgument arg : funcOp.getArguments()) {
if (arg.use_empty())
Expand All @@ -99,10 +108,16 @@ static LogicalResult removeUnusedArgs(SymbolTableCollection &collection,
funcOp.eraseArgument(i);

// Update the call ops.
for (tensorrt::CallOp callOp : callOps)
callOp.getInputsMutable().erase(i);
for (Operation *callOp : callOps) {
if (auto call = dyn_cast<tensorrt::CallOp>(callOp))
call.getInputsMutable().erase(i);
else if (auto callAlloc = dyn_cast<tensorrt::CallAllocOp>(callOp))
callAlloc.getInputsMutable().erase(i);
else
return emitError(funcOp->getLoc())
<< "Unexpected operation type in callOps";
}
}

return success();
}

Expand Down
145 changes: 87 additions & 58 deletions mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,28 +156,32 @@ getTensorRTShapeProfile(plan::BoundsAttr attr, Value v) {
return getProfileAttr(attr.getMinShape(), attr.getMaxShape());
}

static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,
plan::InlineClosedGroupOp op) {
tensorrt::TensorRTModuleOp trtModuleOp = getOrCreateTensorRTModuleOp(op);
auto funcArgTypes = llvm::to_vector(TypeRange(op.getInputs()));
FailureOr<FunctionOpInterface> func = createOutlinedFunc(
rewriter, op.getLoc(), op, trtModuleOp, "tensorrt_cluster",
"cluster.tensorrt", TypeRange(op.getInputs()),
op.getYield()->getOperandTypes());
if (failed(func))
return failure();
assert(func->getFunctionBody().getBlocks().size() == 1 &&
"expected body with one block");
func->setPublic();

rewriter.setInsertionPoint(op);
auto callOp = rewriter.create<tensorrt::CallOp>(
op.getLoc(), op.getResultTypes(), op.getInputs(), op.getOuts(),
SymbolRefAttr::get(trtModuleOp.getNameAttr(),
{FlatSymbolRefAttr::get(*func)}));
template <typename OpType>
static auto createCallOp(RewriterBase &rewriter, OpType op,
tensorrt::TensorRTModuleOp trtModuleOp,
FunctionOpInterface func) {
static_assert(
std::is_same_v<OpType, plan::InlineClosedGroupOp> ||
std::is_same_v<OpType, plan::InlineClosedAllocGroupOp>,
"OpType must be either InlineClosedGroupOp or InlineClosedAllocGroupOp");
if constexpr (std::is_same_v<OpType, plan::InlineClosedGroupOp>)
return rewriter.create<tensorrt::CallOp>(
op.getLoc(), op.getResultTypes(), op.getInputs(), op.getOuts(),
SymbolRefAttr::get(trtModuleOp.getNameAttr(),
{FlatSymbolRefAttr::get(func)}));
else if constexpr (std::is_same_v<OpType, plan::InlineClosedAllocGroupOp>)
return rewriter.create<tensorrt::CallAllocOp>(
op.getLoc(), op.getResultTypes(), op.getInputs(),
SymbolRefAttr::get(trtModuleOp.getNameAttr(),
{FlatSymbolRefAttr::get(func)}));
}

template <typename OpType>
static LogicalResult populateFunctionAttributes(RewriterBase &rewriter,
OpType op,
FunctionOpInterface *func) {
// Populate the function arguments attributes.
for (unsigned i = 0; i < (*func).getNumArguments(); i++) {
for (unsigned i = 0; i < func->getNumArguments(); i++) {
BoundsAttr srcAttr = cast<BoundsAttr>(op.getInputAttrs()[i]);
// We may have scalar (index|signless int)-typed values since we haven't
// eliminated `plan.(with_shape|with_values)` ops yet.
Expand All @@ -202,30 +206,57 @@ static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,
func->setArgAttr(i, mlir::getHostTensorArgAttrName(),
rewriter.getUnitAttr());
}
// Populate the function result attributes.
for (unsigned i = 0; i < (*func).getNumResults(); i++) {
BoundsAttr srcAttr = cast<BoundsAttr>(op.getResAttrs()[i]);
if (srcAttr.isNone())
continue;
FailureOr<tensorrt::ShapeProfileAttr> boundsAttr =
getTensorRTShapeProfile(srcAttr, op.getResults()[i]);
if (failed(boundsAttr))
return op->emitOpError("failed to create TensorRT shape profile "
"attribute from Plan BoundsAttr for result #")
<< i << " (" << srcAttr << ")";
if (srcAttr.isShapeBound()) {
// Populate the function result attributes for DPS call op.
if constexpr (std::is_same_v<OpType, plan::InlineClosedGroupOp>) {
for (unsigned i = 0; i < func->getNumResults(); i++) {
BoundsAttr srcAttr = cast<BoundsAttr>(op.getResAttrs()[i]);
if (srcAttr.isNone())
continue;
FailureOr<tensorrt::ShapeProfileAttr> boundsAttr =
getTensorRTShapeProfile(srcAttr, op.getResults()[i]);
if (failed(boundsAttr))
return op->emitOpError("failed to create TensorRT shape profile "
"attribute from Plan BoundsAttr for result #")
<< i << " (" << srcAttr << ")";
if (srcAttr.isShapeBound()) {
func->setResultAttr(
i, tensorrt::TensorRTDialect::getShapeProfileArgAttrName(),
*boundsAttr);
continue;
}
assert(srcAttr.isValueBound() && "expected value bound or shape bound");
func->setResultAttr(
i, tensorrt::TensorRTDialect::getShapeProfileArgAttrName(),
i, tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName(),
*boundsAttr);
continue;
func->setResultAttr(i, mlir::getHostTensorArgAttrName(),
rewriter.getUnitAttr());
}
assert(srcAttr.isValueBound() && "expected value bound or shape bound");
func->setResultAttr(
i, tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName(),
*boundsAttr);
func->setResultAttr(i, mlir::getHostTensorArgAttrName(),
rewriter.getUnitAttr());
}
return success();
}

template <typename OpType>
static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter, OpType op) {
tensorrt::TensorRTModuleOp trtModuleOp = getOrCreateTensorRTModuleOp(op);
auto funcArgTypes = llvm::to_vector(TypeRange(op.getInputs()));

FailureOr<FunctionOpInterface> func = createOutlinedFunc(
rewriter, op.getLoc(), op, trtModuleOp, "tensorrt_cluster",
"cluster.tensorrt", TypeRange(op.getInputs()),
op.getYield()->getOperandTypes());

if (failed(func))
return failure();

assert(func->getFunctionBody().getBlocks().size() == 1 &&
"expected body with one block");
func->setPublic();

rewriter.setInsertionPoint(op);
auto callOp = createCallOp<OpType>(rewriter, op, trtModuleOp, *func);

if (failed(populateFunctionAttributes(rewriter, op, &(*func))))
return failure();

// Populate the function entry block.
rewriter.eraseBlock(&func->getFunctionBody().front());
Expand All @@ -234,14 +265,14 @@ static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,
// ops to the `tensorrt.module` op. This is needed since `tensorrt.module` op
// has its own symbol table.
SymbolTableCollection symbolTable;
for (auto compositeOp : op.getBody().getOps<stablehlo::CompositeOp>()) {
for (auto compositeOp :
op.getBody().template getOps<stablehlo::CompositeOp>()) {
auto decompositionFunc = dyn_cast_if_present<func::FuncOp>(
symbolTable.lookupSymbolIn(op->getParentOfType<ModuleOp>(),
symbolTable.lookupSymbolIn(op->template getParentOfType<ModuleOp>(),
compositeOp.getDecompositionAttr()));
if (!decompositionFunc)
return emitError(compositeOp.getLoc())
<< "failed to lookup stablehlo.composite decomposition "
"function: "
<< "failed to lookup stablehlo.composite decomposition function: "
<< compositeOp.getDecompositionAttr();
rewriter.moveOpAfter(decompositionFunc, func->getOperation());
}
Expand All @@ -254,24 +285,20 @@ static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,
rewriter.replaceOpWithNewOp<func::ReturnOp>(regionYieldOp,
regionYieldOp->getOperands());

// Erase the DPS arugments, which now should be unused.
if (llvm::any_of(func->getArguments().take_back(op.getOuts().size()),
[](BlockArgument arg) { return !arg.use_empty(); }))
return failure();
func->getFunctionBody().front().eraseArguments(op.getInputs().size(),
op.getOuts().size());
if constexpr (std::is_same_v<OpType, plan::InlineClosedGroupOp>) {
// Erase the DPS arugments, which now should be unused.
if (llvm::any_of(func->getArguments().take_back(op.getOuts().size()),
[](BlockArgument arg) { return !arg.use_empty(); }))
return failure();
func->getFunctionBody().front().eraseArguments(op.getInputs().size(),
op.getOuts().size());
}

// replace the original region results.
rewriter.replaceOp(op, callOp);
return success();
}

static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,
plan::InlineClosedAllocGroupOp op) {
return op.emitError("outlinining inline closed alloc group ops to tensorrt "
"dialect is not yet implemented");
}

/// Create outlined functions for each `scf.execute_region` operation within
/// `region`.
static FailureOr<SmallVector<FunctionOpInterface>>
Expand Down Expand Up @@ -302,12 +329,14 @@ createFunctionsFromRegions(RewriterBase &rewriter, Region &region,
}

if (auto group = dyn_cast<plan::InlineClosedGroupOp>(op)) {
if (failed(outlineTensorRTRegion(rewriter, group)))
if (failed(outlineTensorRTRegion<plan::InlineClosedGroupOp>(rewriter,
group)))
return WalkResult::interrupt();
return WalkResult::advance();
}
if (auto allocGroup = dyn_cast<plan::InlineClosedAllocGroupOp>(op)) {
if (failed(outlineTensorRTRegion(rewriter, allocGroup)))
if (failed(outlineTensorRTRegion<plan::InlineClosedAllocGroupOp>(
rewriter, allocGroup)))
return WalkResult::interrupt();
return WalkResult::advance();
}
Expand Down
Loading

0 comments on commit eed9be4

Please sign in to comment.