Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable end to end non-DPS testing #289

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ StableHLOToExecutableOptions::StableHLOToExecutableOptions(
disallowHostTensorsInTensorRTClusters, llvm::cl::init(false),
llvm::cl::desc("Don't allow TensorRt clusters to contain host tensor "
"calculations (but they can still be inputs)"));
addOption(
"enable-non-dps-returns", enableNonDPSReturns, llvm::cl::init(false),
llvm::cl::desc(
"allow tensorrt based output allocations using output allocator"));
addOption("executor-index-bitwidth", executorIndexBitwidth,
llvm::cl::init(64));
addOption("device-compute-capability", deviceComputeCapability,
Expand Down Expand Up @@ -307,6 +311,7 @@ void StableHloToExecutableTask::buildStablehloClusteringPipeline(
plan::StablehloClusteringPassOptions clusteringOpts{};
clusteringOpts.disallowHostTensorsInTensorRTClusters =
opts.disallowHostTensorsInTensorRTClusters;
clusteringOpts.enableNonDPSReturns = opts.enableNonDPSReturns;
clusteringOpts.entrypoint = opts.entrypoint;
plan::buildPlanSegmentationPipeline(pm, clusteringOpts);

Expand Down Expand Up @@ -340,7 +345,9 @@ void StableHloToExecutableTask::buildPostClusteringPipeline(

// Perform bufferization.
pm.addPass(createMemRefCastEliminationPass());
pm.addPass(plan::createPlanAllocTensorsPass());
plan::PlanAllocTensorsPassOptions allocTensorsOpts{};
allocTensorsOpts.enableNonDPSReturns = opts.enableNonDPSReturns;
pm.addPass(plan::createPlanAllocTensorsPass(allocTensorsOpts));
pm.addPass(plan::createPlanBufferizePass());
pm.addPass(createMemRefCastEliminationPass());
pm.addPass(createCanonicalizerPass());
Expand Down Expand Up @@ -529,6 +536,11 @@ struct ClusteringPipelineCliOpts
*this, "device-compute-capability",
llvm::cl::desc("target device compute capability (SM version)"),
llvm::cl::init(60)};
Option<bool> enableNonDPSReturns{
*this, "enable-non-dps-returns",
llvm::cl::desc(
"allow tensorrt based output allocations using output allocator"),
llvm::cl::init(false)};
Option<int64_t> deviceMaxSharedMemoryPerBlockKb{
*this, "device-max-smem-per-block",
llvm::cl::desc("max shared memory per block (in kilobytes)"),
Expand Down Expand Up @@ -556,6 +568,7 @@ static StableHLOToExecutableOptions populateStablehloClusteringPipelineOpts(
opts.deviceComputeCapability = cliOpts.deviceComputeCapability;
opts.deviceMaxSharedMemoryPerBlockKb =
cliOpts.deviceMaxSharedMemoryPerBlockKb;
opts.enableNonDPSReturns = cliOpts.enableNonDPSReturns;
opts.shouldInferDeviceOptionsFromHost = cliOpts.inferDeviceOptionsFromHost;
opts.entrypoint = cliOpts.entrypoint;
return opts;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,20 +379,22 @@ struct ConvertEnqueueAllocToCall

// Create output memrefs from output descriptors
SmallVector<Value> results;
// Initialize output descriptor offset to skip number of results.
// `outputDescOffset` is used to retrieve rank, ptr, shapes, and strides per
// result.
outputDescOffset = 1;
for (unsigned i = 0; i < op->getNumResults(); ++i) {
unsigned rank = cast<MemRefType>(op->getResult(i).getType()).getRank();
unsigned offset =
1 +
i * (2 * rank + 2); // num res, (i * (rank, ptr, [shape], [stride]))

Value rankOffset = b.create<executor::GetOffsetOp>(
b.getI64Type(), structType,
ArrayRef<OpFoldResult>{this->createIndexConstant(b, 0),
rewriter.getI64IntegerAttr(offset++)});
ArrayRef<OpFoldResult>{
this->createIndexConstant(b, 0),
rewriter.getI64IntegerAttr(outputDescOffset++)});
Value devicePtrOffset = b.create<executor::GetOffsetOp>(
b.getI64Type(), structType,
ArrayRef<OpFoldResult>{this->createIndexConstant(b, 0),
rewriter.getI64IntegerAttr(offset++)});
ArrayRef<OpFoldResult>{
this->createIndexConstant(b, 0),
rewriter.getI64IntegerAttr(outputDescOffset++)});

[[maybe_unused]] Value rankValue = b.create<executor::LoadOp>(
b.getI64Type(), outputDescriptors, rankOffset);
Expand All @@ -406,8 +408,9 @@ struct ConvertEnqueueAllocToCall
for (unsigned r = 0; r < rank; ++r) {
Value shapeOffset = b.create<executor::GetOffsetOp>(
b.getI64Type(), structType,
ArrayRef<OpFoldResult>{this->createIndexConstant(b, 0),
rewriter.getI64IntegerAttr(offset++)});
ArrayRef<OpFoldResult>{
this->createIndexConstant(b, 0),
rewriter.getI64IntegerAttr(outputDescOffset++)});
Value shape = b.create<executor::LoadOp>(
b.getI64Type(), outputDescriptors, shapeOffset);
shapes.push_back(shape);
Expand All @@ -416,8 +419,9 @@ struct ConvertEnqueueAllocToCall
for (unsigned r = 0; r < rank; ++r) {
Value strideOffset = b.create<executor::GetOffsetOp>(
b.getI64Type(), structType,
ArrayRef<OpFoldResult>{this->createIndexConstant(b, 0),
rewriter.getI64IntegerAttr(offset++)});
ArrayRef<OpFoldResult>{
this->createIndexConstant(b, 0),
rewriter.getI64IntegerAttr(outputDescOffset++)});
Value shape = b.create<executor::LoadOp>(
b.getI64Type(), outputDescriptors, strideOffset);
shapes.push_back(shape);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ add_mlir_tensorrt_library(MLIRTensorRTPlanTransforms
MLIRTensorRTStablehloScalarToArith
MLIRTensorRTStablehloToTensorRT
MLIRTensorRTTensorRTRuntimeDialect
MLIRBufferizationToMemRef
MLIRTransforms
StablehloOps
)
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,26 @@ struct RemoveWithValuesRewriter : public OpRewritePattern<plan::WithValuesOp> {
} // namespace

/// Get a map from `tensorrt.func` functions to associated `tensorrt.call`
/// operations.
static llvm::DenseMap<func::FuncOp, SmallVector<tensorrt::CallOp>>
/// and `tensorrt.call_alloc` operations.
static llvm::DenseMap<func::FuncOp, SmallVector<Operation *>>
getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) {
llvm::DenseMap<func::FuncOp, SmallVector<tensorrt::CallOp>> map;
op->walk([&](tensorrt::CallOp callOp) {
func::FuncOp func = callOp.getFuncCallee(collection);
if (map.contains(func)) {
map[func].push_back(callOp);
llvm::DenseMap<func::FuncOp, SmallVector<Operation *>> map;
op->walk([&](Operation *callOp) {
if (!isa<tensorrt::CallOp, tensorrt::CallAllocOp>(callOp))
return;

func::FuncOp func;
if (auto call = dyn_cast<tensorrt::CallOp>(callOp)) {
func = call.getFuncCallee(collection);
} else {
auto callAlloc = dyn_cast<tensorrt::CallAllocOp>(callOp);
func = callAlloc.getFuncCallee(collection);
}
map.insert(std::make_pair(func, SmallVector<tensorrt::CallOp>{callOp}));

if (map.count(func))
map[func].push_back(callOp);
else
map.insert({func, SmallVector<Operation *>{callOp}});
});
return map;
}
Expand All @@ -84,7 +93,7 @@ getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) {
/// `tensorrt.call` operations.
static LogicalResult removeUnusedArgs(SymbolTableCollection &collection,
ModuleOp op, func::FuncOp funcOp,
ArrayRef<tensorrt::CallOp> callOps) {
ArrayRef<Operation *> callOps) {
llvm::SmallBitVector unusedArgs(funcOp.getNumArguments(), 0);
for (BlockArgument arg : funcOp.getArguments()) {
if (arg.use_empty())
Expand All @@ -99,10 +108,16 @@ static LogicalResult removeUnusedArgs(SymbolTableCollection &collection,
funcOp.eraseArgument(i);

// Update the call ops.
for (tensorrt::CallOp callOp : callOps)
callOp.getInputsMutable().erase(i);
for (Operation *callOp : callOps) {
if (auto call = dyn_cast<tensorrt::CallOp>(callOp))
call.getInputsMutable().erase(i);
else if (auto callAlloc = dyn_cast<tensorrt::CallAllocOp>(callOp))
callAlloc.getInputsMutable().erase(i);
else
return emitError(funcOp->getLoc())
<< "Unexpected operation type in callOps";
}
}

return success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,28 +156,32 @@ getTensorRTShapeProfile(plan::BoundsAttr attr, Value v) {
return getProfileAttr(attr.getMinShape(), attr.getMaxShape());
}

static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,
plan::InlineClosedGroupOp op) {
tensorrt::TensorRTModuleOp trtModuleOp = getOrCreateTensorRTModuleOp(op);
auto funcArgTypes = llvm::to_vector(TypeRange(op.getInputs()));
FailureOr<FunctionOpInterface> func = createOutlinedFunc(
rewriter, op.getLoc(), op, trtModuleOp, "tensorrt_cluster",
"cluster.tensorrt", TypeRange(op.getInputs()),
op.getYield()->getOperandTypes());
if (failed(func))
return failure();
assert(func->getFunctionBody().getBlocks().size() == 1 &&
"expected body with one block");
func->setPublic();

rewriter.setInsertionPoint(op);
auto callOp = rewriter.create<tensorrt::CallOp>(
op.getLoc(), op.getResultTypes(), op.getInputs(), op.getOuts(),
SymbolRefAttr::get(trtModuleOp.getNameAttr(),
{FlatSymbolRefAttr::get(*func)}));
template <typename OpType>
static auto createCallOp(RewriterBase &rewriter, OpType op,
tensorrt::TensorRTModuleOp trtModuleOp,
FunctionOpInterface func) {
static_assert(
std::is_same_v<OpType, plan::InlineClosedGroupOp> ||
std::is_same_v<OpType, plan::InlineClosedAllocGroupOp>,
"OpType must be either InlineClosedGroupOp or InlineClosedAllocGroupOp");
if constexpr (std::is_same_v<OpType, plan::InlineClosedGroupOp>)
return rewriter.create<tensorrt::CallOp>(
op.getLoc(), op.getResultTypes(), op.getInputs(), op.getOuts(),
SymbolRefAttr::get(trtModuleOp.getNameAttr(),
{FlatSymbolRefAttr::get(func)}));
else if constexpr (std::is_same_v<OpType, plan::InlineClosedAllocGroupOp>)
return rewriter.create<tensorrt::CallAllocOp>(
op.getLoc(), op.getResultTypes(), op.getInputs(),
SymbolRefAttr::get(trtModuleOp.getNameAttr(),
{FlatSymbolRefAttr::get(func)}));
}

template <typename OpType>
static LogicalResult populateFunctionAttributes(RewriterBase &rewriter,
OpType op,
FunctionOpInterface *func) {
// Populate the function arguments attributes.
for (unsigned i = 0; i < (*func).getNumArguments(); i++) {
for (unsigned i = 0; i < func->getNumArguments(); i++) {
BoundsAttr srcAttr = cast<BoundsAttr>(op.getInputAttrs()[i]);
// We may have scalar (index|signless int)-typed values since we haven't
// eliminated `plan.(with_shape|with_values)` ops yet.
Expand All @@ -202,30 +206,57 @@ static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,
func->setArgAttr(i, mlir::getHostTensorArgAttrName(),
rewriter.getUnitAttr());
}
// Populate the function result attributes.
for (unsigned i = 0; i < (*func).getNumResults(); i++) {
BoundsAttr srcAttr = cast<BoundsAttr>(op.getResAttrs()[i]);
if (srcAttr.isNone())
continue;
FailureOr<tensorrt::ShapeProfileAttr> boundsAttr =
getTensorRTShapeProfile(srcAttr, op.getResults()[i]);
if (failed(boundsAttr))
return op->emitOpError("failed to create TensorRT shape profile "
"attribute from Plan BoundsAttr for result #")
<< i << " (" << srcAttr << ")";
if (srcAttr.isShapeBound()) {
// Populate the function result attributes for DPS call op.
if constexpr (std::is_same_v<OpType, plan::InlineClosedGroupOp>) {
for (unsigned i = 0; i < func->getNumResults(); i++) {
BoundsAttr srcAttr = cast<BoundsAttr>(op.getResAttrs()[i]);
if (srcAttr.isNone())
continue;
FailureOr<tensorrt::ShapeProfileAttr> boundsAttr =
getTensorRTShapeProfile(srcAttr, op.getResults()[i]);
if (failed(boundsAttr))
return op->emitOpError("failed to create TensorRT shape profile "
"attribute from Plan BoundsAttr for result #")
<< i << " (" << srcAttr << ")";
if (srcAttr.isShapeBound()) {
func->setResultAttr(
i, tensorrt::TensorRTDialect::getShapeProfileArgAttrName(),
*boundsAttr);
continue;
}
assert(srcAttr.isValueBound() && "expected value bound or shape bound");
func->setResultAttr(
i, tensorrt::TensorRTDialect::getShapeProfileArgAttrName(),
i, tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName(),
*boundsAttr);
continue;
func->setResultAttr(i, mlir::getHostTensorArgAttrName(),
rewriter.getUnitAttr());
}
assert(srcAttr.isValueBound() && "expected value bound or shape bound");
func->setResultAttr(
i, tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName(),
*boundsAttr);
func->setResultAttr(i, mlir::getHostTensorArgAttrName(),
rewriter.getUnitAttr());
}
return success();
}

template <typename OpType>
static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter, OpType op) {
tensorrt::TensorRTModuleOp trtModuleOp = getOrCreateTensorRTModuleOp(op);
auto funcArgTypes = llvm::to_vector(TypeRange(op.getInputs()));

FailureOr<FunctionOpInterface> func = createOutlinedFunc(
rewriter, op.getLoc(), op, trtModuleOp, "tensorrt_cluster",
"cluster.tensorrt", TypeRange(op.getInputs()),
op.getYield()->getOperandTypes());

if (failed(func))
return failure();

assert(func->getFunctionBody().getBlocks().size() == 1 &&
"expected body with one block");
func->setPublic();

rewriter.setInsertionPoint(op);
auto callOp = createCallOp<OpType>(rewriter, op, trtModuleOp, *func);

if (failed(populateFunctionAttributes(rewriter, op, &(*func))))
return failure();

// Populate the function entry block.
rewriter.eraseBlock(&func->getFunctionBody().front());
Expand All @@ -234,14 +265,14 @@ static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,
// ops to the `tensorrt.module` op. This is needed since `tensorrt.module` op
// has its own symbol table.
SymbolTableCollection symbolTable;
for (auto compositeOp : op.getBody().getOps<stablehlo::CompositeOp>()) {
for (auto compositeOp :
op.getBody().template getOps<stablehlo::CompositeOp>()) {
auto decompositionFunc = dyn_cast_if_present<func::FuncOp>(
symbolTable.lookupSymbolIn(op->getParentOfType<ModuleOp>(),
symbolTable.lookupSymbolIn(op->template getParentOfType<ModuleOp>(),
compositeOp.getDecompositionAttr()));
if (!decompositionFunc)
return emitError(compositeOp.getLoc())
<< "failed to lookup stablehlo.composite decomposition "
"function: "
<< "failed to lookup stablehlo.composite decomposition function: "
<< compositeOp.getDecompositionAttr();
rewriter.moveOpAfter(decompositionFunc, func->getOperation());
}
Expand All @@ -254,24 +285,20 @@ static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,
rewriter.replaceOpWithNewOp<func::ReturnOp>(regionYieldOp,
regionYieldOp->getOperands());

// Erase the DPS arugments, which now should be unused.
if (llvm::any_of(func->getArguments().take_back(op.getOuts().size()),
[](BlockArgument arg) { return !arg.use_empty(); }))
return failure();
func->getFunctionBody().front().eraseArguments(op.getInputs().size(),
op.getOuts().size());
if constexpr (std::is_same_v<OpType, plan::InlineClosedGroupOp>) {
// Erase the DPS arugments, which now should be unused.
if (llvm::any_of(func->getArguments().take_back(op.getOuts().size()),
[](BlockArgument arg) { return !arg.use_empty(); }))
return failure();
func->getFunctionBody().front().eraseArguments(op.getInputs().size(),
op.getOuts().size());
}

// replace the original region results.
rewriter.replaceOp(op, callOp);
return success();
}

static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,
plan::InlineClosedAllocGroupOp op) {
return op.emitError("outlinining inline closed alloc group ops to tensorrt "
"dialect is not yet implemented");
}

/// Create outlined functions for each `scf.execute_region` operation within
/// `region`.
static FailureOr<SmallVector<FunctionOpInterface>>
Expand Down Expand Up @@ -302,12 +329,14 @@ createFunctionsFromRegions(RewriterBase &rewriter, Region &region,
}

if (auto group = dyn_cast<plan::InlineClosedGroupOp>(op)) {
if (failed(outlineTensorRTRegion(rewriter, group)))
if (failed(outlineTensorRTRegion<plan::InlineClosedGroupOp>(rewriter,
group)))
return WalkResult::interrupt();
return WalkResult::advance();
}
if (auto allocGroup = dyn_cast<plan::InlineClosedAllocGroupOp>(op)) {
if (failed(outlineTensorRTRegion(rewriter, allocGroup)))
if (failed(outlineTensorRTRegion<plan::InlineClosedAllocGroupOp>(
rewriter, allocGroup)))
return WalkResult::interrupt();
return WalkResult::advance();
}
Expand Down
Loading
Loading