Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[uncategorized_lowerings] Add lowering for torch.aten.round.decimals #3811

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/torch-mlir/Dialect/Torch/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
std::unique_ptr<OperationPass<ModuleOp>> createInlineGlobalSlotsPass();

std::unique_ptr<OperationPass<func::FuncOp>>
createReduceOpVariantsPass(StringRef extraLibrary);
createReduceOpVariantsPass(StringRef extraLibrary, ArrayRef<std::string> = {});

std::unique_ptr<OperationPass<func::FuncOp>> createMaximizeValueSemanticsPass();

Expand Down
3 changes: 3 additions & 0 deletions include/torch-mlir/Dialect/Torch/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def ReduceOpVariants : Pass<"torch-reduce-op-variants", "func::FuncOp"> {
let options = [
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
"MLIR module for verifying custom op value semantics">,
ListOption<"legalOps", "legal-ops", "std::string",
"Comma separated list of operation names that should be considered legal",
"llvm::cl::ZeroOrMore">
];
let description = [{
Replaces ops with other ops to reduce the number of variants that
Expand Down
5 changes: 4 additions & 1 deletion lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2874,7 +2874,10 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
patterns.add<ConvertAtenDiagEmbedOp>(typeConverter, context);
// Rewrite all special sparse conversions hidden as operators.
target.addDynamicallyLegalOp<OperatorOp>([&](Torch::OperatorOp op) {
return !ConvertSparseOperatorOp::isSparsePrimitive(op.getNameAttr());
// Note: Legality behaviour of torch.operator ops that are not sparse
// primitives should be conserved and not modified by this block.
return !ConvertSparseOperatorOp::isSparsePrimitive(op.getNameAttr()) &&
typeConverter.isLegal(op);
});
patterns.add<ConvertSparseOperatorOp>(typeConverter, context);
}
46 changes: 44 additions & 2 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1557,6 +1557,43 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
threshold);
}

if (auto operatorOp = dyn_cast<OperatorOp>(op)) {
// We do not yet implement lowering for other variants of the op.
if (operatorOp.getNameAttr().str() != "torch.aten.round.decimals")
return nullptr;

// Lower the op in a similar fashion as described here:
// https://github.com/pytorch/pytorch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/torch/_inductor/decomposition.py#L223.
// Note that `aten.round` is converted to `math.roundeven`, we do this
// implicitly here because `aten.round` cannot operate on a single input
// tensor element which is what we get as payload argument.

Location loc = op->getLoc();
Type i64Type = b.getI64Type();

auto torchIntOp = dyn_cast<ConstantIntOp>(
operatorOp.getOperands().back().getDefiningOp());
if (!torchIntOp)
return nullptr;
int64_t numDecimalsArg = torchIntOp.getValue();

Value inputTensorElem = payloadArgs[0];
Type inputTensorElemType = inputTensorElem.getType();

auto numDecimals = b.create<arith::ConstantOp>(
loc, i64Type, IntegerAttr::get(i64Type, numDecimalsArg));
auto const10 = b.create<arith::ConstantOp>(
loc, inputTensorElemType, FloatAttr::get(inputTensorElemType, 10));
auto tenPowDecimals = b.create<math::FPowIOp>(loc, const10, numDecimals);

auto mulTenPowDecimalsinputTensorElem =
b.create<arith::MulFOp>(loc, inputTensorElem, tenPowDecimals);
auto roundOp =
b.create<math::RoundEvenOp>(loc, mulTenPowDecimalsinputTensorElem);
auto res = b.create<arith::DivFOp>(loc, roundOp, tenPowDecimals);
return res;
}

op->emitError("unimplemented lowering in "
"createLinalgPayloadCalculationForElementwiseOp");
return nullptr;
Expand Down Expand Up @@ -1616,9 +1653,14 @@ class ConvertElementwiseOp : public ConversionPattern {
AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenQuantizePerTensorOp, AtenIscloseOp>(op))
AtenQuantizePerTensorOp, AtenIscloseOp, OperatorOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");

if (auto operatorOp = dyn_cast<OperatorOp>(op))
if (operatorOp.getNameAttr().str() != "torch.aten.round.decimals")
return rewriter.notifyMatchFailure(op,
"not a supported elementwise op");

if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

Expand Down Expand Up @@ -3375,7 +3417,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenTrilOp, AtenRemainderScalarOp, AtenRemainderTensorOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenQuantizePerTensorOp, AtenIscloseOp>();
AtenQuantizePerTensorOp, AtenIscloseOp, OperatorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/Torch/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline(

void mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
pm.addNestedPass<func::FuncOp>(
createReduceOpVariantsPass(options.extraLibrary));
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass(
options.extraLibrary, options.backendLegalOps));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
if (options.decompose) {
pm.addNestedPass<func::FuncOp>(
Expand Down Expand Up @@ -161,8 +161,8 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOpsPass());
// Reduce variants of ops to a smaller set of primitives.
pm.addNestedPass<func::FuncOp>(
createReduceOpVariantsPass(options.extraLibrary));
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass(
options.extraLibrary, options.backendLegalOps));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Remove dead global slots.
pm.addPass(createSymbolDCEPass());
Expand Down
24 changes: 14 additions & 10 deletions lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,9 @@ namespace {
struct ReduceOpVariantsPass
: public ReduceOpVariantsBase<ReduceOpVariantsPass> {
ReduceOpVariantsPass() = default;
ReduceOpVariantsPass(StringRef extraLibrary) {
ReduceOpVariantsPass(StringRef extraLibrary, ArrayRef<std::string> legalOps) {
this->extraLibrary = extraLibrary.str();
this->legalOps = legalOps;
}
void runOnOperation() override {
MLIRContext *context = &getContext();
Expand Down Expand Up @@ -439,13 +440,15 @@ struct ReduceOpVariantsPass
target.addIllegalOp<NonValueTensorLiteralOp>();
target.addIllegalOp<AtenBernoulli_FloatOp>();
target.addIllegalOp<AtenArangeStartOutOp>();
target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable,
&specializedNames](Operation *op) {
if (isa<OperatorOp>(op)) {
if (specializedNames.contains(cast<OperatorOp>(op).getNameAttr())) {
return false;
}
}

target.addDynamicallyLegalOp<OperatorOp>([&](OperatorOp op) {
auto opNameAttr = op.getNameAttr();
return llvm::find(legalOps, opNameAttr.str()) != legalOps.end() &&
!specializedNames.contains(opNameAttr);
});

target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable](
Operation *op) {
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||
(isa<OperatorOp>(op) &&
operatorOpHasValueSemantics(cast<OperatorOp>(op),
Expand Down Expand Up @@ -479,6 +482,7 @@ struct ReduceOpVariantsPass
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createReduceOpVariantsPass(StringRef extraLibrary) {
return std::make_unique<ReduceOpVariantsPass>(extraLibrary);
mlir::torch::Torch::createReduceOpVariantsPass(StringRef extraLibrary,
ArrayRef<std::string> legalOps) {
return std::make_unique<ReduceOpVariantsPass>(extraLibrary, legalOps);
}
33 changes: 33 additions & 0 deletions test/Conversion/TorchToLinalg/elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,36 @@ func.func @elementwise_sinh(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3
%0 = torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32>
}

// -----

// CHECK-LABEL: func.func @torch_aten_round_decimals
// CHECK: %[[VAL2:.*]] = linalg.generic
// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32):
// CHECK-NEXT: %[[CONST_64:.*]] = arith.constant
// CHECK-NEXT: %[[CONST_10:.*]] = arith.constant 1.000000e+01
// CHECK-NEXT: %[[VAL4:.*]] = math.fpowi %[[CONST_10]], %[[CONST_64]]
// CHECK-NEXT: %[[VAL5:.*]] = arith.mulf %[[IN]], %[[VAL4]]
// CHECK-NEXT: %[[VAL6:.*]] = math.roundeven %[[VAL5]]
// CHECK-NEXT: %[[VAL7:.*]] = arith.divf %[[VAL6]], %[[VAL4]]
// CHECK-NEXT: linalg.yield %[[VAL7]]
// CHECK: %[[CAST:.*]] = tensor.cast %[[VAL2]]
// CHECK-NEXT: %[[VAL3:.*]] = torch_c.from_builtin_tensor %[[CAST]]
// CHECK-NEXT: return %[[VAL3]]
func.func @torch_aten_round_decimals(%0: !torch.vtensor<[1,1024,1024,3],f32>) -> !torch.vtensor<[1, 1024,1024,3],f32> {
%int0 = torch.constant.int 0
%1 = torch.operator "torch.aten.round.decimals"(%0, %int0) : (!torch.vtensor<[1,1024,1024,3],f32>, !torch.int) -> !torch.vtensor<[1,1024,1024,3],f32>
return %1 : !torch.vtensor<[1, 1024,1024,3],f32>
}

// -----

// Test that unhandled versions of `torch.operator` op are not legalized.
func.func @torch.prims.device_put(%arg0: !torch.vtensor<[1,77],si64>) -> !torch.vtensor<[1,77],si64> {
%cuda3A0 = torch.constant.device "cuda:0"
// expected-error @+1 {{failed to legalize operation 'torch.operator' that was explicitly marked illegal}}
%0 = torch.operator "torch.prims.device_put"(%arg0, %cuda3A0) : (!torch.vtensor<[1,77],si64>, !torch.Device) -> !torch.vtensor<[1,77],si64>
%int4 = torch.constant.int 4
%1 = torch.prims.convert_element_type %0, %int4 : !torch.vtensor<[1,77],si64>, !torch.int -> !torch.vtensor<[1,77],si64>
return %1 : !torch.vtensor<[1,77],si64>
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-function-to-torch-backend-pipeline{backend-legal-ops=aten.square,aten.argmax})' -split-input-file %s | FileCheck %s
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-function-to-torch-backend-pipeline{backend-legal-ops=aten.square,aten.argmax,torch.aten.round.decimals})' -split-input-file %s | FileCheck %s

// CHECK-LABEL: func.func @torch.aten.square
func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
Expand All @@ -25,3 +25,11 @@ func.func @torch.uint8(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[
%1 = torch.aten.reshape %arg0, %0 : !torch.tensor, !torch.list<int> -> !torch.tensor
return %1 : !torch.tensor
}

// Test that "torch.aten.round.decimals" was considered legal after explicitly specifying it in pass options.
// CHECK-LABEL: func.func @torch_aten_round_decimals
func.func @torch_aten_round_decimals(%0: !torch.vtensor<[1,1024,1024,3],f32>) -> !torch.vtensor<[1, 1024,1024,3],f32> {
%int0 = torch.constant.int 0
%1 = torch.operator "torch.aten.round.decimals"(%0, %int0) : (!torch.vtensor<[1,1024,1024,3],f32>, !torch.int) -> !torch.vtensor<[1,1024,1024,3],f32>
return %1 : !torch.vtensor<[1, 1024,1024,3],f32>
}
Loading