14
14
#define LINALG_IR_LINALGINTERFACES
15
15
16
16
include "mlir/Interfaces/DestinationStyleOpInterface.td"
17
+ include "mlir/Interfaces/IndexingMapOpInterface.td"
17
18
include "mlir/IR/OpBase.td"
18
19
19
20
// The 'LinalgContractionOpInterface' provides access to the
@@ -222,59 +223,11 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
222
223
];
223
224
}
224
225
225
- def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> {
226
- let description = [{
227
- Interface for operations that connect an iteration domain to operands via
228
- affine maps. Provides methods to access indexing maps between iteration
229
- domain and operand index spaces.
230
- }];
231
- let cppNamespace = "::mlir::linalg";
232
- let methods = [
233
- InterfaceMethod<
234
- /*desc=*/[{
235
- Return the indexing maps attribute within the current operation.
236
- }],
237
- /*retTy=*/"ArrayAttr",
238
- /*methodName=*/"getIndexingMaps"
239
- >,
240
- InterfaceMethod<
241
- /*desc=*/[{
242
- Return the indexing maps within the current operation.
243
- }],
244
- /*retTy=*/"SmallVector<AffineMap>",
245
- /*methodName=*/"getIndexingMapsArray",
246
- /*args=*/(ins),
247
- /*methodBody=*/"",
248
- /*defaultImplementation=*/[{
249
- auto range = $_op.getIndexingMaps()
250
- .template getAsValueRange<AffineMapAttr>();
251
- return {range.begin(), range.end()};
252
- }]
253
- >,
254
- InterfaceMethod<
255
- /*desc=*/[{
256
- Return the input or output indexing map for `opOperand`.
257
- }],
258
- /*retTy=*/"AffineMap",
259
- /*methodName=*/"getMatchingIndexingMap",
260
- /*args=*/(ins "OpOperand*":$opOperand),
261
- /*methodBody=*/"",
262
- /*defaultImplementation=*/[{
263
- assert(opOperand->getOwner() == this->getOperation());
264
- auto indexingMaps =
265
- $_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
266
- return *(indexingMaps.begin() + opOperand->getOperandNumber());
267
- }]
268
- >,
269
- ];
270
- }
271
-
272
226
// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
273
227
def LinalgStructuredInterface
274
- : OpInterface<"LinalgOp", [
275
- DestinationStyleOpInterface,
276
- IndexingMapOpInterface
277
- ]> {
228
+ : OpInterface<"LinalgOp",
229
+ [DestinationStyleOpInterface, IndexingMapOpInterface]
230
+ > {
278
231
let cppNamespace = "::mlir::linalg";
279
232
let methods = [
280
233
//===------------------------------------------------------------------===//
@@ -464,30 +417,6 @@ def LinalgStructuredInterface
464
417
return getBlock()->getArguments().take_back($_op.getNumDpsInits());
465
418
}]
466
419
>,
467
- InterfaceMethod<
468
- /*desc=*/[{
469
- Return the `opOperand` shape or an empty vector for scalars or vectors
470
- not wrapped within a tensor or a memref.
471
- }],
472
- /*retTy=*/"ArrayRef<int64_t>",
473
- /*methodName=*/"getShape",
474
- /*args=*/(ins "OpOperand*":$opOperand),
475
- /*methodBody=*/"",
476
- /*defaultImplementation=*/[{
477
- assert(opOperand->getOwner() == this->getOperation());
478
- Type t = opOperand->get().getType();
479
- // A VectorType is an elemental type, do not consider its rank for the operand.
480
- if (isa<VectorType>(t))
481
- return {};
482
- if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
483
- // Failsafe.
484
- assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
485
- "expected a ranked tensor or memref in LinalgInterface::getRank");
486
- return shapedType.getShape();
487
- }
488
- return {};
489
- }]
490
- >,
491
420
InterfaceMethod<
492
421
/*desc=*/[{
493
422
Return the block argument for an `opOperand`.
@@ -620,7 +549,12 @@ def LinalgStructuredInterface
620
549
/*args=*/(ins),
621
550
/*methodBody=*/"",
622
551
/*defaultImplementation=*/[{
623
- return llvm::any_of(getStaticShape(), ShapedType::isDynamic);
552
+ for (OpOperand &opOperand : this->getOperation()->getOpOperands()) {
553
+ if (auto shapedType = dyn_cast<ShapedType>(opOperand.get().getType())) {
554
+ if (ShapedType::isDynamicShape(shapedType.getShape())) return true;
555
+ }
556
+ }
557
+ return false;
624
558
}]
625
559
>,
626
560
InterfaceMethod<
@@ -738,53 +672,6 @@ def LinalgStructuredInterface
738
672
//===------------------------------------------------------------------===//
739
673
// Linalg generalization hooks.
740
674
//===------------------------------------------------------------------===//
741
- InterfaceMethod<
742
- /*desc=*/[{
743
- Hook to provide a custom AffineMap used to compute all the operand
744
- subshapes given loop bounds. This is used to answer the question: "given
745
- an iteration space over the codomain, what are the subshapes of the
746
- operands involved in the computation".
747
- The default behavior is to just concatenate all the indexing maps.
748
- A custom AffineMap allows providing a map that can be used to
749
- compute subshapes even in cases where the concatenation of indexing maps
750
- (i.e. the data traversal order) is not a simple permutation of the loop
751
- traversal order. It is then possible to define ops with skewed data
752
- traversal order for which we can still easily compute hyperrectangular
753
- loop bounds and subviews.
754
- }],
755
- /*retTy=*/"AffineMap",
756
- /*methodName=*/"getLoopsToShapesMap",
757
- /*args=*/(ins),
758
- /*methodBody=*/"",
759
- /*defaultImplementation=*/[{
760
- auto maps = $_op.getIndexingMapsArray();
761
- return concatAffineMaps(maps, $_op.getContext());
762
- }]
763
- >,
764
- InterfaceMethod<
765
- /*desc=*/[{
766
- Hook to provide a custom AffineMap used to construct the
767
- hyperrectangular loop iteration space given all the operand subshapes.
768
- This is used to answer the question:
769
- "Given a list of operand ranges, what is the subportion of the iteration
770
- space involved in the computation".
771
- This is the inverse problem of `getLoopsToShapesMap`.
772
- Return the empty AffineMap when such an AffineMap cannot be constructed.
773
- The default behavior is based on a very simple inference procedure that
774
- only works with permutation affine maps.
775
- A more advanced Tensor-Comprehension like inference is possible but has
776
- proven to be ambiguous in unfavorable case.
777
- A safer and more robust alternative is to allow each op to define
778
- its own AffineMap.
779
- }],
780
- /*retTy=*/"AffineMap",
781
- /*methodName=*/"getShapesToLoopsMap",
782
- /*args=*/(ins),
783
- /*methodBody=*/"",
784
- /*defaultImplementation=*/[{
785
- return inversePermutation(getLoopsToShapesMap());
786
- }]
787
- >,
788
675
InterfaceMethod<
789
676
/*desc=*/[{
790
677
Checks if the given operands can be dropped, and the remaining
@@ -798,39 +685,30 @@ def LinalgStructuredInterface
798
685
return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands);
799
686
}]
800
687
>,
688
+ //===------------------------------------------------------------------===//
689
+ // IndexingMapOpInterface interface methods implementation.
690
+ //===------------------------------------------------------------------===//
801
691
InterfaceMethod<
802
692
/*desc=*/[{
803
- Like `getShape`, but only returns statically-known information, without
804
- generating any new IR. For each shape dimension, returns >=0 if that
805
- dimension is statically known, or ShapedType::kDynamic otherwise.
806
- }],
807
- /*retTy=*/"SmallVector<int64_t>",
808
- /*methodName=*/"getStaticShape",
809
- /*args=*/(ins),
810
- /*methodBody=*/"",
811
- /*defaultImplementation=*/[{
812
- SmallVector<int64_t> res;
813
- for (OpOperand &opOperand : this->getOperation()->getOpOperands())
814
- llvm::append_range(res, getShape(&opOperand));
815
- return res;
816
- }]
817
- >,
818
- InterfaceMethod<
819
- /*desc=*/[{
820
- Returns the statically-known loop ranges. Composes
821
- `getShapesToLoopsMap()` with the result of `getStaticShape`.
822
- Returns ShapedType::kDynamic for non-statically-known loop ranges.
823
- This is expected to be called by a valid Linalg op
693
+ Return the `opOperand` shape or an empty vector for scalars or vectors
694
+ not wrapped within a tensor or a memref.
824
695
}],
825
- /*retTy=*/"SmallVector <int64_t, 4 >",
826
- /*methodName=*/"getStaticLoopRanges ",
827
- /*args=*/(ins),
696
+ /*retTy=*/"ArrayRef <int64_t>",
697
+ /*methodName=*/"getShape ",
698
+ /*args=*/(ins "OpOperand*":$opOperand ),
828
699
/*methodBody=*/"",
829
700
/*defaultImplementation=*/[{
830
- SmallVector<int64_t> viewSizes = getStaticShape();
831
- AffineMap invertedMap = getShapesToLoopsMap();
832
- assert(invertedMap && "expected a valid Linalg op to call the method");
833
- return invertedMap.compose(viewSizes);
701
+ Type t = opOperand->get().getType();
702
+ // A VectorType is an elemental type, do not consider its rank for the operand.
703
+ if (isa<VectorType>(t))
704
+ return {};
705
+ if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
706
+ // Failsafe.
707
+ assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
708
+ "expected a ranked tensor or memref in LinalgInterface::getRank");
709
+ return shapedType.getShape();
710
+ }
711
+ return {};
834
712
}]
835
713
>,
836
714
//===------------------------------------------------------------------===//
0 commit comments