Skip to content

Commit ef2bddf

Browse files
committed
Add plan dialect changes
1 parent b232f2b commit ef2bddf

File tree

20 files changed

+691
-372
lines changed

20 files changed

+691
-372
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Options.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ namespace mlirtrt::compiler {
3434
/// DebugOptions are options that are common to different compiler API
3535
/// interfaces.
3636
struct DebugOptions {
37+
/// Dump textual pipeline passes
38+
bool dumpTextualPipeline = false;
39+
3740
/// A directory path where the IR will be dumped during compilation
3841
/// using the `mlir-print-ir-tree-dir` mechanism.
3942
std::string dumpIRPath = "";
@@ -49,6 +52,7 @@ struct DebugOptions {
4952
mlir::SmallVector<std::string> llvmDebugTypes = {};
5053

5154
void addToOptions(mlir::OptionsContext &context) {
55+
context.addOption("dump-textual-pipeline", dumpTextualPipeline);
5256
context.addOption("mlir-print-ir-tree-dir", dumpIRPath, llvm::cl::init(""));
5357
context.addOption("debug", enableLLVMDebugFlag);
5458
context.addList<std::string>("debug-only", llvmDebugTypes,

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ struct StableHLOToExecutableOptions : public mlir::OptionsContext {
128128
/// Whether to disallow host tensors in TensorRT clusters.
129129
bool disallowHostTensorsInTensorRTClusters = false;
130130

131+
/// Whether to use non-DPS style calling convention.
132+
bool useNonDPSCallConv = false;
133+
131134
/// Entrypoint function name.
132135
std::string entrypoint = "main";
133136

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanOps.td

Lines changed: 117 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,57 @@ def Plan_InlineGroupOp : Plan_GroupOpBase<"inline_group", [
131131
}
132132

133133
//===----------------------------------------------------------------------===//
134-
// InlineClosedGroupOp
134+
// Plan_InlineClosedGroupBase
135135
//===----------------------------------------------------------------------===//
136136

137-
def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [
137+
class Plan_InlineClosedGroupBase<string mnemonic, list<Trait> traits = []> :
138+
Plan_GroupOpBase<mnemonic, traits> {
139+
140+
code baseInlineClosedExtraClassDeclaration = baseExtraClassDeclaration # [{
141+
// Common methods for both DPS and non-DPS versions
142+
bool argHasTensorType(unsigned inputIdx) {
143+
assert(inputIdx < getInputs().size() && "input index out-of-bounds");
144+
return isa<RankedTensorType>(getInputs()[inputIdx].getType());
145+
}
146+
147+
BoundsAttr getInputBoundsAttr(unsigned inputIdx) {
148+
assert(inputIdx < getInputs().size() && "input index out-of-bounds");
149+
return cast<BoundsAttr>(getInputAttrs()[inputIdx]);
150+
}
151+
152+
/// Populate the `input_attrs` from an array of BoundsAttrs.
153+
void setInputAttrsAttr(ArrayRef<BoundsAttr> boundsAttrs) {
154+
setInputAttrsAttr(::mlir::ArrayAttr::get(
155+
getOperation()->getContext(),
156+
ArrayRef<Attribute>(boundsAttrs.begin(), boundsAttrs.end())
157+
));
158+
}
159+
160+
void getSuccessorRegionsBase(RegionBranchPoint point,
161+
SmallVectorImpl<RegionSuccessor> &regions) {
162+
// If the predecessor is the InlineClosedGroupOp, branch into the body.
163+
if (point.isParent()) {
164+
regions.push_back(RegionSuccessor(&getBody(), getBody().getArguments()));
165+
return;
166+
}
167+
168+
// Otherwise, the region branches back to the parent operation.
169+
regions.push_back(RegionSuccessor(getResults()));
170+
}
171+
172+
OperandRange getEntrySuccessorOperandsBase(RegionBranchPoint point) {
173+
return getOperands();
174+
}
175+
}];
176+
177+
let extraClassDeclaration = baseInlineClosedExtraClassDeclaration;
178+
}
179+
180+
//===----------------------------------------------------------------------===//
181+
// Plan_InlineClosedGroupOp
182+
//===----------------------------------------------------------------------===//
183+
184+
def Plan_InlineClosedGroupOp : Plan_InlineClosedGroupBase<"inline_closed_group", [
138185
IsolatedFromAbove,
139186
AttrSizedOperandSegments,
140187
DestinationStyleOpInterface,
@@ -226,24 +273,12 @@ def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [
226273
CArg<"ArrayRef<BoundsAttr>", "{}">:$res_attrs)>
227274
];
228275

229-
let extraClassDeclaration = baseExtraClassDeclaration # [{
276+
let extraClassDeclaration = baseInlineClosedExtraClassDeclaration # [{
230277

231278
MutableOperandRange getDpsInitsMutable() {
232279
return getOutsMutable();
233280
}
234281

235-
/// Returns true if the `i-th` input argument has a tensor type.
236-
bool argHasTensorType(unsigned inputIdx) {
237-
assert(inputIdx < getInputs().size() && "input index out-of-bounds");
238-
return isa<RankedTensorType>(getInputs()[inputIdx].getType());
239-
}
240-
241-
/// Returns the i-th input argument's bounds attribute.
242-
BoundsAttr getInputBoundsAttr(unsigned inputIdx) {
243-
assert(inputIdx < getInputs().size() && "input index out-of-bounds");
244-
return cast<BoundsAttr>(getInputAttrs()[inputIdx]);
245-
}
246-
247282
ArrayRef<BlockArgument> getRegionOutArgs() {
248283
return getBody().getArguments().take_back(getOuts().size());
249284
}
@@ -255,16 +290,75 @@ def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [
255290
ArrayRef<Attribute>(boundsAttrs.begin(), boundsAttrs.end())
256291
));
257292
}
293+
}];
294+
}
258295

259-
/// Populate the `input_attrs` from an array of BoundsAttrs.
260-
void setInputAttrsAttr(ArrayRef<BoundsAttr> boundsAttrs) {
261-
setInputAttrsAttr(::mlir::ArrayAttr::get(
262-
getOperation()->getContext(),
263-
ArrayRef<Attribute>(boundsAttrs.begin(), boundsAttrs.end())
264-
));
265-
}
296+
//===----------------------------------------------------------------------===//
297+
// InlineClosedGroupNonDPSOp
298+
//===----------------------------------------------------------------------===//
299+
300+
def Plan_InlineClosedGroupNonDPSOp : Plan_InlineClosedGroupBase<"inline_closed_group_non_dps", [
301+
IsolatedFromAbove,
302+
SingleBlockImplicitTerminator<"plan::YieldOp">,
303+
DeclareOpInterfaceMethods<RegionBranchOpInterface,
304+
["getEntrySuccessorOperands"]>,
305+
DeclareOpInterfaceMethods<OpAsmOpInterface,
306+
["getAsmBlockArgumentNames"]>
307+
]> {
308+
let description = [{
309+
The `plan.inline_closed_group_non_dps` operation is a variant of the
310+
`plan.inline_closed_group` operation that does not use destination-passing style
311+
(DPS). It is isolated from above and explicitly captures input operands,
312+
but unlike its DPS counterpart, it does not capture destination operands.
313+
This operation takes input operands and their corresponding bounds attributes,
314+
and produces results. The `input_attrs` hold bounds attribute information for
315+
the input operands. The absence of bounds information is allowed (`none` bounds).
316+
317+
The `target` attribute specifies the execution target for the group.
318+
319+
#### Example
320+
321+
Consider the following simple program containing operations with dynamically shaped operands:
322+
323+
```mlir
324+
%0 = ... : tensor<?xf32> // A dynamically shaped operand
325+
%1 = ... : index // A dynamic calculation of %0's extent
326+
327+
%2 = plan.inline_closed_group_non_dps target(#plan.cluster_target<tensorrt>)
328+
inputs(%0, %1 : tensor<?xf32>, index)
329+
in_attrs [#plan.bounds<shape, , >, #plan.bounds<none>] -> tensor<?xf32> {
330+
%3 = plan.with_shape %0 (%1) : (tensor<?xf32>, index) -> tensor<?xf32>
331+
%4 = stablehlo.exponential %3 : tensor<?xf32>
332+
yield %4 : tensor<?xf32>
333+
}
266334

267335
}];
336+
let arguments = (ins Variadic<AnyTypeOf<[AnyRankedTensor, AnySignlessIntegerOrIndex]>>:$inputs,
337+
BoundsAttrArray:$input_attrs,
338+
AnyAttr:$target);
339+
340+
let results = (outs Variadic<AnyTypeOf<[AnyRankedTensor]>>:$results);
341+
342+
let assemblyFormat = [{
343+
`target` `(` $target `)` `\n`
344+
`inputs` `(` ( $inputs^ `:` type($inputs) `)` ) : ( `)` ) ? `\n`
345+
`in_attrs` $input_attrs `\n`
346+
attr-dict-with-keyword `->` type($results)
347+
$body
348+
}];
349+
350+
let hasVerifier = 1;
351+
352+
let skipDefaultBuilders = 1;
353+
354+
let builders = [
355+
OpBuilder<(ins "TypeRange":$results,
356+
"Attribute":$target,
357+
"ValueRange":$inputs,
358+
CArg<"ArrayRef<BoundsAttr>", "{}">:$input_attrs)>,
359+
];
360+
361+
let extraClassDeclaration = baseInlineClosedExtraClassDeclaration;
268362
}
269363

270364
//===----------------------------------------------------------------------===//
@@ -276,7 +370,7 @@ def Plan_YieldOp : Plan_Op<"yield", [
276370
Terminator,
277371
ReturnLike,
278372
ParentOneOf<["plan::InlineGroupOp",
279-
"plan::InlineClosedGroupOp"]>]> {
373+
"plan::InlineClosedGroupOp", "plan::InlineClosedGroupNonDPSOp"]>]> {
280374

281375
let arguments = (ins Variadic<AnyType>:$results);
282376

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/Transforms/Passes.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ executorOneShotModuleBufferize(ModuleOp targetOp,
6969
const ExecutorBufferizationOptions &options);
7070

7171
/// Build a pipeline (targeting ModuleOp) for bufferization.
72-
void buildPlanBufferizationPipeline(OpPassManager &pm);
72+
void buildPlanBufferizationPipeline(
73+
OpPassManager &pm, const plan::PlanAllocTensorsPassOptions &options);
7374

7475
/// Build a post-bufferization pipeline that performs optimizations on memrefs.
7576
void buildPlanBufferOptimizationPipeline(OpPassManager &pm);

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/Transforms/Passes.td

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,9 @@ def StablehloClusteringPass : Pass<"stablehlo-clustering", "::mlir::ModuleOp"> {
248248
Option<"entrypoint", "entrypoint", "std::string", "\"\"",
249249
"the name of the entrypoint function; if empty then the clustering runs"
250250
" on all functions">,
251+
Option<"useNonDPSCallConv",
252+
"use-non-dps-call-conv", "bool", "false",
253+
"allow tensorrt based output allocations using output allocator">,
251254
Option<"disallowHostTensorsInTensorRTClusters",
252255
"disallow-host-tensors-in-tensorrt-clusters", "bool", "false",
253256
"don't cluster host tensors in TensorRT clusters">,
@@ -332,7 +335,10 @@ def CreateClosedRegionsPass : Pass<"plan-create-closed-regions", "::mlir::Module
332335
Option<"testPreWalkOrder", "test-pre-walk-order", "bool", "false",
333336
"(used only in testing) specifies to outline regions by walking in "
334337
" pre-order; used for verifying results are not sensitive "
335-
"to traversal order">
338+
"to traversal order">,
339+
Option<"useNonDPSCallConv", "use-non-dps-call-conv", "bool",
340+
/*default=*/"false",
341+
"Allow TensorRT-based output allocations using output allocator">
336342
];
337343

338344
let dependentDialects = [
@@ -428,6 +434,13 @@ def PlanAllocTensorsPass : Pass<"plan-alloc-tensors",
428434
"::mlir::bufferization::BufferizationDialect",
429435
"::mlir::plan::PlanDialect"
430436
];
437+
438+
let options = [
439+
Option<"useNonDPSCallConv", "use-non-dps-call-conv", "bool",
440+
/*default=*/"false",
441+
"Allow TensorRT-based output allocations using output allocator">
442+
];
443+
431444
}
432445

433446
//===----------------------------------------------------------------------===//

mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,10 @@ StableHLOToExecutableOptions::StableHLOToExecutableOptions(
222222
disallowHostTensorsInTensorRTClusters, llvm::cl::init(false),
223223
llvm::cl::desc("Don't allow TensorRt clusters to contain host tensor "
224224
"calculations (but they can still be inputs)"));
225+
addOption(
226+
"use-non-dps-call-conv", useNonDPSCallConv, llvm::cl::init(false),
227+
llvm::cl::desc(
228+
"allow tensorrt based output allocations using output allocator"));
225229
addOption("executor-index-bitwidth", executorIndexBitwidth,
226230
llvm::cl::init(64));
227231
addOption("device-compute-capability", deviceComputeCapability,
@@ -303,6 +307,7 @@ void StableHloToExecutableTask::buildStablehloClusteringPipeline(
303307
plan::StablehloClusteringPassOptions clusteringOpts{};
304308
clusteringOpts.disallowHostTensorsInTensorRTClusters =
305309
opts.disallowHostTensorsInTensorRTClusters;
310+
clusteringOpts.useNonDPSCallConv = opts.useNonDPSCallConv;
306311
clusteringOpts.entrypoint = opts.entrypoint;
307312
plan::buildPlanSegmentationPipeline(pm, clusteringOpts);
308313

@@ -336,7 +341,9 @@ void StableHloToExecutableTask::buildPostClusteringPipeline(
336341

337342
// Perform bufferization.
338343
pm.addPass(createMemRefCastEliminationPass());
339-
pm.addPass(plan::createPlanAllocTensorsPass());
344+
plan::PlanAllocTensorsPassOptions allocTensorsOpts{};
345+
allocTensorsOpts.useNonDPSCallConv = opts.useNonDPSCallConv;
346+
pm.addPass(plan::createPlanAllocTensorsPass(allocTensorsOpts));
340347
pm.addPass(plan::createPlanBufferizePass());
341348
pm.addPass(createMemRefCastEliminationPass());
342349
pm.addPass(createCanonicalizerPass());
@@ -485,13 +492,14 @@ StableHloToExecutableTask::compileStableHLOToExecutable(
485492
runner = pm.get();
486493
}
487494

488-
runner->printAsTextualPipeline(llvm::dbgs());
495+
if (options.debugOptions.dumpTextualPipeline)
496+
runner->printAsTextualPipeline(llvm::dbgs());
489497

490498
// Setup pass manager
491-
// if (failed(runner->run(module)))
492-
// return getInternalErrorStatus(
493-
// "failed to run compilation on module with symbol name: {0}",
494-
// module.getName() ? *module.getName() : "no-symbol-name");
499+
if (failed(runner->run(module)))
500+
return getInternalErrorStatus(
501+
"failed to run compilation on module with symbol name: {0}",
502+
module.getName() ? *module.getName() : "no-symbol-name");
495503

496504
// Translate to Runtime Executable
497505
FailureOr<std::unique_ptr<runtime::ExecutableStorage>> exeStorage =
@@ -524,6 +532,11 @@ struct ClusteringPipelineCliOpts
524532
*this, "device-compute-capability",
525533
llvm::cl::desc("target device compute capability (SM version)"),
526534
llvm::cl::init(60)};
535+
Option<bool> useNonDPSCallConv{
536+
*this, "use-non-dps-call-conv",
537+
llvm::cl::desc(
538+
"allow tensorrt based output allocations using output allocator"),
539+
llvm::cl::init(false)};
527540
Option<int64_t> deviceMaxSharedMemoryPerBlockKb{
528541
*this, "device-max-smem-per-block",
529542
llvm::cl::desc("max shared memory per block (in kilobytes)"),
@@ -551,6 +564,7 @@ static StableHLOToExecutableOptions populateStablehloClusteringPipelineOpts(
551564
opts.deviceComputeCapability = cliOpts.deviceComputeCapability;
552565
opts.deviceMaxSharedMemoryPerBlockKb =
553566
cliOpts.deviceMaxSharedMemoryPerBlockKb;
567+
opts.useNonDPSCallConv = cliOpts.useNonDPSCallConv;
554568
opts.shouldInferDeviceOptionsFromHost = cliOpts.inferDeviceOptionsFromHost;
555569
opts.entrypoint = cliOpts.entrypoint;
556570
return opts;

mlir-tensorrt/compiler/lib/Conversion/TensorRTRuntimeToExecutor/TensorRTRuntimeToExecutor.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ struct ConvertEnqueueAllocToCall
263263
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
264264

265265
// Function name for the enqueue alloc operation
266-
std::string funcName = "_trtrt_alloc_enqueue";
266+
std::string funcName = "_trtrt_enqueue_alloc";
267267

268268
// Create new operands for the call op
269269
SmallVector<Value> newOperands = {adaptor.getExecutionContext(),
@@ -394,7 +394,7 @@ struct ConvertEnqueueAllocToCall
394394
ArrayRef<OpFoldResult>{this->createIndexConstant(b, 0),
395395
rewriter.getI64IntegerAttr(offset++)});
396396

397-
Value rankValue = b.create<executor::LoadOp>(
397+
[[maybe_unused]] Value rankValue = b.create<executor::LoadOp>(
398398
b.getI64Type(), outputDescriptors, rankOffset);
399399
Value intPtr = b.create<executor::LoadOp>(
400400
b.getI64Type(), outputDescriptors, devicePtrOffset);
@@ -429,8 +429,10 @@ struct ConvertEnqueueAllocToCall
429429
resultRange.append(shapes.begin(), shapes.end());
430430
resultRange.append(strides.begin(), strides.end());
431431

432-
Value result = b.create<executor::CreateTableOp>(executor::TableType::get(
433-
b.getContext(), llvm::to_vector(TypeRange(resultRange))));
432+
Value result = b.create<executor::CreateTableOp>(
433+
executor::TableType::get(b.getContext(),
434+
llvm::to_vector(TypeRange(resultRange))),
435+
resultRange);
434436

435437
results.push_back(result);
436438
}

0 commit comments

Comments
 (0)