Skip to content

Commit

Permalink
[Torch] Add folder for AtenIntOp, AtenFloatOp (llvm#3189)
Browse files Browse the repository at this point in the history
See unit test below:
```
// CHECK-LABEL:   func.func @torch.aten.tensor.float(
// CHECK-NEXT: torch.vtensor.literal(dense<1.000000e+01> : tensor<f32>) : !torch.vtensor<[],f32>
func.func @torch.aten.tensor.float() -> !torch.vtensor<[],f32> {
  %none = torch.constant.none
  %false = torch.constant.bool false
  %float1.000000e01 = torch.constant.float 1.000000e+01
  %67 = torch.aten.tensor.float %float1.000000e01, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],f32>
  return %67 : !torch.vtensor<[],f32>
}

// CHECK-LABEL:   func.func @torch.aten.tensor.int(
// CHECK-NEXT: torch.vtensor.literal(dense<45> : tensor<si32>) : !torch.vtensor<[],si32>
func.func @torch.aten.tensor.int() -> !torch.vtensor<[],si32> {
  %none = torch.constant.none
  %false = torch.constant.bool false 
  %int45 = torch.constant.int 45
  %67 = torch.aten.tensor.int %int45, %none, %none, %false : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si32>
  return %67 : !torch.vtensor<[],si32>
}

```
  • Loading branch information
Xinyu Yang authored Apr 19, 2024
1 parent 5a98c72 commit 790a697
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 3 deletions.
2 changes: 2 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9092,6 +9092,7 @@ def Torch_AtenTensorIntOp : Torch_Op<"aten.tensor.int", [
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenScalarTensorOp : Torch_Op<"aten.scalar_tensor", [
Expand Down Expand Up @@ -11577,6 +11578,7 @@ def Torch_AtenTensorFloatOp : Torch_Op<"aten.tensor.float", [
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [
Expand Down
44 changes: 43 additions & 1 deletion lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3747,6 +3747,8 @@ OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) {
// If a torch.aten.tensor op is initialized by a list with a constant, single
// element, fold it into a torch.vtensor.literal
auto resultTy = dyn_cast<ValueTensorType>(getType());
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
return nullptr;
Type eTy = resultTy.getDtype();
ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy);

Expand All @@ -3761,7 +3763,47 @@ OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) {
}

//===----------------------------------------------------------------------===//
// AtenTensorOp
// AtenTensorIntOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenTensorIntOp::fold(FoldAdaptor adaptor) {
auto resultTy = dyn_cast<ValueTensorType>(getType());
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
return nullptr;
Type eTy = resultTy.getDtype();
ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy);

int64_t data;
if (matchPattern(getT(), m_TorchConstantInt(&data))) {
Attribute attribute = IntegerAttr::get(eTy, data);
return DenseElementsAttr::get(shapedTy, attribute);
}

return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenTensorFloatOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenTensorFloatOp::fold(FoldAdaptor adaptor) {
auto resultTy = dyn_cast<ValueTensorType>(getType());
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
return nullptr;
Type eTy = resultTy.getDtype();
ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy);

double data;
if (matchPattern(getT(), m_TorchConstantFloat(&data))) {
Attribute attribute = FloatAttr::get(eTy, data);
return DenseElementsAttr::get(shapedTy, attribute);
}

return nullptr;
}

//===----------------------------------------------------------------------===//
// Aten_ShapeAsTensorOp
//===----------------------------------------------------------------------===//

OpFoldResult Aten_ShapeAsTensorOp::fold(FoldAdaptor adaptor) {
Expand Down
3 changes: 3 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,7 @@
"TModuleRank0_basic",
"TModuleRank1_basic",
"TModuleRank2_basic",
"TensorFloatModule_basic",
"TensorIntModule_basic",
"TensorLiteralModule_basic",
"TensorOpaqueLiteralModule_basic",
Expand Down Expand Up @@ -1838,6 +1839,8 @@
"TModuleRank1_basic",
"TModuleRank2_basic",
"TanhBackward_basic",
"TensorFloatModule_basic",
"TensorIntModule_basic",
"TensorLiteralModule_basic",
"TensorOpaqueLiteralModule_basic",
"TensorsConcatNegativeDimStaticModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)", has_folder=True)
emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)")
emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)")
emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)", has_folder=True)
emit("aten::scalar_tensor : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)", has_folder=True)
emit("aten::isnan : (Tensor) -> (Tensor)")
Expand Down Expand Up @@ -691,7 +691,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit_with_mutating_variants("aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)")
emit("aten::IntImplicit : (Tensor) -> (int)", has_canonicalizer=True)
emit("aten::FloatImplicit : (Tensor) -> (float)", has_canonicalizer=True)
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)", has_folder=True)
emit("aten::Int.Tensor : (Tensor) -> (int)", has_canonicalizer=True)
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)")
Expand Down
20 changes: 20 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1481,6 +1481,26 @@ func.func @torch.aten.tensor$one_elem() -> (!torch.vtensor<[1],si64>) {
return %67 : !torch.vtensor<[1],si64>
}

// CHECK-LABEL: func.func @torch.aten.tensor.float(
// CHECK-NEXT: torch.vtensor.literal(dense<1.000000e+01> : tensor<f32>) : !torch.vtensor<[],f32>
func.func @torch.aten.tensor.float() -> !torch.vtensor<[],f32> {
%none = torch.constant.none
%false = torch.constant.bool false
%float1.000000e01 = torch.constant.float 1.000000e+01
%67 = torch.aten.tensor.float %float1.000000e01, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],f32>
return %67 : !torch.vtensor<[],f32>
}

// CHECK-LABEL: func.func @torch.aten.tensor.int(
// CHECK-NEXT: torch.vtensor.literal(dense<45> : tensor<si32>) : !torch.vtensor<[],si32>
func.func @torch.aten.tensor.int() -> !torch.vtensor<[],si32> {
%none = torch.constant.none
%false = torch.constant.bool false
%int45 = torch.constant.int 45
%67 = torch.aten.tensor.int %int45, %none, %none, %false : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si32>
return %67 : !torch.vtensor<[],si32>
}

// CHECK-LABEL: func.func @torch.aten.to.dtype$same_dtype(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<*,f32>
Expand Down

0 comments on commit 790a697

Please sign in to comment.