Skip to content

Commit d31ba52

Browse files
[mlir][Interface] Factor out common IndexingMapOpInterface behavior in a new generic interface (#145313)
Refactor the verifiers to make use of the common bits and make `vector.contract` also use this interface. In the process, the confusingly named getStaticShape has disappeared. Note: the verifier for IndexingMapOpInterface is currently called manually from other verifiers as it was unclear how to avoid it taking precedence over more meaningful error messages
1 parent 0c33799 commit d31ba52

File tree

17 files changed

+369
-238
lines changed

17 files changed

+369
-238
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/IR/ImplicitLocOpBuilder.h"
2121
#include "mlir/IR/OpDefinition.h"
2222
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
23+
#include "mlir/Interfaces/IndexingMapOpInterface.h"
2324
#include "mlir/Interfaces/InferTypeOpInterface.h"
2425
#include "mlir/Interfaces/ViewLikeInterface.h"
2526
#include "mlir/Support/RawOstreamExtras.h"

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 29 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define LINALG_IR_LINALGINTERFACES
1515

1616
include "mlir/Interfaces/DestinationStyleOpInterface.td"
17+
include "mlir/Interfaces/IndexingMapOpInterface.td"
1718
include "mlir/IR/OpBase.td"
1819

1920
// The 'LinalgContractionOpInterface' provides access to the
@@ -222,59 +223,11 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
222223
];
223224
}
224225

225-
def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> {
226-
let description = [{
227-
Interface for operations that connect an iteration domain to operands via
228-
affine maps. Provides methods to access indexing maps between iteration
229-
domain and operand index spaces.
230-
}];
231-
let cppNamespace = "::mlir::linalg";
232-
let methods = [
233-
InterfaceMethod<
234-
/*desc=*/[{
235-
Return the indexing maps attribute within the current operation.
236-
}],
237-
/*retTy=*/"ArrayAttr",
238-
/*methodName=*/"getIndexingMaps"
239-
>,
240-
InterfaceMethod<
241-
/*desc=*/[{
242-
Return the indexing maps within the current operation.
243-
}],
244-
/*retTy=*/"SmallVector<AffineMap>",
245-
/*methodName=*/"getIndexingMapsArray",
246-
/*args=*/(ins),
247-
/*methodBody=*/"",
248-
/*defaultImplementation=*/[{
249-
auto range = $_op.getIndexingMaps()
250-
.template getAsValueRange<AffineMapAttr>();
251-
return {range.begin(), range.end()};
252-
}]
253-
>,
254-
InterfaceMethod<
255-
/*desc=*/[{
256-
Return the input or output indexing map for `opOperand`.
257-
}],
258-
/*retTy=*/"AffineMap",
259-
/*methodName=*/"getMatchingIndexingMap",
260-
/*args=*/(ins "OpOperand*":$opOperand),
261-
/*methodBody=*/"",
262-
/*defaultImplementation=*/[{
263-
assert(opOperand->getOwner() == this->getOperation());
264-
auto indexingMaps =
265-
$_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
266-
return *(indexingMaps.begin() + opOperand->getOperandNumber());
267-
}]
268-
>,
269-
];
270-
}
271-
272226
// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
273227
def LinalgStructuredInterface
274-
: OpInterface<"LinalgOp", [
275-
DestinationStyleOpInterface,
276-
IndexingMapOpInterface
277-
]> {
228+
: OpInterface<"LinalgOp",
229+
[DestinationStyleOpInterface, IndexingMapOpInterface]
230+
> {
278231
let cppNamespace = "::mlir::linalg";
279232
let methods = [
280233
//===------------------------------------------------------------------===//
@@ -464,30 +417,6 @@ def LinalgStructuredInterface
464417
return getBlock()->getArguments().take_back($_op.getNumDpsInits());
465418
}]
466419
>,
467-
InterfaceMethod<
468-
/*desc=*/[{
469-
Return the `opOperand` shape or an empty vector for scalars or vectors
470-
not wrapped within a tensor or a memref.
471-
}],
472-
/*retTy=*/"ArrayRef<int64_t>",
473-
/*methodName=*/"getShape",
474-
/*args=*/(ins "OpOperand*":$opOperand),
475-
/*methodBody=*/"",
476-
/*defaultImplementation=*/[{
477-
assert(opOperand->getOwner() == this->getOperation());
478-
Type t = opOperand->get().getType();
479-
// A VectorType is an elemental type, do not consider its rank for the operand.
480-
if (isa<VectorType>(t))
481-
return {};
482-
if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
483-
// Failsafe.
484-
assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
485-
"expected a ranked tensor or memref in LinalgInterface::getRank");
486-
return shapedType.getShape();
487-
}
488-
return {};
489-
}]
490-
>,
491420
InterfaceMethod<
492421
/*desc=*/[{
493422
Return the block argument for an `opOperand`.
@@ -620,7 +549,12 @@ def LinalgStructuredInterface
620549
/*args=*/(ins),
621550
/*methodBody=*/"",
622551
/*defaultImplementation=*/[{
623-
return llvm::any_of(getStaticShape(), ShapedType::isDynamic);
552+
for (OpOperand &opOperand : this->getOperation()->getOpOperands()) {
553+
if (auto shapedType = dyn_cast<ShapedType>(opOperand.get().getType())) {
554+
if (ShapedType::isDynamicShape(shapedType.getShape())) return true;
555+
}
556+
}
557+
return false;
624558
}]
625559
>,
626560
InterfaceMethod<
@@ -738,53 +672,6 @@ def LinalgStructuredInterface
738672
//===------------------------------------------------------------------===//
739673
// Linalg generalization hooks.
740674
//===------------------------------------------------------------------===//
741-
InterfaceMethod<
742-
/*desc=*/[{
743-
Hook to provide a custom AffineMap used to compute all the operand
744-
subshapes given loop bounds. This is used to answer the question: "given
745-
an iteration space over the codomain, what are the subshapes of the
746-
operands involved in the computation".
747-
The default behavior is to just concatenate all the indexing maps.
748-
A custom AffineMap allows providing a map that can be used to
749-
compute subshapes even in cases where the concatenation of indexing maps
750-
(i.e. the data traversal order) is not a simple permutation of the loop
751-
traversal order. It is then possible to define ops with skewed data
752-
traversal order for which we can still easily compute hyperrectangular
753-
loop bounds and subviews.
754-
}],
755-
/*retTy=*/"AffineMap",
756-
/*methodName=*/"getLoopsToShapesMap",
757-
/*args=*/(ins),
758-
/*methodBody=*/"",
759-
/*defaultImplementation=*/[{
760-
auto maps = $_op.getIndexingMapsArray();
761-
return concatAffineMaps(maps, $_op.getContext());
762-
}]
763-
>,
764-
InterfaceMethod<
765-
/*desc=*/[{
766-
Hook to provide a custom AffineMap used to construct the
767-
hyperrectangular loop iteration space given all the operand subshapes.
768-
This is used to answer the question:
769-
"Given a list of operand ranges, what is the subportion of the iteration
770-
space involved in the computation".
771-
This is the inverse problem of `getLoopsToShapesMap`.
772-
Return the empty AffineMap when such an AffineMap cannot be constructed.
773-
The default behavior is based on a very simple inference procedure that
774-
only works with permutation affine maps.
775-
A more advanced Tensor-Comprehension like inference is possible but has
776-
proven to be ambiguous in unfavorable case.
777-
A safer and more robust alternative is to allow each op to define
778-
its own AffineMap.
779-
}],
780-
/*retTy=*/"AffineMap",
781-
/*methodName=*/"getShapesToLoopsMap",
782-
/*args=*/(ins),
783-
/*methodBody=*/"",
784-
/*defaultImplementation=*/[{
785-
return inversePermutation(getLoopsToShapesMap());
786-
}]
787-
>,
788675
InterfaceMethod<
789676
/*desc=*/[{
790677
Checks if the given operands can be dropped, and the remaining
@@ -798,39 +685,30 @@ def LinalgStructuredInterface
798685
return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands);
799686
}]
800687
>,
688+
//===------------------------------------------------------------------===//
689+
// IndexingMapOpInterface interface methods implementation.
690+
//===------------------------------------------------------------------===//
801691
InterfaceMethod<
802692
/*desc=*/[{
803-
Like `getShape`, but only returns statically-known information, without
804-
generating any new IR. For each shape dimension, returns >=0 if that
805-
dimension is statically known, or ShapedType::kDynamic otherwise.
806-
}],
807-
/*retTy=*/"SmallVector<int64_t>",
808-
/*methodName=*/"getStaticShape",
809-
/*args=*/(ins),
810-
/*methodBody=*/"",
811-
/*defaultImplementation=*/[{
812-
SmallVector<int64_t> res;
813-
for (OpOperand &opOperand : this->getOperation()->getOpOperands())
814-
llvm::append_range(res, getShape(&opOperand));
815-
return res;
816-
}]
817-
>,
818-
InterfaceMethod<
819-
/*desc=*/[{
820-
Returns the statically-known loop ranges. Composes
821-
`getShapesToLoopsMap()` with the result of `getStaticShape`.
822-
Returns ShapedType::kDynamic for non-statically-known loop ranges.
823-
This is expected to be called by a valid Linalg op
693+
Return the `opOperand` shape or an empty vector for scalars or vectors
694+
not wrapped within a tensor or a memref.
824695
}],
825-
/*retTy=*/"SmallVector<int64_t, 4>",
826-
/*methodName=*/"getStaticLoopRanges",
827-
/*args=*/(ins),
696+
/*retTy=*/"ArrayRef<int64_t>",
697+
/*methodName=*/"getShape",
698+
/*args=*/(ins "OpOperand*":$opOperand),
828699
/*methodBody=*/"",
829700
/*defaultImplementation=*/[{
830-
SmallVector<int64_t> viewSizes = getStaticShape();
831-
AffineMap invertedMap = getShapesToLoopsMap();
832-
assert(invertedMap && "expected a valid Linalg op to call the method");
833-
return invertedMap.compose(viewSizes);
701+
Type t = opOperand->get().getType();
702+
// A VectorType is an elemental type, do not consider its rank for the operand.
703+
if (isa<VectorType>(t))
704+
return {};
705+
if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
706+
// Failsafe.
707+
assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
708+
"expected a ranked tensor or memref in LinalgInterface::getRank");
709+
return shapedType.getShape();
710+
}
711+
return {};
834712
}]
835713
>,
836714
//===------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/IR/PatternMatch.h"
2626
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2727
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
28+
#include "mlir/Interfaces/IndexingMapOpInterface.h"
2829
#include "mlir/Interfaces/InferTypeOpInterface.h"
2930
#include "mlir/Interfaces/SideEffectInterfaces.h"
3031
#include "mlir/Interfaces/VectorInterfaces.h"

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ include "mlir/Dialect/Vector/IR/Vector.td"
2121
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
2222
include "mlir/Interfaces/ControlFlowInterfaces.td"
2323
include "mlir/Interfaces/DestinationStyleOpInterface.td"
24+
include "mlir/Interfaces/IndexingMapOpInterface.td"
2425
include "mlir/Interfaces/InferIntRangeInterface.td"
2526
include "mlir/Interfaces/InferTypeOpInterface.td"
2627
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -33,6 +34,7 @@ include "mlir/IR/EnumAttr.td"
3334
// than the current set: {*, +}.
3435
def Vector_ContractionOp :
3536
Vector_Op<"contract", [
37+
IndexingMapOpInterface,
3638
Pure,
3739
PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
3840
PredOpTrait<"third operand acc and result have same element type",
@@ -207,6 +209,16 @@ def Vector_ContractionOp :
207209
.template getAsValueRange<IteratorTypeAttr, IteratorType>();
208210
return {range.begin(), range.end()};
209211
}
212+
213+
//===------------------------------------------------------------------===//
214+
// IndexingMapOpInterface interface methods implementation.
215+
//===------------------------------------------------------------------===//
216+
ArrayRef<int64_t> getShape(OpOperand * opOperand) {
217+
Type t = opOperand->get().getType();
218+
if (auto vt = dyn_cast<VectorType>(t))
219+
return vt.getShape();
220+
return {};
221+
}
210222
}];
211223

212224
let hasCanonicalizer = 1;

mlir/include/mlir/Interfaces/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_mlir_interface(CopyOpInterface)
55
add_mlir_interface(DerivedAttributeOpInterface)
66
add_mlir_interface(DestinationStyleOpInterface)
77
add_mlir_interface(FunctionInterfaces)
8+
add_mlir_interface(IndexingMapOpInterface)
89
add_mlir_interface(InferIntRangeInterface)
910
add_mlir_interface(InferTypeOpInterface)
1011
add_mlir_interface(LoopLikeInterface)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- IndexingMapOpInterface.h ----------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE_H_
10+
#define MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE_H_
11+
12+
#include "mlir/IR/AffineMap.h"
13+
#include "mlir/IR/BuiltinAttributes.h"
14+
#include "mlir/IR/BuiltinTypes.h"
15+
#include "mlir/IR/OpDefinition.h"
16+
17+
namespace mlir {
18+
namespace detail {
19+
/// Verify that `op` conforms to the invariants of StructuredOpInterface
20+
LogicalResult verifyIndexingMapOpInterface(Operation *op);
21+
} // namespace detail
22+
} // namespace mlir
23+
24+
/// Include the generated interface declarations.
25+
#include "mlir/Interfaces/IndexingMapOpInterface.h.inc"
26+
27+
#endif // MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE_H_

0 commit comments

Comments
 (0)