Skip to content

Commit

Permalink
Return a failure instead of crashing if shape inference can not be ru…
Browse files Browse the repository at this point in the history
…n because of unraked operand types (#3023)

Signed-off-by: Jonas Rickert <[email protected]>
  • Loading branch information
jorickert authored Dec 9, 2024
1 parent 32d2c8b commit 49fc9c1
Show file tree
Hide file tree
Showing 22 changed files with 82 additions and 8 deletions.
3 changes: 3 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Math/DFT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ LogicalResult ONNXGenericDFTOpShapeHelper<OP_TYPE>::customComputeShape(
// Get info about input data operand.
Value input = operandAdaptor.getInput();
// Get the rank to compensate for N dimensions.
if (!hasShapeAndRank(input)) {
return failure();
}
int64_t rank = createIE->getShapedTypeRank(input);

// Check if the dimension for axis is a literal and in range.
Expand Down
3 changes: 3 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Math/MatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ LogicalResult ONNXGenericMatMulOpShapeHelper<OP_TYPE>::computeShape() {
std::tie(A, B) = matMulInputs(operandAdaptor);

// Size all the arrays to padded length.
if (!hasShapeAndRank(A) || !hasShapeAndRank(B)) {
return failure();
}
uint64_t aRank = createIE->getShapedTypeRank(A);
uint64_t bRank = createIE->getShapedTypeRank(B);
int paddedRank = std::max(aRank, bRank);
Expand Down
9 changes: 8 additions & 1 deletion src/Dialect/ONNX/ONNXOps/Math/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ LogicalResult ONNXGenericReductionOpShapeHelper<OP_TYPE>::customComputeShape(
DimsExpr &axes, int noopWithEmptyAxes) {
typename OP_TYPE::Adaptor operandAdaptor(operands, op->getAttrDictionary());
Value data = operandAdaptor.getData();
if (!hasShapeAndRank(data)) {
return failure();
}
int64_t rank = createIE->getShapedTypeRank(data);
// Normalize the axes: at present, we only support compile time axes, but
// with keep_dim on, it might not be too difficult to generate the code.
Expand Down Expand Up @@ -104,7 +107,11 @@ LogicalResult ONNXGenericReductionOpShapeHelper<OP_TYPE>::computeShape() {
createIE->getIntFromArrayAsSymbols(operandAdaptor.getAxes(), axes);
} else {
// When the axis is dynamic, try to infer the rank of output tensor
int64_t dataRank = createIE->getShapedTypeRank(operandAdaptor.getData());
const auto data = operandAdaptor.getData();
if (!hasShapeAndRank(data)) {
return failure();
}
int64_t dataRank = createIE->getShapedTypeRank(data);
int64_t axlesSize = createIE->getArraySize(operandAdaptor.getAxes());
if (!operandAdaptor.getKeepdims() && axlesSize < 0 /*undef shape*/) {
// Even though we did not compute the shape in ShapeHelper, return
Expand Down
3 changes: 3 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Math/TopK.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ LogicalResult ONNXTopKOpShapeHelper::computeShape() {
// Get info about X and K operands.
Value X = operandAdaptor.getX();
Value K = operandAdaptor.getK();
if (!hasShapeAndRank(X)) {
return failure();
}
int64_t rank = createIE->getShapedTypeRank(X);

// Axis to compute TopK.
Expand Down
3 changes: 3 additions & 0 deletions src/Dialect/ONNX/ONNXOps/NN/Conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ LogicalResult ONNXConvTransposeOpShapeHelper::computeShape() {
Value wValue = operandAdaptor.getW();

// Basic information.
if (!hasShapeAndRank(xValue)) {
return failure();
}
int64_t rank = createIE->getShapedTypeRank(xValue);
int64_t spatialOffset = 2;
int64_t spatialRank = rank - spatialOffset;
Expand Down
3 changes: 3 additions & 0 deletions src/Dialect/ONNX/ONNXOps/NN/NNHelper.cpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ LogicalResult ONNXGenericPoolOpShapeHelper<OP_TYPE>::customComputeShape(
std::optional<ArrayAttr> strideOpt, std::optional<ArrayAttr> dilationOpt,
bool hasFilter, bool ceilMode) {
// Basic information.
if(!hasShapeAndRank(xValue)) {
return failure();
}
int64_t rank = createIE->getShapedTypeRank(xValue);
int64_t spatialOffset = 2;
int64_t spatialRank = rank - spatialOffset;
Expand Down
10 changes: 7 additions & 3 deletions src/Dialect/ONNX/ONNXOps/NN/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,18 @@ LogicalResult ONNXGenericGlobalPoolOpShapeHelper<OP_TYPE>::computeShape() {
template <>
LogicalResult ONNXMaxRoiPoolOpShapeHelper::computeShape() {
ONNXMaxRoiPoolOpAdaptor operandAdaptor(operands, op->getAttrDictionary());

IndexExpr channel = createIE->getShapeAsDim(operandAdaptor.getX(), 1);
uint64_t roisRank = createIE->getShapedTypeRank(operandAdaptor.getRois());

const auto rois = operandAdaptor.getRois();
if (!hasShapeAndRank(rois)) {
return failure();
}
uint64_t roisRank = createIE->getShapedTypeRank(rois);
if (roisRank != 2)
return op->emitError("rois rank is expected to be 2d");

// 2d tensor: (num_rois, 5)
IndexExpr numRois = createIE->getShapeAsDim(operandAdaptor.getRois(), 0);
IndexExpr numRois = createIE->getShapeAsDim(rois, 0);
DimsExpr pooledDims;
createIE->getIntFromArrayAsLiterals(
operandAdaptor.getPooledShape(), pooledDims);
Expand Down
5 changes: 5 additions & 0 deletions src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ LogicalResult ONNXBroadcastOpShapeHelper::customComputeShape(
DimsExpr dimsExpr;
uint64_t numOfInputs = initialOperands.size();

if (!llvm::all_of(initialOperands,
[](Value initalOperand) { return hasShapeAndRank(initalOperand); })) {
return failure();
}

// Compute rank of the output. Rank of the output is the maximum rank of all
// initial operands.
uint64_t additionalOperRank =
Expand Down
3 changes: 3 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Tensor/Compress.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ LogicalResult ONNXCompressOpShapeHelper::computeShape() {
ONNXCompressOpAdaptor operandAdaptor(operands);
Value input = operandAdaptor.getInput();
Value cond = operandAdaptor.getCondition();
if (!hasShapeAndRank(input)) {
return failure();
}
int64_t inputRank = createIE->getShapedTypeRank(input);
createIE->assertHasShapeAndRank(cond);
std::optional<int64_t> optionalAxis = compressOp.getAxis();
Expand Down
3 changes: 3 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Tensor/DepthToSpace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ LogicalResult ONNXDepthToSpaceOpShapeHelper::computeShape() {
ONNXDepthToSpaceOp depthOp = llvm::cast<ONNXDepthToSpaceOp>(op);
ONNXDepthToSpaceOpAdaptor operandAdaptor(operands);
Value input = operandAdaptor.getInput();
if (!hasShapeAndRank(input)) {
return failure();
}
int64_t inputRank = createIE->getShapedTypeRank(input);
assert(inputRank == 4 && "Unexpected input tensor rank");
int64_t blocksize = depthOp.getBlocksize();
Expand Down
6 changes: 5 additions & 1 deletion src/Dialect/ONNX/ONNXOps/Tensor/NonZero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ namespace onnx_mlir {
template <>
LogicalResult ONNXNonZeroOpShapeHelper::computeShape() {
ONNXNonZeroOpAdaptor operandAdaptor(operands);
int64_t xRank = createIE->getShapedTypeRank(operandAdaptor.getX());
auto x = operandAdaptor.getX();
if (!hasShapeAndRank(x)) {
return failure();
}
int64_t xRank = createIE->getShapedTypeRank(x);
// Cannot refine shape as we may otherwise loose the dynamic dim.
return setOutputDimsFromLiterals(
{xRank, ShapedType::kDynamic}, 0, /*refineShape*/ false);
Expand Down
3 changes: 3 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Tensor/OneHot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ LogicalResult ONNXOneHotOpShapeHelper::computeShape() {
ONNXOneHotOp oneHotOp = llvm::cast<ONNXOneHotOp>(op);
ONNXOneHotOpAdaptor operandAdaptor(operands);
Value indices = operandAdaptor.getIndices();
if (!hasShapeAndRank(indices)) {
return failure();
}
int64_t indicesRank = createIE->getShapedTypeRank(indices);

// Axis is a required attribute and should have default value of -1.
Expand Down
3 changes: 3 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ LogicalResult ONNXPadOpShapeHelper::computeShape() {
DimsExpr outputDims;

// Get info about input data operand.
if (!hasShapeAndRank(dataOperand)) {
return failure();
}
uint64_t dataRank = createIE->getShapedTypeRank(dataOperand);

// Initialize context and results (pads & output)
Expand Down
8 changes: 6 additions & 2 deletions src/Dialect/ONNX/ONNXOps/Tensor/Resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,13 @@ LogicalResult ONNXResizeOpShapeHelper::computeShape() {
ONNXResizeOpAdaptor operandAdaptor(operands, cast<ONNXResizeOp>(op));
if (operandAdaptor.getAxes().has_value())
return op->emitOpError("axes are unsupported");
uint64_t rank = createIE->getShapedTypeRank(operandAdaptor.getX());
const auto x = operandAdaptor.getX();
if (!hasShapeAndRank(x)) {
return failure();
}
uint64_t rank = createIE->getShapedTypeRank(x);
DimsExpr inputDims, outputDims;
createIE->getShapeAsDims(operandAdaptor.getX(), inputDims);
createIE->getShapeAsDims(x, inputDims);
bool scalesIsAbsent = isAbsent(operandAdaptor.getScales());
if (!scalesIsAbsent) {
// Read and save scales as float.
Expand Down
3 changes: 3 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Tensor/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ LogicalResult ONNXShapeOpShapeHelper::computeShape() {
Value data = operandAdaptor.getData();

// Compute and store start/end in ONNXShapeOpShapeHelper object.
if (!hasShapeAndRank(data)) {
return failure();
}
int64_t rank = createIE->getShapedTypeRank(data);
start = shapeOp.getStart();
start = normalizeClampedPerSpec(start, rank);
Expand Down
4 changes: 3 additions & 1 deletion src/Dialect/ONNX/ONNXOps/Tensor/SpaceToDepth.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ LogicalResult ONNXSpaceToDepthOpShapeHelper::computeShape() {
Value input = operandAdaptor.getInput();
int64_t blocksize = operandAdaptor.getBlocksize();
assert(blocksize > 0 && "blocksize should be strictly positive");

if (!hasShapeAndRank(input)) {
return failure();
}
int64_t inputRank = createIE->getShapedTypeRank(input);
assert(inputRank == 4 && "Unexpected input tensor rank");

Expand Down
3 changes: 3 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Tensor/Split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ LogicalResult ONNXCommonSplitOpShapeHelper<OP_TYPE>::customComputeShape(

unsigned int numOfResults = splitOp.getNumResults();
Value input = operandAdaptor.getInput();
if (!hasShapeAndRank(input)) {
return failure();
}
int64_t rank = createIE->getShapedTypeRank(input);

// Checking value of axis parameter.
Expand Down
3 changes: 3 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Tensor/Squeeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ LogicalResult ONNXCommonSqueezeOpShapeHelper<OP_TYPE>::customComputeShape(
typename OP_TYPE::Adaptor operandAdaptor(operands, op->getAttrDictionary());
DimsExpr outputDims;
Value data = operandAdaptor.getData();
if (!hasShapeAndRank(data)) {
return failure();
}
int64_t dataRank = createIE->getShapedTypeRank(data);

// Init state.
Expand Down
3 changes: 3 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Tensor/Tile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ LogicalResult ONNXTileOpShapeHelper::computeShape() {
ONNXTileOpAdaptor operandAdaptor(operands);
// Get info about input data operand.
Value input = operandAdaptor.getInput();
if (!hasShapeAndRank(input)) {
return failure();
}
int64_t inputRank = createIE->getShapedTypeRank(input);
Value repeats = operandAdaptor.getRepeats();
// Compute outputDims
Expand Down
3 changes: 3 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Tensor/Transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ LogicalResult ONNXTransposeOpShapeHelper::computeShape() {
ONNXTransposeOp transposeOp = llvm::cast<ONNXTransposeOp>(op);

Value data = operandAdaptor.getData();
if (!hasShapeAndRank(data)) {
return failure();
}
auto rank = createIE->getShapedTypeRank(data);

// Transposition which handles the default case of
Expand Down
3 changes: 3 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Tensor/Unique.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ LogicalResult ONNXUniqueOpShapeHelper::computeShape() {
ONNXUniqueOpAdaptor operandAdaptor(operands, op->getAttrDictionary());
// Get info about X and K operands.
Value X = operandAdaptor.getX();
if (!hasShapeAndRank(X)) {
return failure();
}
int64_t rank = createIE->getShapedTypeRank(X);
std::optional<int64_t> optionalAxis = operandAdaptor.getAxis();
// Generate the output dims.
Expand Down
3 changes: 3 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Tensor/Unsqueeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ LogicalResult ONNXCommonUnsqueezeOpShapeHelper<OP_TYPE>::customComputeShape(
typename OP_TYPE::Adaptor operandAdaptor(operands, op->getAttrDictionary());
DimsExpr outputDims;
Value data = operandAdaptor.getData();
if (!hasShapeAndRank(data)) {
return failure();
}
int64_t dataRank = createIE->getShapedTypeRank(data);

// Init state.
Expand Down

0 comments on commit 49fc9c1

Please sign in to comment.