Skip to content

Commit

Permalink
[stablehlo] add aten.clamp.Tensor op conversion support (llvm#3185)
Browse files Browse the repository at this point in the history
  • Loading branch information
penguin-wwy authored Apr 19, 2024
1 parent be742a9 commit 6c4f7de
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
36 changes: 36 additions & 0 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1470,6 +1470,41 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
return success();
}

// AtenClampTensorOp
template <>
LogicalResult ConvertAtenOp<AtenClampTensorOp>::matchAndRewrite(
AtenClampTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputType = cast<RankedTensorType>(input.getType());
auto inputElemType = inputType.getElementType();
Value minValue = adaptor.getMin();
Value maxValue = adaptor.getMax();
auto minIsNotNone = checkNotNone(rewriter, op, minValue);
auto maxIsNotNone = checkNotNone(rewriter, op, maxValue);
if (failed(minIsNotNone) && failed(maxIsNotNone)) {
return rewriter.notifyMatchFailure(
op, "this op should be folded as its `min` and `max` both are none");
} else if (failed(minIsNotNone)) {
auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter);
if (failed(minInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to generate min value of dtype");
}
minValue = *minInfo;
} else if (failed(maxIsNotNone)) {
auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter);
if (failed(maxInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to generate max value of dtype");
}
maxValue = *maxInfo;
}
rewriter.replaceOpWithNewOp<stablehlo::ClampOp>(op, minValue, input,
maxValue);
return success();
}

// AtenArangeStartStepOp
// aten.arange.start_step = range(ceil((end-start)/step)) * step + start.
template <>
Expand Down Expand Up @@ -1906,6 +1941,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(

INSERT_ATENOP_PATTERN(AtenCatOp);
INSERT_ATENOP_PATTERN(AtenClampOp);
INSERT_ATENOP_PATTERN(AtenClampTensorOp);
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);

INSERT_ATENOP_PATTERN(AtenBatchNormOp);
Expand Down
4 changes: 4 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,11 @@
"ElementwiseCeilModule_basic",
"ElementwiseClampMaxModule_basic",
"ElementwiseClampMinModule_basic",
"ElementwiseClampMinTensorFloatModule_basic",
"ElementwiseClampMinTensorIntModule_basic",
"ElementwiseClampModule_basic",
"ElementwiseClampTensorFloatModule_basic",
"ElementwiseClampTensorIntModule_basic",
"ElementwiseClampTensorInt8Module_basic",
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
"ElementwiseCloneContiguousModule_basic",
Expand Down

0 comments on commit 6c4f7de

Please sign in to comment.