From 06e4ca7d3d2e882000fa432eae261d7af52e9072 Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Mon, 25 Nov 2024 08:57:46 +0000 Subject: [PATCH] Return a failure instead of crashing if shape inference can not be run because of unraked operand types Signed-off-by: Jonas Rickert --- src/Dialect/ONNX/ONNXOps/Math/DFT.cpp | 3 +++ src/Dialect/ONNX/ONNXOps/Math/MatMul.cpp | 3 +++ src/Dialect/ONNX/ONNXOps/Math/Reduction.cpp | 9 ++++++++- src/Dialect/ONNX/ONNXOps/Math/TopK.cpp | 3 +++ src/Dialect/ONNX/ONNXOps/NN/Conv.cpp | 3 +++ src/Dialect/ONNX/ONNXOps/NN/NNHelper.cpp.inc | 3 +++ src/Dialect/ONNX/ONNXOps/NN/Pooling.cpp | 10 +++++++--- src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp | 5 +++++ src/Dialect/ONNX/ONNXOps/Tensor/Compress.cpp | 3 +++ src/Dialect/ONNX/ONNXOps/Tensor/DepthToSpace.cpp | 3 +++ src/Dialect/ONNX/ONNXOps/Tensor/NonZero.cpp | 6 +++++- src/Dialect/ONNX/ONNXOps/Tensor/OneHot.cpp | 3 +++ src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp | 3 +++ src/Dialect/ONNX/ONNXOps/Tensor/Resize.cpp | 8 ++++++-- src/Dialect/ONNX/ONNXOps/Tensor/Shape.cpp | 3 +++ src/Dialect/ONNX/ONNXOps/Tensor/SpaceToDepth.cpp | 4 +++- src/Dialect/ONNX/ONNXOps/Tensor/Split.cpp | 3 +++ src/Dialect/ONNX/ONNXOps/Tensor/Squeeze.cpp | 3 +++ src/Dialect/ONNX/ONNXOps/Tensor/Tile.cpp | 3 +++ src/Dialect/ONNX/ONNXOps/Tensor/Transpose.cpp | 3 +++ src/Dialect/ONNX/ONNXOps/Tensor/Unique.cpp | 3 +++ src/Dialect/ONNX/ONNXOps/Tensor/Unsqueeze.cpp | 3 +++ 22 files changed, 82 insertions(+), 8 deletions(-) diff --git a/src/Dialect/ONNX/ONNXOps/Math/DFT.cpp b/src/Dialect/ONNX/ONNXOps/Math/DFT.cpp index 537ca3a018..787f18ae70 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/DFT.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/DFT.cpp @@ -33,6 +33,9 @@ LogicalResult ONNXGenericDFTOpShapeHelper::customComputeShape( // Get info about input data operand. Value input = operandAdaptor.getInput(); // Get the rank to compensate for N dimensions. + if (!hasShapeAndRank(input)) { + return failure(); + } int64_t rank = createIE->getShapedTypeRank(input); // Check if the dimension for axis is a literal and in range. diff --git a/src/Dialect/ONNX/ONNXOps/Math/MatMul.cpp b/src/Dialect/ONNX/ONNXOps/Math/MatMul.cpp index c17b62679d..d9fe81fce1 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/MatMul.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/MatMul.cpp @@ -55,6 +55,9 @@ LogicalResult ONNXGenericMatMulOpShapeHelper::computeShape() { std::tie(A, B) = matMulInputs(operandAdaptor); // Size all the arrays to padded length. + if (!hasShapeAndRank(A) || !hasShapeAndRank(B)) { + return failure(); + } uint64_t aRank = createIE->getShapedTypeRank(A); uint64_t bRank = createIE->getShapedTypeRank(B); int paddedRank = std::max(aRank, bRank); diff --git a/src/Dialect/ONNX/ONNXOps/Math/Reduction.cpp b/src/Dialect/ONNX/ONNXOps/Math/Reduction.cpp index d988f67bc1..536a585f40 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/Reduction.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/Reduction.cpp @@ -29,6 +29,9 @@ LogicalResult ONNXGenericReductionOpShapeHelper::customComputeShape( DimsExpr &axes, int noopWithEmptyAxes) { typename OP_TYPE::Adaptor operandAdaptor(operands, op->getAttrDictionary()); Value data = operandAdaptor.getData(); + if (!hasShapeAndRank(data)) { + return failure(); + } int64_t rank = createIE->getShapedTypeRank(data); // Normalize the axes: at present, we only support compile time axes, but // with keep_dim on, it might not be too difficult to generate the code. @@ -104,7 +107,11 @@ LogicalResult ONNXGenericReductionOpShapeHelper::computeShape() { createIE->getIntFromArrayAsSymbols(operandAdaptor.getAxes(), axes); } else { // When the axis is dynamic, try to infer the rank of output tensor - int64_t dataRank = createIE->getShapedTypeRank(operandAdaptor.getData()); + const auto data = operandAdaptor.getData(); + if (!hasShapeAndRank(data)) { + return failure(); + } + int64_t dataRank = createIE->getShapedTypeRank(data); int64_t axlesSize = createIE->getArraySize(operandAdaptor.getAxes()); if (!operandAdaptor.getKeepdims() && axlesSize < 0 /*undef shape*/) { // Even though we did not compute the shape in ShapeHelper, return diff --git a/src/Dialect/ONNX/ONNXOps/Math/TopK.cpp b/src/Dialect/ONNX/ONNXOps/Math/TopK.cpp index 641faa1e4d..98e0ec45f6 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/TopK.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/TopK.cpp @@ -31,6 +31,9 @@ LogicalResult ONNXTopKOpShapeHelper::computeShape() { // Get info about X and K operands. Value X = operandAdaptor.getX(); Value K = operandAdaptor.getK(); + if (!hasShapeAndRank(X)) { + return failure(); + } int64_t rank = createIE->getShapedTypeRank(X); // Axis to compute TopK. diff --git a/src/Dialect/ONNX/ONNXOps/NN/Conv.cpp b/src/Dialect/ONNX/ONNXOps/NN/Conv.cpp index 558ee9a7ad..0539ad446c 100644 --- a/src/Dialect/ONNX/ONNXOps/NN/Conv.cpp +++ b/src/Dialect/ONNX/ONNXOps/NN/Conv.cpp @@ -374,6 +374,9 @@ LogicalResult ONNXConvTransposeOpShapeHelper::computeShape() { Value wValue = operandAdaptor.getW(); // Basic information. + if (!hasShapeAndRank(xValue)) { + return failure(); + } int64_t rank = createIE->getShapedTypeRank(xValue); int64_t spatialOffset = 2; int64_t spatialRank = rank - spatialOffset; diff --git a/src/Dialect/ONNX/ONNXOps/NN/NNHelper.cpp.inc b/src/Dialect/ONNX/ONNXOps/NN/NNHelper.cpp.inc index 6a0ee296c6..688624d3f3 100644 --- a/src/Dialect/ONNX/ONNXOps/NN/NNHelper.cpp.inc +++ b/src/Dialect/ONNX/ONNXOps/NN/NNHelper.cpp.inc @@ -31,6 +31,9 @@ LogicalResult ONNXGenericPoolOpShapeHelper::customComputeShape( std::optional strideOpt, std::optional dilationOpt, bool hasFilter, bool ceilMode) { // Basic information. + if(!hasShapeAndRank(xValue)) { + return failure(); + } int64_t rank = createIE->getShapedTypeRank(xValue); int64_t spatialOffset = 2; int64_t spatialRank = rank - spatialOffset; diff --git a/src/Dialect/ONNX/ONNXOps/NN/Pooling.cpp b/src/Dialect/ONNX/ONNXOps/NN/Pooling.cpp index 1a996de63d..59ef14d10f 100644 --- a/src/Dialect/ONNX/ONNXOps/NN/Pooling.cpp +++ b/src/Dialect/ONNX/ONNXOps/NN/Pooling.cpp @@ -48,14 +48,18 @@ LogicalResult ONNXGenericGlobalPoolOpShapeHelper::computeShape() { template <> LogicalResult ONNXMaxRoiPoolOpShapeHelper::computeShape() { ONNXMaxRoiPoolOpAdaptor operandAdaptor(operands, op->getAttrDictionary()); - IndexExpr channel = createIE->getShapeAsDim(operandAdaptor.getX(), 1); - uint64_t roisRank = createIE->getShapedTypeRank(operandAdaptor.getRois()); + + const auto rois = operandAdaptor.getRois(); + if (!hasShapeAndRank(rois)) { + return failure(); + } + uint64_t roisRank = createIE->getShapedTypeRank(rois); if (roisRank != 2) return op->emitError("rois rank is expected to be 2d"); // 2d tensor: (num_rois, 5) - IndexExpr numRois = createIE->getShapeAsDim(operandAdaptor.getRois(), 0); + IndexExpr numRois = createIE->getShapeAsDim(rois, 0); DimsExpr pooledDims; createIE->getIntFromArrayAsLiterals( operandAdaptor.getPooledShape(), pooledDims); diff --git a/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp b/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp index 4a144ca2f4..1b0cd1dc34 100644 --- a/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp @@ -285,6 +285,11 @@ LogicalResult ONNXBroadcastOpShapeHelper::customComputeShape( DimsExpr dimsExpr; uint64_t numOfInputs = initialOperands.size(); + if (!llvm::all_of(initialOperands, + [](Value initalOperand) { return hasShapeAndRank(initalOperand); })) { + return failure(); + } + // Compute rank of the output. Rank of the output is the maximum rank of all // initial operands. uint64_t additionalOperRank = diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Compress.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Compress.cpp index ec6a7a1cb4..0bf3b36999 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Compress.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Compress.cpp @@ -31,6 +31,9 @@ LogicalResult ONNXCompressOpShapeHelper::computeShape() { ONNXCompressOpAdaptor operandAdaptor(operands); Value input = operandAdaptor.getInput(); Value cond = operandAdaptor.getCondition(); + if (!hasShapeAndRank(input)) { + return failure(); + } int64_t inputRank = createIE->getShapedTypeRank(input); createIE->assertHasShapeAndRank(cond); std::optional optionalAxis = compressOp.getAxis(); diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/DepthToSpace.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/DepthToSpace.cpp index 68c9a213a7..515d947efb 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/DepthToSpace.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/DepthToSpace.cpp @@ -30,6 +30,9 @@ LogicalResult ONNXDepthToSpaceOpShapeHelper::computeShape() { ONNXDepthToSpaceOp depthOp = llvm::cast(op); ONNXDepthToSpaceOpAdaptor operandAdaptor(operands); Value input = operandAdaptor.getInput(); + if (!hasShapeAndRank(input)) { + return failure(); + } int64_t inputRank = createIE->getShapedTypeRank(input); assert(inputRank == 4 && "Unexpected input tensor rank"); int64_t blocksize = depthOp.getBlocksize(); diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/NonZero.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/NonZero.cpp index 2c245032f0..bc9809fea0 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/NonZero.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/NonZero.cpp @@ -27,7 +27,11 @@ namespace onnx_mlir { template <> LogicalResult ONNXNonZeroOpShapeHelper::computeShape() { ONNXNonZeroOpAdaptor operandAdaptor(operands); - int64_t xRank = createIE->getShapedTypeRank(operandAdaptor.getX()); + auto x = operandAdaptor.getX(); + if (!hasShapeAndRank(x)) { + return failure(); + } + int64_t xRank = createIE->getShapedTypeRank(x); // Cannot refine shape as we may otherwise loose the dynamic dim. return setOutputDimsFromLiterals( {xRank, ShapedType::kDynamic}, 0, /*refineShape*/ false); diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/OneHot.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/OneHot.cpp index 3b1699b35e..80474ef9e9 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/OneHot.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/OneHot.cpp @@ -28,6 +28,9 @@ LogicalResult ONNXOneHotOpShapeHelper::computeShape() { ONNXOneHotOp oneHotOp = llvm::cast(op); ONNXOneHotOpAdaptor operandAdaptor(operands); Value indices = operandAdaptor.getIndices(); + if (!hasShapeAndRank(indices)) { + return failure(); + } int64_t indicesRank = createIE->getShapedTypeRank(indices); // Axis is a required attribute and should have default value of -1. diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp index 16e4713a91..3a384a5db2 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp @@ -31,6 +31,9 @@ LogicalResult ONNXPadOpShapeHelper::computeShape() { DimsExpr outputDims; // Get info about input data operand. + if (!hasShapeAndRank(dataOperand)) { + return failure(); + } uint64_t dataRank = createIE->getShapedTypeRank(dataOperand); // Initialize context and results (pads & output) diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Resize.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Resize.cpp index 761d80c503..d4559e37de 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Resize.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Resize.cpp @@ -48,9 +48,13 @@ LogicalResult ONNXResizeOpShapeHelper::computeShape() { ONNXResizeOpAdaptor operandAdaptor(operands, cast(op)); if (operandAdaptor.getAxes().has_value()) return op->emitOpError("axes are unsupported"); - uint64_t rank = createIE->getShapedTypeRank(operandAdaptor.getX()); + const auto x = operandAdaptor.getX(); + if (!hasShapeAndRank(x)) { + return failure(); + } + uint64_t rank = createIE->getShapedTypeRank(x); DimsExpr inputDims, outputDims; - createIE->getShapeAsDims(operandAdaptor.getX(), inputDims); + createIE->getShapeAsDims(x, inputDims); bool scalesIsAbsent = isAbsent(operandAdaptor.getScales()); if (!scalesIsAbsent) { // Read and save scales as float. diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Shape.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Shape.cpp index 374b226b43..187daea201 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Shape.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Shape.cpp @@ -52,6 +52,9 @@ LogicalResult ONNXShapeOpShapeHelper::computeShape() { Value data = operandAdaptor.getData(); // Compute and store start/end in ONNXShapeOpShapeHelper object. + if (!hasShapeAndRank(data)) { + return failure(); + } int64_t rank = createIE->getShapedTypeRank(data); start = shapeOp.getStart(); start = normalizeClampedPerSpec(start, rank); diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/SpaceToDepth.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/SpaceToDepth.cpp index ce3f4f84d0..55cb4bdd35 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/SpaceToDepth.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/SpaceToDepth.cpp @@ -33,7 +33,9 @@ LogicalResult ONNXSpaceToDepthOpShapeHelper::computeShape() { Value input = operandAdaptor.getInput(); int64_t blocksize = operandAdaptor.getBlocksize(); assert(blocksize > 0 && "blocksize should be strictly positive"); - + if (!hasShapeAndRank(input)) { + return failure(); + } int64_t inputRank = createIE->getShapedTypeRank(input); assert(inputRank == 4 && "Unexpected input tensor rank"); diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Split.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Split.cpp index be2eeaa887..1a9927afe2 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Split.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Split.cpp @@ -33,6 +33,9 @@ LogicalResult ONNXCommonSplitOpShapeHelper::customComputeShape( unsigned int numOfResults = splitOp.getNumResults(); Value input = operandAdaptor.getInput(); + if (!hasShapeAndRank(input)) { + return failure(); + } int64_t rank = createIE->getShapedTypeRank(input); // Checking value of axis parameter. diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Squeeze.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Squeeze.cpp index 786f1e136a..a3c8221cf3 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Squeeze.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Squeeze.cpp @@ -42,6 +42,9 @@ LogicalResult ONNXCommonSqueezeOpShapeHelper::customComputeShape( typename OP_TYPE::Adaptor operandAdaptor(operands, op->getAttrDictionary()); DimsExpr outputDims; Value data = operandAdaptor.getData(); + if (!hasShapeAndRank(data)) { + return failure(); + } int64_t dataRank = createIE->getShapedTypeRank(data); // Init state. diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Tile.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Tile.cpp index 96f403c409..6819b4c81f 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Tile.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Tile.cpp @@ -29,6 +29,9 @@ LogicalResult ONNXTileOpShapeHelper::computeShape() { ONNXTileOpAdaptor operandAdaptor(operands); // Get info about input data operand. Value input = operandAdaptor.getInput(); + if (!hasShapeAndRank(input)) { + return failure(); + } int64_t inputRank = createIE->getShapedTypeRank(input); Value repeats = operandAdaptor.getRepeats(); // Compute outputDims diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Transpose.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Transpose.cpp index 50e8663983..05a11d8189 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Transpose.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Transpose.cpp @@ -30,6 +30,9 @@ LogicalResult ONNXTransposeOpShapeHelper::computeShape() { ONNXTransposeOp transposeOp = llvm::cast(op); Value data = operandAdaptor.getData(); + if (!hasShapeAndRank(data)) { + return failure(); + } auto rank = createIE->getShapedTypeRank(data); // Transposition which handles the default case of diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Unique.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Unique.cpp index f9177073e0..842ed0b767 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Unique.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Unique.cpp @@ -22,6 +22,9 @@ LogicalResult ONNXUniqueOpShapeHelper::computeShape() { ONNXUniqueOpAdaptor operandAdaptor(operands, op->getAttrDictionary()); // Get info about X and K operands. Value X = operandAdaptor.getX(); + if (!hasShapeAndRank(X)) { + return failure(); + } int64_t rank = createIE->getShapedTypeRank(X); std::optional optionalAxis = operandAdaptor.getAxis(); // Generate the output dims. diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Unsqueeze.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Unsqueeze.cpp index fa6a46cdc5..3bb049aa69 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Unsqueeze.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Unsqueeze.cpp @@ -33,6 +33,9 @@ LogicalResult ONNXCommonUnsqueezeOpShapeHelper::customComputeShape( typename OP_TYPE::Adaptor operandAdaptor(operands, op->getAttrDictionary()); DimsExpr outputDims; Value data = operandAdaptor.getData(); + if (!hasShapeAndRank(data)) { + return failure(); + } int64_t dataRank = createIE->getShapedTypeRank(data); // Init state.