Skip to content

Commit

Permalink
[MLIR][TORCH] Add OnnxToTorch lowering for Einsum op (llvm#3117)
Browse files Browse the repository at this point in the history
Signed-Off By: Vivek Khandelwal <[email protected]>
  • Loading branch information
vivekkhandelwal1 authored Apr 8, 2024
1 parent 84c24e5 commit 1d6e4c3
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 7 deletions.
26 changes: 26 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1951,4 +1951,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
/*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal);
return success();
});
patterns.onOp(
"Einsum", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
SmallVector<Value> tensors;
std::string equation;
if (binder.tensorOperands(tensors, binder.op->getNumOperands()) ||
binder.customOpNameStringAttr(equation, "equation") ||
binder.tensorResultType(resultType))
return failure();
Type listElemType =
tensors[0]
.getType()
.cast<Torch::BaseTensorType>()
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
binder.op->getLoc(), listType, tensors);
Value cstEquation = rewriter.create<Torch::ConstantStrOp>(
binder.getLoc(), rewriter.getType<Torch::StringType>(),
rewriter.getStringAttr(equation));
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
rewriter.replaceOpWithNewOp<Torch::AtenEinsumOp>(
binder.op, resultType, cstEquation, tensorList, /*path=*/cstNone);
return success();
});
}
7 changes: 0 additions & 7 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1977,13 +1977,6 @@
# Failure - onnx_lowering: onnx.Clip
"NormalizeModule_basic",

# Failure - onnx_lowering: onnx.Einsum
"EinsumStaticContractRhsModule_basic",
"EinsumStaticFourDimensionModule_basic",
"EinsumStaticModule_basic",
"EinsumStaticWithEllipsisSlicingModule_basic",
"EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic",

# Failure - onnx_lowering: onnx.MaxPool
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
Expand Down
48 changes: 48 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1743,3 +1743,51 @@ func.func @test_compress_neg_axis(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.
%0 = torch.operator "onnx.Compress"(%arg0, %cst) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,2,4],f32>
return %0 : !torch.vtensor<[2,2,4],f32>
}

// -----

// CHECK-LABEL: func.func @test_einsum_batch_diagonal
func.func @test_einsum_batch_diagonal(%arg0: !torch.vtensor<[3,5,5],f64>) -> !torch.vtensor<[3,5],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[3,5,5],f64>) -> !torch.list<vtensor>
// CHECK: %[[EQUATION:.*]] = torch.constant.str "...ii ->...i"
// CHECK: %[[PATH:.*]] = torch.constant.none
// CHECK: torch.aten.einsum %[[EQUATION]], %[[TENSORS]], %[[PATH]] : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[3,5],f64>
%0 = torch.operator "onnx.Einsum"(%arg0) {torch.onnx.equation = "...ii ->...i"} : (!torch.vtensor<[3,5,5],f64>) -> !torch.vtensor<[3,5],f64>
return %0 : !torch.vtensor<[3,5],f64>
}

// -----

// CHECK-LABEL: func.func @test_einsum_batch_matmul
func.func @test_einsum_batch_matmul(%arg0: !torch.vtensor<[5,2,3],f64>, %arg1: !torch.vtensor<[5,3,4],f64>) -> !torch.vtensor<[5,2,4],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[5,2,3],f64>, !torch.vtensor<[5,3,4],f64>) -> !torch.list<vtensor>
// CHECK: %[[EQUATION:.*]] = torch.constant.str "bij, bjk -> bik"
// CHECK: %[[PATH:.*]] = torch.constant.none
// CHECK: torch.aten.einsum %[[EQUATION]], %[[TENSORS]], %[[PATH]] : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[5,2,4],f64>
%0 = torch.operator "onnx.Einsum"(%arg0, %arg1) {torch.onnx.equation = "bij, bjk -> bik"} : (!torch.vtensor<[5,2,3],f64>, !torch.vtensor<[5,3,4],f64>) -> !torch.vtensor<[5,2,4],f64>
return %0 : !torch.vtensor<[5,2,4],f64>
}

// -----

// CHECK-LABEL: func.func @test_einsum_sum
func.func @test_einsum_sum(%arg0: !torch.vtensor<[3,4],f64>) -> !torch.vtensor<[3],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[3,4],f64>) -> !torch.list<vtensor>
// CHECK: %[[EQUATION:.*]] = torch.constant.str "ij->i"
// CHECK: %[[PATH:.*]] = torch.constant.none
// CHECK: torch.aten.einsum %[[EQUATION]], %[[TENSORS]], %[[PATH]] : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[3],f64>
%0 = torch.operator "onnx.Einsum"(%arg0) {torch.onnx.equation = "ij->i"} : (!torch.vtensor<[3,4],f64>) -> !torch.vtensor<[3],f64>
return %0 : !torch.vtensor<[3],f64>
}

// -----

// CHECK-LABEL: func.func @test_einsum_transpose
func.func @test_einsum_transpose(%arg0: !torch.vtensor<[3,4],f64>) -> !torch.vtensor<[4,3],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[3,4],f64>) -> !torch.list<vtensor>
// CHECK: %[[EQUATION:.*]] = torch.constant.str "ij->ji"
// CHECK: %[[PATH:.*]] = torch.constant.none
// CHECK: torch.aten.einsum %[[EQUATION]], %[[TENSORS]], %[[PATH]] : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[4,3],f64>
%0 = torch.operator "onnx.Einsum"(%arg0) {torch.onnx.equation = "ij->ji"} : (!torch.vtensor<[3,4],f64>) -> !torch.vtensor<[4,3],f64>
return %0 : !torch.vtensor<[4,3],f64>
}

0 comments on commit 1d6e4c3

Please sign in to comment.