diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4c783fd3b495..a4757463d0a2 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6351,6 +6351,7 @@ def Torch_AtenIndexPutOp : Torch_Op<"aten.index_put", [ printDefaultTorchOp(printer, *this, 4, 1); } }]; + let hasVerifier = 1; } def Torch_AtenIndexPut_Op : Torch_Op<"aten.index_put_", [ diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index a000b7ab2f98..53598e451751 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -158,6 +158,10 @@ LogicalResult getPermutedType(BaseTensorType inType, SmallVector permuteDims, Type &permutedType); +// Check whether the given shapes of 2 tensors are broadcastable or not. +LogicalResult areStaticallyBroadcastCompatible(ArrayRef shapeA, + ArrayRef shapeB); + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index eb2f697c2596..64ca695b8178 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -6086,6 +6086,91 @@ LogicalResult AtenCountNonzeroDimIntListOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// AtenIndexPutOp +//===----------------------------------------------------------------------===// + +// Determine the common broadcast shape of all the index tensors. +SmallVector +getIndexBroadcastShape(SmallVector indicesTypes) { + int64_t indicesBroadcastRank = 0; + SmallVector indicesRank; + SmallVector> indicesShape; + for (auto indexTy : indicesTypes) { + indicesShape.push_back(indexTy.getSizes()); + int64_t rank = indexTy.getSizes().size(); + indicesRank.push_back(rank); + indicesBroadcastRank = std::max(rank, indicesBroadcastRank); + } + + auto maxDim = [](int64_t dim0, int64_t dim1) { + if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize) + return Torch::kUnknownSize; + return std::max(dim0, dim1); + }; + + SmallVector broadcastShape(indicesBroadcastRank, 0); + for (unsigned i = 0; i < indicesTypes.size(); i++) { + for (int32_t j = 0; j < indicesRank[i]; ++j) { + auto size = indicesShape[i][j]; + int32_t idx = broadcastShape.size() - indicesRank[i] + j; + broadcastShape[idx] = maxDim(size, broadcastShape[idx]); + } + } + return broadcastShape; +} + +LogicalResult AtenIndexPutOp::verify() { + if (isa(getIndices().getType())) + return success(); + + SmallVector indices; + if (!getListConstructElements(getIndices(), indices)) + return success(); + + SmallVector indicesTypes; + for (auto index : indices) { + // Skipping the none value in the indices list. + if (auto indexTy = dyn_cast(index.getType())) { + if (!indexTy.hasSizes()) + return success(); + indicesTypes.push_back(indexTy); + } + } + + auto inputType = cast(getSelf().getType()); + if (!inputType.hasSizes()) + return success(); + SmallVector inputShape(inputType.getSizes()); + + auto valuesType = cast(getValues().getType()); + if (!valuesType.hasSizes()) + return success(); + SmallVector valuesShape(valuesType.getSizes()); + + SmallVector indicesBroadcastShape( + getIndexBroadcastShape(indicesTypes)); + // In the case where the input rank is greater than the number of index + // tensors, the remaining dimensions of the input are indexed in their + // entirety. Thus, we need to append the remaining dimensions to get the shape + // of the indexed slice. + for (size_t i = indices.size(); i < inputShape.size(); i++) { + indicesBroadcastShape.push_back(inputShape[i]); + } + + // Check if the values tensor is broadcast compatible with indexing result + // shape or not. Here, we only check the static dimensions the dynamic ones + // will be caught by the downstream lowering through runtime checks. + if (failed( + areStaticallyBroadcastCompatible(valuesShape, indicesBroadcastShape))) + return emitOpError("values tensor shape [") + << valuesShape + << "] cannot be broadcasted to indexing result shape [" + << indicesBroadcastShape << "]\n"; + + return success(); +} + //===----------------------------------------------------------------------===// // OnnxVariantRotaryEmbeddingOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 388e31353571..11e3884c1e4c 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -709,3 +709,29 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) { return rewriter.getI64Type(); return inputType; } + +// Check whether the shapes of the tensors are broadcastable or not. +// Two tensors are “broadcastable” if the following rules hold: +// 1.) Each tensor has at least one dimension. +// 2.) When iterating over the dimension sizes, starting at the trailing +// dimension, the dimension sizes must either be equal, one of them is 1, or +// one of them does not exist. +LogicalResult +Torch::areStaticallyBroadcastCompatible(ArrayRef shapeA, + ArrayRef shapeB) { + unsigned rankA = shapeA.size(); + unsigned rankB = shapeB.size(); + unsigned minRank = std::min(rankA, rankB); + + for (unsigned i = 0; i < minRank; i++) { + int64_t dimA = shapeA[rankA - i - 1]; + int64_t dimB = shapeB[rankB - i - 1]; + // Here, we only check the static dimensions for compatibility. + if (dimA == Torch::kUnknownSize || dimB == Torch::kUnknownSize) + continue; + if (!(dimA == dimB || dimA == 1 || dimB == 1)) + return failure(); + } + + return success(); +} diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 0f6694132aec..234cf99a9807 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -557,7 +557,8 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)") emit_with_mutating_variants("aten::tril : (Tensor, int) -> (Tensor)") emit_with_mutating_variants( - "aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)" + "aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)", + has_verifier=True, ) emit_with_mutating_variants( "aten::index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)" diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index c863e93fa5fa..c383c75f256b 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -403,3 +403,13 @@ func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) - torch.bind_symbolic_shape %arg0, [%int0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> return %arg0 : !torch.vtensor<[?],f32> } + +// ----- + +func.func @index_put_values_shape_broadcast_incompatible(%arg0: !torch.vtensor<[?,32,16,192],f16>, %arg1: !torch.vtensor<[?],si64>, %arg2: !torch.vtensor<[?,32,128,192],f16>) -> !torch.vtensor<[?,32,16,192],f16> attributes {torch.onnx_meta.opset_version = 10 : si64} { + %0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false = torch.constant.bool false + // expected-error @+1 {{'torch.aten.index_put' op values tensor shape [-1, 32, 128, 192] cannot be broadcasted to indexing result shape [-1, 32, 16, 192]}} + %1 = torch.aten.index_put %arg0, %0, %arg2, %false : !torch.vtensor<[?,32,16,192],f16>, !torch.list>, !torch.vtensor<[?,32,128,192],f16>, !torch.bool -> !torch.vtensor<[?,32,16,192],f16> + return %1 : !torch.vtensor<[?,32,16,192],f16> +}