Skip to content

Commit

Permalink
[stablehlo] add aten.fmod.Tensor op conversion support (llvm#3198)
Browse files Browse the repository at this point in the history
  • Loading branch information
penguin-wwy authored Apr 21, 2024
1 parent ea0ecb6 commit b6b0160
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
32 changes: 32 additions & 0 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1833,6 +1833,37 @@ LogicalResult ConvertAtenOp<AtenRemainderTensorOp>::matchAndRewrite(
return success();
}

// AtenFmodTensorOp
// torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b
template <>
LogicalResult ConvertAtenOp<AtenFmodTensorOp>::matchAndRewrite(
AtenFmodTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op->getLoc();
Value lhs = adaptor.getSelf();
Value rhs = adaptor.getOther();

auto resultType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType);
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType);

stablehlo::MulOp mul;
auto div = rewriter.create<stablehlo::DivOp>(loc, lhs, rhs);
if (isa<mlir::FloatType>(resultType.getElementType())) {
// rounding mode is trunc
auto sign = rewriter.create<stablehlo::SignOp>(loc, div);
auto abs = rewriter.create<stablehlo::AbsOp>(loc, div);
auto floor = rewriter.create<stablehlo::FloorOp>(loc, abs);
auto trunc = rewriter.create<stablehlo::MulOp>(loc, sign, floor);
mul = rewriter.create<stablehlo::MulOp>(loc, trunc, rhs);
} else {
mul = rewriter.create<stablehlo::MulOp>(loc, div, rhs);
}
rewriter.replaceOpWithNewOp<stablehlo::SubtractOp>(op, lhs, mul);
return success();
}

void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) {
Expand Down Expand Up @@ -1976,6 +2007,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenFillScalarOp);
INSERT_ATENOP_PATTERN(AtenFlipOp);
INSERT_ATENOP_PATTERN(AtenRemainderTensorOp);
INSERT_ATENOP_PATTERN(AtenFmodTensorOp);
#undef INSERT_ATENOP_PATTERN

#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \
Expand Down
6 changes: 3 additions & 3 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,9 +650,6 @@
"ElementwiseErfIntModule_basic",
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
"ElementwiseFmodTensor_Float_basic",
"ElementwiseFmodTensor_Int_Float_basic",
"ElementwiseFmodTensor_Int_basic",
"ElementwiseLog10IntModule_basic",
"ElementwiseLog10Module_basic",
"ElementwiseLog2IntModule_basic",
Expand Down Expand Up @@ -1056,6 +1053,9 @@
"ElementwiseExpModule_basic",
"ElementwiseFloorIntModule_basic",
"ElementwiseFloorModule_basic",
"ElementwiseFmodTensor_Float_basic",
"ElementwiseFmodTensor_Int_Float_basic",
"ElementwiseFmodTensor_Int_basic",
"ElementwiseGeluApproximateTanhModule_basic",
"ElementwiseGeluModule_basic",
"ElementwiseLeakyReluStaticModule_basic",
Expand Down

0 comments on commit b6b0160

Please sign in to comment.