From e227b959212791cafd6a0d02931ef4663e1ae46b Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Thu, 12 Dec 2024 17:34:40 -0800 Subject: [PATCH] Add rules for deprecated AMP APIs --- .../deprecated_symbols/checker/amp.py | 10 +++++++ .../deprecated_symbols/checker/amp.txt | 6 +++++ .../deprecated_symbols/codemod/amp.in.py | 11 ++++++++ .../deprecated_symbols/codemod/amp.out.py | 11 ++++++++ torchfix/deprecated_symbols.yaml | 16 ++++++++++++ .../visitors/deprecated_symbols/__init__.py | 17 +++++++----- torchfix/visitors/deprecated_symbols/amp.py | 26 +++++++++++++++++++ 7 files changed, 91 insertions(+), 6 deletions(-) create mode 100644 tests/fixtures/deprecated_symbols/checker/amp.py create mode 100644 tests/fixtures/deprecated_symbols/checker/amp.txt create mode 100644 tests/fixtures/deprecated_symbols/codemod/amp.in.py create mode 100644 tests/fixtures/deprecated_symbols/codemod/amp.out.py create mode 100644 torchfix/visitors/deprecated_symbols/amp.py diff --git a/tests/fixtures/deprecated_symbols/checker/amp.py b/tests/fixtures/deprecated_symbols/checker/amp.py new file mode 100644 index 0000000..278ac39 --- /dev/null +++ b/tests/fixtures/deprecated_symbols/checker/amp.py @@ -0,0 +1,10 @@ +import torch + +torch.cuda.amp.autocast() +torch.cuda.amp.custom_fwd() +torch.cuda.amp.custom_bwd() + +dtype = torch.float32 +maybe_autocast = torch.cpu.amp.autocast() +maybe_autocast = torch.cpu.amp.autocast(dtype=torch.bfloat16) +maybe_autocast = torch.cpu.amp.autocast(dtype=dtype) diff --git a/tests/fixtures/deprecated_symbols/checker/amp.txt b/tests/fixtures/deprecated_symbols/checker/amp.txt new file mode 100644 index 0000000..71939e9 --- /dev/null +++ b/tests/fixtures/deprecated_symbols/checker/amp.txt @@ -0,0 +1,6 @@ +3:1 TOR101 Use of deprecated function torch.cuda.amp.autocast +4:1 TOR101 Use of deprecated function torch.cuda.amp.custom_fwd +5:1 TOR101 Use of deprecated function torch.cuda.amp.custom_bwd +8:18 TOR101 Use of deprecated function torch.cpu.amp.autocast +9:18 TOR101 Use of deprecated function torch.cpu.amp.autocast +10:18 TOR101 Use of deprecated function torch.cpu.amp.autocast diff --git a/tests/fixtures/deprecated_symbols/codemod/amp.in.py b/tests/fixtures/deprecated_symbols/codemod/amp.in.py new file mode 100644 index 0000000..6a1227c --- /dev/null +++ b/tests/fixtures/deprecated_symbols/codemod/amp.in.py @@ -0,0 +1,11 @@ +import torch + +dtype = torch.float32 + +maybe_autocast = torch.cuda.amp.autocast() +maybe_autocast = torch.cuda.amp.autocast(dtype=torch.bfloat16) +maybe_autocast = torch.cuda.amp.autocast(dtype=dtype) + +maybe_autocast = torch.cpu.amp.autocast() +maybe_autocast = torch.cpu.amp.autocast(dtype=torch.bfloat16) +maybe_autocast = torch.cpu.amp.autocast(dtype=dtype) diff --git a/tests/fixtures/deprecated_symbols/codemod/amp.out.py b/tests/fixtures/deprecated_symbols/codemod/amp.out.py new file mode 100644 index 0000000..da39d0a --- /dev/null +++ b/tests/fixtures/deprecated_symbols/codemod/amp.out.py @@ -0,0 +1,11 @@ +import torch + +dtype = torch.float32 + +maybe_autocast = torch.amp.autocast("cuda") +maybe_autocast = torch.amp.autocast("cuda", dtype=torch.bfloat16) +maybe_autocast = torch.amp.autocast("cuda", dtype=dtype) + +maybe_autocast = torch.amp.autocast("cpu") +maybe_autocast = torch.amp.autocast("cpu", dtype=torch.bfloat16) +maybe_autocast = torch.amp.autocast("cpu", dtype=dtype) diff --git a/torchfix/deprecated_symbols.yaml b/torchfix/deprecated_symbols.yaml index e219b2c..eaa5119 100644 --- a/torchfix/deprecated_symbols.yaml +++ b/torchfix/deprecated_symbols.yaml @@ -83,6 +83,22 @@ remove_pr: reference: https://github.com/pytorch-labs/torchfix#torchbackendscudasdp_kernel +- name: torch.cuda.amp.autocast + deprecate_pr: TBA + remove_pr: + +- name: torch.cuda.amp.custom_fwd + deprecate_pr: TBA + remove_pr: + +- name: torch.cuda.amp.custom_bwd + deprecate_pr: TBA + remove_pr: + +- name: torch.cpu.amp.autocast + deprecate_pr: TBA + remove_pr: + # functorch - name: functorch.vmap deprecate_pr: TBA diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index 6a05472..40885ee 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -1,20 +1,23 @@ -import libcst as cst import pkgutil +from typing import List, Optional + +import libcst as cst import yaml -from typing import Optional, List from ...common import ( - TorchVisitor, - TorchError, call_with_name_changes, check_old_names_in_import_from, + TorchError, + TorchVisitor, ) -from .range import call_replacement_range -from .cholesky import call_replacement_cholesky +from .amp import call_replacement_cpu_amp_autocast, call_replacement_cuda_amp_autocast from .chain_matmul import call_replacement_chain_matmul +from .cholesky import call_replacement_cholesky from .qr import call_replacement_qr +from .range import call_replacement_range + class TorchDeprecatedSymbolsVisitor(TorchVisitor): ERRORS: List[TorchError] = [ @@ -49,6 +52,8 @@ def _call_replacement( "torch.range": call_replacement_range, "torch.chain_matmul": call_replacement_chain_matmul, "torch.qr": call_replacement_qr, + "torch.cuda.amp.autocast": call_replacement_cuda_amp_autocast, + "torch.cpu.amp.autocast": call_replacement_cpu_amp_autocast, } replacement = None diff --git a/torchfix/visitors/deprecated_symbols/amp.py b/torchfix/visitors/deprecated_symbols/amp.py new file mode 100644 index 0000000..9aa87c7 --- /dev/null +++ b/torchfix/visitors/deprecated_symbols/amp.py @@ -0,0 +1,26 @@ +import libcst as cst + +from ...common import get_module_name + + +def call_replacement_cpu_amp_autocast(node: cst.Call) -> cst.CSTNode: + return _call_replacement_amp(node, "cpu") + + +def call_replacement_cuda_amp_autocast(node: cst.Call) -> cst.CSTNode: + return _call_replacement_amp(node, "cuda") + + +def _call_replacement_amp(node: cst.Call, device: str) -> cst.CSTNode: + """ + Replace `torch.cuda.amp.autocast()` with `torch.amp.autocast("cuda")` and + Replace `torch.cpu.amp.autocast()` with `torch.amp.autocast("cpu")`. + """ + device_arg = cst.ensure_type(cst.parse_expression(f'f("{device}")'), cst.Call).args[ + 0 + ] + + module_name = get_module_name(node, "torch") + replacement = cst.parse_expression(f"{module_name}.amp.autocast(args)") + replacement = replacement.with_changes(args=(device_arg, *node.args)) + return replacement