Skip to content

Commit

Permalink
[Torch] Add decompose of AtenToPrimDeviceOp (llvm#3131)
Browse files Browse the repository at this point in the history
As device information isn't relevant to torch-mlir
  • Loading branch information
Xinyu Yang authored Apr 10, 2024
1 parent 8951a8c commit 5eb0cf9
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 2 deletions.
10 changes: 9 additions & 1 deletion lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10529,8 +10529,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.to.prim_Device\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<Device>, %arg2: !torch.optional<int>, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" %1 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
" torch.prim.If.yield %0#1 : !torch.int\n"
" } else {\n"
" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %3 : !torch.int\n"
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.transpose.int\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
Expand Down
28 changes: 28 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5450,6 +5450,33 @@ class DecomposeAtenToDtypeLayoutOp
};
} // namespace

namespace {
// Decompose `aten.to.prim_Device` op into `aten.to.dtype` op.
class DecomposeAtenToPrimDeviceOp
: public OpRewritePattern<AtenToPrimDeviceOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenToPrimDeviceOp op,
PatternRewriter &rewriter) const override {

// Device information isn't relevant to torch-mlir, so we can drop that info
// here.
auto loc = op.getLoc();
Value constNone = rewriter.create<ConstantNoneOp>(loc);

Value dtype = op.getDtype();
if (dtype.getType().template isa<Torch::NoneType>()) {
dtype = rewriter.create<Torch::PrimDtypeOp>(loc, op.getSelf());
}
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), op.getSelf(),
dtype, op.getNonBlocking(),
op.getCopy(), constNone);

return success();
}
};
} // namespace

namespace {
// Decompose `aten.to.device` op into `aten.to.dtype` op.
class DecomposeAtenToDeviceOp : public OpRewritePattern<AtenToDeviceOp> {
Expand Down Expand Up @@ -7559,6 +7586,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenPadOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToPrimDeviceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenPreluOp>();
target.addIllegalOp<AtenToDtypeLayoutOp>();
target.addIllegalOp<AtenToDeviceOp>();
target.addIllegalOp<AtenToPrimDeviceOp>();
target.addIllegalOp<AtenAdaptiveAvgPool1dOp>();
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
target.addIllegalOp<AtenClampMinOp>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2755,7 +2755,9 @@ def aten〇t〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
@check_dtype_function(_check_tensors_with_the_same_dtype(1, tensor_device="meta", device=torch.device("meta")))
def aten〇to〇prim_Device〡dtype(self_rank_dtype: Tuple[int, int], device: Optional[device], dtype: Optional[int] = None, non_blocking: bool = False, copy: bool = False) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
if dtype is None:
return self_dtype
return dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3)], dim0=0, dim1=1))
def aten〇transpose〇int〡dtype(self_rank_dtype: Tuple[int, int], dim0: int, dim1: int) -> int:
Expand Down

0 comments on commit 5eb0cf9

Please sign in to comment.