@@ -131,10 +131,57 @@ def Plan_InlineGroupOp : Plan_GroupOpBase<"inline_group", [
131
131
}
132
132
133
133
//===----------------------------------------------------------------------===//
134
- // InlineClosedGroupOp
134
+ // Plan_InlineClosedGroupBase
135
135
//===----------------------------------------------------------------------===//
136
136
137
- def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [
137
+ class Plan_InlineClosedGroupBase<string mnemonic, list<Trait> traits = []> :
138
+ Plan_GroupOpBase<mnemonic, traits> {
139
+
140
+ code baseInlineClosedExtraClassDeclaration = baseExtraClassDeclaration # [{
141
+ // Common methods for both DPS and non-DPS versions
142
+ bool argHasTensorType(unsigned inputIdx) {
143
+ assert(inputIdx < getInputs().size() && "input index out-of-bounds");
144
+ return isa<RankedTensorType>(getInputs()[inputIdx].getType());
145
+ }
146
+
147
+ BoundsAttr getInputBoundsAttr(unsigned inputIdx) {
148
+ assert(inputIdx < getInputs().size() && "input index out-of-bounds");
149
+ return cast<BoundsAttr>(getInputAttrs()[inputIdx]);
150
+ }
151
+
152
+ /// Populate the `input_attrs` from an array of BoundsAttrs.
153
+ void setInputAttrsAttr(ArrayRef<BoundsAttr> boundsAttrs) {
154
+ setInputAttrsAttr(::mlir::ArrayAttr::get(
155
+ getOperation()->getContext(),
156
+ ArrayRef<Attribute>(boundsAttrs.begin(), boundsAttrs.end())
157
+ ));
158
+ }
159
+
160
+ void getSuccessorRegionsBase(RegionBranchPoint point,
161
+ SmallVectorImpl<RegionSuccessor> ®ions) {
162
+ // If the predecessor is the InlineClosedGroupOp, branch into the body.
163
+ if (point.isParent()) {
164
+ regions.push_back(RegionSuccessor(&getBody(), getBody().getArguments()));
165
+ return;
166
+ }
167
+
168
+ // Otherwise, the region branches back to the parent operation.
169
+ regions.push_back(RegionSuccessor(getResults()));
170
+ }
171
+
172
+ OperandRange getEntrySuccessorOperandsBase(RegionBranchPoint point) {
173
+ return getOperands();
174
+ }
175
+ }];
176
+
177
+ let extraClassDeclaration = baseInlineClosedExtraClassDeclaration;
178
+ }
179
+
180
+ //===----------------------------------------------------------------------===//
181
+ // Plan_InlineClosedGroupOp
182
+ //===----------------------------------------------------------------------===//
183
+
184
+ def Plan_InlineClosedGroupOp : Plan_InlineClosedGroupBase<"inline_closed_group", [
138
185
IsolatedFromAbove,
139
186
AttrSizedOperandSegments,
140
187
DestinationStyleOpInterface,
@@ -226,24 +273,12 @@ def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [
226
273
CArg<"ArrayRef<BoundsAttr>", "{}">:$res_attrs)>
227
274
];
228
275
229
- let extraClassDeclaration = baseExtraClassDeclaration # [{
276
+ let extraClassDeclaration = baseInlineClosedExtraClassDeclaration # [{
230
277
231
278
MutableOperandRange getDpsInitsMutable() {
232
279
return getOutsMutable();
233
280
}
234
281
235
- /// Returns true if the `i-th` input argument has a tensor type.
236
- bool argHasTensorType(unsigned inputIdx) {
237
- assert(inputIdx < getInputs().size() && "input index out-of-bounds");
238
- return isa<RankedTensorType>(getInputs()[inputIdx].getType());
239
- }
240
-
241
- /// Returns the i-th input argument's bounds attribute.
242
- BoundsAttr getInputBoundsAttr(unsigned inputIdx) {
243
- assert(inputIdx < getInputs().size() && "input index out-of-bounds");
244
- return cast<BoundsAttr>(getInputAttrs()[inputIdx]);
245
- }
246
-
247
282
ArrayRef<BlockArgument> getRegionOutArgs() {
248
283
return getBody().getArguments().take_back(getOuts().size());
249
284
}
@@ -255,16 +290,75 @@ def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [
255
290
ArrayRef<Attribute>(boundsAttrs.begin(), boundsAttrs.end())
256
291
));
257
292
}
293
+ }];
294
+ }
258
295
259
- /// Populate the `input_attrs` from an array of BoundsAttrs.
260
- void setInputAttrsAttr(ArrayRef<BoundsAttr> boundsAttrs) {
261
- setInputAttrsAttr(::mlir::ArrayAttr::get(
262
- getOperation()->getContext(),
263
- ArrayRef<Attribute>(boundsAttrs.begin(), boundsAttrs.end())
264
- ));
265
- }
296
+ //===----------------------------------------------------------------------===//
297
+ // InlineClosedGroupNonDPSOp
298
+ //===----------------------------------------------------------------------===//
299
+
300
+ def Plan_InlineClosedGroupNonDPSOp : Plan_InlineClosedGroupBase<"inline_closed_group_non_dps", [
301
+ IsolatedFromAbove,
302
+ SingleBlockImplicitTerminator<"plan::YieldOp">,
303
+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
304
+ ["getEntrySuccessorOperands"]>,
305
+ DeclareOpInterfaceMethods<OpAsmOpInterface,
306
+ ["getAsmBlockArgumentNames"]>
307
+ ]> {
308
+ let description = [{
309
+ The `plan.inline_closed_group_non_dps` operation is a variant of the
310
+ `plan.inline_closed_group` operation that does not use destination-passing style
311
+ (DPS). It is isolated from above and explicitly captures input operands,
312
+ but unlike its DPS counterpart, it does not capture destination operands.
313
+ This operation takes input operands and their corresponding bounds attributes,
314
+ and produces results. The `input_attrs` hold bounds attribute information for
315
+ the input operands. The absence of bounds information is allowed (`none` bounds).
316
+
317
+ The `target` attribute specifies the execution target for the group.
318
+
319
+ #### Example
320
+
321
+ Consider the following simple program containing operations with dynamically shaped operands:
322
+
323
+ ```mlir
324
+ %0 = ... : tensor<?xf32> // A dynamically shaped operand
325
+ %1 = ... : index // A dynamic calculation of %0's extent
326
+
327
+ %2 = plan.inline_closed_group_non_dps target(#plan.cluster_target<tensorrt>)
328
+ inputs(%0, %1 : tensor<?xf32>, index)
329
+ in_attrs [#plan.bounds<shape, , >, #plan.bounds<none>] -> tensor<?xf32> {
330
+ %3 = plan.with_shape %0 (%1) : (tensor<?xf32>, index) -> tensor<?xf32>
331
+ %4 = stablehlo.exponential %3 : tensor<?xf32>
332
+ yield %4 : tensor<?xf32>
333
+ }
266
334
267
335
}];
336
+ let arguments = (ins Variadic<AnyTypeOf<[AnyRankedTensor, AnySignlessIntegerOrIndex]>>:$inputs,
337
+ BoundsAttrArray:$input_attrs,
338
+ AnyAttr:$target);
339
+
340
+ let results = (outs Variadic<AnyTypeOf<[AnyRankedTensor]>>:$results);
341
+
342
+ let assemblyFormat = [{
343
+ `target` `(` $target `)` `\n`
344
+ `inputs` `(` ( $inputs^ `:` type($inputs) `)` ) : ( `)` ) ? `\n`
345
+ `in_attrs` $input_attrs `\n`
346
+ attr-dict-with-keyword `->` type($results)
347
+ $body
348
+ }];
349
+
350
+ let hasVerifier = 1;
351
+
352
+ let skipDefaultBuilders = 1;
353
+
354
+ let builders = [
355
+ OpBuilder<(ins "TypeRange":$results,
356
+ "Attribute":$target,
357
+ "ValueRange":$inputs,
358
+ CArg<"ArrayRef<BoundsAttr>", "{}">:$input_attrs)>,
359
+ ];
360
+
361
+ let extraClassDeclaration = baseInlineClosedExtraClassDeclaration;
268
362
}
269
363
270
364
//===----------------------------------------------------------------------===//
@@ -276,7 +370,7 @@ def Plan_YieldOp : Plan_Op<"yield", [
276
370
Terminator,
277
371
ReturnLike,
278
372
ParentOneOf<["plan::InlineGroupOp",
279
- "plan::InlineClosedGroupOp"]>]> {
373
+ "plan::InlineClosedGroupOp", "plan::InlineClosedGroupNonDPSOp" ]>]> {
280
374
281
375
let arguments = (ins Variadic<AnyType>:$results);
282
376
0 commit comments