Skip to content

Commit

Permalink
Add plan dialect changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jhalakpatel committed Oct 14, 2024
1 parent 0656ede commit 3844712
Show file tree
Hide file tree
Showing 19 changed files with 577 additions and 291 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ namespace mlirtrt::compiler {
/// DebugOptions are options that are common to different compiler API
/// interfaces.
struct DebugOptions {
/// Dump textual pipeline passes
bool dumpTextualPipeline = false;

/// A directory path where the IR will be dumped during compilation
/// using the `mlir-print-ir-tree-dir` mechanism.
std::string dumpIRPath = "";
Expand All @@ -49,6 +52,7 @@ struct DebugOptions {
mlir::SmallVector<std::string> llvmDebugTypes = {};

void addToOptions(mlir::OptionsContext &context) {
context.addOption("dump-textual-pipeline", dumpTextualPipeline);
context.addOption("mlir-print-ir-tree-dir", dumpIRPath, llvm::cl::init(""));
context.addOption("debug", enableLLVMDebugFlag);
context.addList<std::string>("debug-only", llvmDebugTypes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ struct StableHLOToExecutableOptions : public mlir::OptionsContext {
/// Whether to disallow host tensors in TensorRT clusters.
bool disallowHostTensorsInTensorRTClusters = false;

/// Whether to use non-DPS style calling convention.
bool useNonDPSCallConv = false;

/// Entrypoint function name.
std::string entrypoint = "main";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,72 @@ def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [
}];
}


//===----------------------------------------------------------------------===//
// InlineClosedGroupOp
//===----------------------------------------------------------------------===//

def Plan_InlineClosedGroupNonDPSOp : Plan_GroupOpBase<"inline_closed_group_non_dps", [
IsolatedFromAbove,
SingleBlockImplicitTerminator<"plan::YieldOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands"]>,
DeclareOpInterfaceMethods<OpAsmOpInterface,
["getAsmBlockArgumentNames"]>
]> {
let description = [{
// TODO: Add explanation for non-DPS version for inline closed group op.
}];
let arguments = (ins Variadic<AnyTypeOf<[AnyRankedTensor, AnySignlessIntegerOrIndex]>>:$inputs,
BoundsAttrArray:$input_attrs,
AnyAttr:$target);

let results = (outs Variadic<AnyTypeOf<[AnyRankedTensor]>>:$results);

let assemblyFormat = [{
`target` `(` $target `)` `\n`
`inputs` `(` ( $inputs^ `:` type($inputs) `)` ) : ( `)` ) ? `\n`
`in_attrs` $input_attrs `\n`
attr-dict-with-keyword `->` type($results)
$body
}];

let hasVerifier = 1;

let skipDefaultBuilders = 1;

let builders = [
OpBuilder<(ins "TypeRange":$results,
"Attribute":$target,
"ValueRange":$inputs,
CArg<"ArrayRef<BoundsAttr>", "{}">:$input_attrs)>,
];

let extraClassDeclaration = baseExtraClassDeclaration # [{

/// Returns true if the `i-th` input argument has a tensor type.
bool argHasTensorType(unsigned inputIdx) {
assert(inputIdx < getInputs().size() && "input index out-of-bounds");
return isa<RankedTensorType>(getInputs()[inputIdx].getType());
}

/// Returns the i-th input argument's bounds attribute.
BoundsAttr getInputBoundsAttr(unsigned inputIdx) {
assert(inputIdx < getInputs().size() && "input index out-of-bounds");
return cast<BoundsAttr>(getInputAttrs()[inputIdx]);
}

/// Populate the `input_attrs` from an array of BoundsAttrs.
void setInputAttrsAttr(ArrayRef<BoundsAttr> boundsAttrs) {
setInputAttrsAttr(::mlir::ArrayAttr::get(
getOperation()->getContext(),
ArrayRef<Attribute>(boundsAttrs.begin(), boundsAttrs.end())
));
}

}];
}

//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
Expand All @@ -276,7 +342,7 @@ def Plan_YieldOp : Plan_Op<"yield", [
Terminator,
ReturnLike,
ParentOneOf<["plan::InlineGroupOp",
"plan::InlineClosedGroupOp"]>]> {
"plan::InlineClosedGroupOp", "plan::InlineClosedGroupNonDPSOp"]>]> {

let arguments = (ins Variadic<AnyType>:$results);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ executorOneShotModuleBufferize(ModuleOp targetOp,
const ExecutorBufferizationOptions &options);

/// Build a pipeline (targeting ModuleOp) for bufferization.
void buildPlanBufferizationPipeline(OpPassManager &pm);
void buildPlanBufferizationPipeline(
OpPassManager &pm, const plan::PlanAllocTensorsPassOptions &options);

/// Build a post-bufferization pipeline that performs optimizations on memrefs.
void buildPlanBufferOptimizationPipeline(OpPassManager &pm);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,9 @@ def StablehloClusteringPass : Pass<"stablehlo-clustering", "::mlir::ModuleOp"> {
Option<"entrypoint", "entrypoint", "std::string", "\"\"",
"the name of the entrypoint function; if empty then the clustering runs"
" on all functions">,
Option<"useNonDPSCallConv",
"use-non-dps-call-conv", "bool", "false",
"allow tensorrt based output allocations using output allocator">,
Option<"disallowHostTensorsInTensorRTClusters",
"disallow-host-tensors-in-tensorrt-clusters", "bool", "false",
"don't cluster host tensors in TensorRT clusters">,
Expand Down Expand Up @@ -332,7 +335,10 @@ def CreateClosedRegionsPass : Pass<"plan-create-closed-regions", "::mlir::Module
Option<"testPreWalkOrder", "test-pre-walk-order", "bool", "false",
"(used only in testing) specifies to outline regions by walking in "
" pre-order; used for verifying results are not sensitive "
"to traversal order">
"to traversal order">,
Option<"useNonDPSCallConv", "use-non-dps-call-conv", "bool",
/*default=*/"false",
"Allow TensorRT-based output allocations using output allocator">
];

let dependentDialects = [
Expand Down Expand Up @@ -428,6 +434,13 @@ def PlanAllocTensorsPass : Pass<"plan-alloc-tensors",
"::mlir::bufferization::BufferizationDialect",
"::mlir::plan::PlanDialect"
];

let options = [
Option<"useNonDPSCallConv", "use-non-dps-call-conv", "bool",
/*default=*/"false",
"Allow TensorRT-based output allocations using output allocator">
];

}

//===----------------------------------------------------------------------===//
Expand Down
24 changes: 18 additions & 6 deletions mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,9 @@ StableHLOToExecutableOptions::StableHLOToExecutableOptions(
disallowHostTensorsInTensorRTClusters, llvm::cl::init(false),
llvm::cl::desc("Don't allow TensorRt clusters to contain host tensor "
"calculations (but they can still be inputs)"));
addOption("use-non-dps-call-conv", useNonDPSCallConv,
llvm::cl::init(false),
llvm::cl::desc("allow tensorrt based output allocations using output allocator"));
addOption("executor-index-bitwidth", executorIndexBitwidth,
llvm::cl::init(64));
addOption("device-compute-capability", deviceComputeCapability,
Expand Down Expand Up @@ -303,6 +306,7 @@ void StableHloToExecutableTask::buildStablehloClusteringPipeline(
plan::StablehloClusteringPassOptions clusteringOpts{};
clusteringOpts.disallowHostTensorsInTensorRTClusters =
opts.disallowHostTensorsInTensorRTClusters;
clusteringOpts.useNonDPSCallConv = opts.useNonDPSCallConv;
clusteringOpts.entrypoint = opts.entrypoint;
plan::buildPlanSegmentationPipeline(pm, clusteringOpts);

Expand Down Expand Up @@ -336,7 +340,9 @@ void StableHloToExecutableTask::buildPostClusteringPipeline(

// Perform bufferization.
pm.addPass(createMemRefCastEliminationPass());
pm.addPass(plan::createPlanAllocTensorsPass());
plan::PlanAllocTensorsPassOptions allocTensorsOpts{};
allocTensorsOpts.useNonDPSCallConv = opts.useNonDPSCallConv;
pm.addPass(plan::createPlanAllocTensorsPass(allocTensorsOpts));
pm.addPass(plan::createPlanBufferizePass());
pm.addPass(createMemRefCastEliminationPass());
pm.addPass(createCanonicalizerPass());
Expand Down Expand Up @@ -485,13 +491,14 @@ StableHloToExecutableTask::compileStableHLOToExecutable(
runner = pm.get();
}

runner->printAsTextualPipeline(llvm::dbgs());
if (options.debugOptions.dumpTextualPipeline)
runner->printAsTextualPipeline(llvm::dbgs());

// Setup pass manager
// if (failed(runner->run(module)))
// return getInternalErrorStatus(
// "failed to run compilation on module with symbol name: {0}",
// module.getName() ? *module.getName() : "no-symbol-name");
if (failed(runner->run(module)))
return getInternalErrorStatus(
"failed to run compilation on module with symbol name: {0}",
module.getName() ? *module.getName() : "no-symbol-name");

// Translate to Runtime Executable
FailureOr<std::unique_ptr<runtime::ExecutableStorage>> exeStorage =
Expand Down Expand Up @@ -524,6 +531,10 @@ struct ClusteringPipelineCliOpts
*this, "device-compute-capability",
llvm::cl::desc("target device compute capability (SM version)"),
llvm::cl::init(60)};
Option<bool> useNonDPSCallConv{
*this, "use-non-dps-call-conv",
llvm::cl::desc("allow tensorrt based output allocations using output allocator"),
llvm::cl::init(false)};
Option<int64_t> deviceMaxSharedMemoryPerBlockKb{
*this, "device-max-smem-per-block",
llvm::cl::desc("max shared memory per block (in kilobytes)"),
Expand Down Expand Up @@ -551,6 +562,7 @@ static StableHLOToExecutableOptions populateStablehloClusteringPipelineOpts(
opts.deviceComputeCapability = cliOpts.deviceComputeCapability;
opts.deviceMaxSharedMemoryPerBlockKb =
cliOpts.deviceMaxSharedMemoryPerBlockKb;
opts.useNonDPSCallConv = cliOpts.useNonDPSCallConv;
opts.shouldInferDeviceOptionsFromHost = cliOpts.inferDeviceOptionsFromHost;
opts.entrypoint = cliOpts.entrypoint;
return opts;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ struct ConvertEnqueueAllocToCall
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

// Function name for the enqueue alloc operation
std::string funcName = "_trtrt_alloc_enqueue";
std::string funcName = "_trtrt_enqueue_alloc";

// Create new operands for the call op
SmallVector<Value> newOperands = {adaptor.getExecutionContext(),
Expand Down Expand Up @@ -429,8 +429,10 @@ struct ConvertEnqueueAllocToCall
resultRange.append(shapes.begin(), shapes.end());
resultRange.append(strides.begin(), strides.end());

Value result = b.create<executor::CreateTableOp>(executor::TableType::get(
b.getContext(), llvm::to_vector(TypeRange(resultRange))));
Value result = b.create<executor::CreateTableOp>(
executor::TableType::get(b.getContext(),
llvm::to_vector(TypeRange(resultRange))),
resultRange);

results.push_back(result);
}
Expand Down
69 changes: 69 additions & 0 deletions mlir-tensorrt/compiler/lib/Dialect/Plan/IR/PlanOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,75 @@ void InlineClosedGroupOp::build(OpBuilder &b, OperationState &state,
state.addTypes(TypeRange(outs));
}

//===----------------------------------------------------------------------===//
// InlineClosedGroupNonDPSOp
//===----------------------------------------------------------------------===//

LogicalResult InlineClosedGroupNonDPSOp::verify() {
SmallVector<BoundsAttr> inputAttrs =
llvm::to_vector(getInputAttrs().getAsRange<BoundsAttr>());
if (inputAttrs.size() != getInputs().size())
return emitOpError("expected number of inputs (")
<< getInputs().size()
<< " to equal the number of input_attrs BoundsAttrs ("
<< inputAttrs.size() << ")";

for (auto [idx, type] : llvm::enumerate(TypeRange(getInputs()))) {
BoundsAttr boundsAttr = inputAttrs[idx];
if (failed(verifyBoundsAttr("input argument", idx, type, boundsAttr,
[&]() { return emitOpError(); })))
return failure();
}

return success();
}

void InlineClosedGroupNonDPSOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
// If the predecessor is the InlineClosedGroupOp, branch into the body.
if (point.isParent()) {
regions.push_back(RegionSuccessor(&getBody(), getBody().getArguments()));
return;
}

// Otherwise, the region branches back to the parent operation.
regions.push_back(RegionSuccessor(getResults()));
}

OperandRange
InlineClosedGroupNonDPSOp::getEntrySuccessorOperands(RegionBranchPoint point) {
return getOperands();
}

void InlineClosedGroupNonDPSOp::getAsmBlockArgumentNames(
Region &region, OpAsmSetValueNameFn setNameFn) {
assert(region.front().getNumArguments() == getInputs().size() &&
"expected one block arg for each input argument");
for (BlockArgument arg : region.front().getArguments()) {
setNameFn(arg, "in");
}
}

void InlineClosedGroupNonDPSOp::build(OpBuilder &b, OperationState &state, TypeRange resultTypes,
Attribute target, ValueRange inputs,
ArrayRef<BoundsAttr> input_attrs) {
state.addTypes(resultTypes);
state.addOperands(inputs);
state.getOrAddProperties<Properties>().target = target;
state.getOrAddProperties<Properties>().setInputAttrs(b.getArrayAttr(
SmallVector<Attribute>(input_attrs.begin(), input_attrs.end())));
Region *body = state.addRegion();
auto getLocs = [](ValueRange r) {
SmallVector<Location> locs;
locs.reserve(r.size());
for (Value v : r)
locs.push_back(v.getLoc());
return locs;
};
(void)body->emplaceBlock();
body->addArguments(TypeRange(inputs), getLocs(inputs));
}

//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -856,11 +856,14 @@ class AllocTensorsPass
}
}

// First rewrite public functions to conform to DPS style.
IRRewriter rewriter(ctx);
if (failed(rewriteNotPrivateFuncsToDPS(rewriter, op))) {
op->emitError("Failed to convert non-private functions to DPS");
return signalPassFailure();

if (!useNonDPSCallConv) {
// First rewrite public functions to conform to DPS style.
if (failed(rewriteNotPrivateFuncsToDPS(rewriter, op))) {
op->emitError("Failed to convert non-private functions to DPS");
return signalPassFailure();
}
}

// Rewrite SCF for and while loop bodies for better bufferization results,
Expand Down
Loading

0 comments on commit 3844712

Please sign in to comment.