From b2d55f8b91ca951ee512c452dc9676046b952b44 Mon Sep 17 00:00:00 2001 From: Eli Uriegas <1700823+seemethere@users.noreply.github.com> Date: Mon, 22 Apr 2024 12:16:43 -0700 Subject: [PATCH] torchfix: Refactor ERROR_CODE to be consistent (#46) --- tests/test_torchfix.py | 4 +-- torchfix/__main__.py | 10 ++---- torchfix/common.py | 22 +++++++++--- torchfix/torchfix.py | 27 ++++++-------- .../visitors/deprecated_symbols/__init__.py | 25 ++++++++----- torchfix/visitors/internal/__init__.py | 26 +++++++++----- torchfix/visitors/misc/__init__.py | 35 +++++++++++-------- torchfix/visitors/nonpublic/__init__.py | 31 ++++++++++++---- torchfix/visitors/performance/__init__.py | 23 +++++++----- torchfix/visitors/security/__init__.py | 25 +++++++------ torchfix/visitors/vision/models_import.py | 34 +++++++++++------- torchfix/visitors/vision/pretrained.py | 19 +++++++--- torchfix/visitors/vision/to_tensor.py | 20 ++++++----- 13 files changed, 187 insertions(+), 114 deletions(-) diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index d699ea7..5f5dff9 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -62,8 +62,8 @@ def test_errorcodes_distinct(): seen = set() for visitor in visitors: LOGGER.info("Checking error code for %s", visitor.__class__.__name__) - error_code = visitor.ERROR_CODE - for e in error_code if isinstance(error_code, list) else [error_code]: + errors = visitor.ERRORS + for e in errors: assert e not in seen seen.add(e) diff --git a/torchfix/__main__.py b/torchfix/__main__.py index 5df0cf9..b8413bf 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -63,16 +63,12 @@ def main() -> None: parser.add_argument( "--select", help=f"Comma-separated list of rules to enable or 'ALL' to enable all rules. " - f"Available rules: {', '.join(list(GET_ALL_ERROR_CODES()))}. " - f"Defaults to all except for {', '.join(DISABLED_BY_DEFAULT)}.", + f"Available rules: {', '.join(list(GET_ALL_ERROR_CODES()))}. " + f"Defaults to all except for {', '.join(DISABLED_BY_DEFAULT)}.", type=str, default=None, ) - parser.add_argument( - "--version", - action="version", - version=f"{TorchFixVersion}" - ) + parser.add_argument("--version", action="version", version=f"{TorchFixVersion}") # XXX TODO: Get rid of this! # Silence "Failed to determine module name" diff --git a/torchfix/common.py b/torchfix/common.py index 7fdd00a..db2fb8c 100644 --- a/torchfix/common.py +++ b/torchfix/common.py @@ -1,10 +1,11 @@ -from dataclasses import dataclass import sys +from abc import ABC +from dataclasses import dataclass +from typing import List, Optional, Set, Tuple + import libcst as cst -from libcst.metadata import QualifiedNameProvider, WhitespaceInclusivePositionProvider from libcst.codemod.visitors import ImportItem -from typing import Optional, List, Set, Tuple, Union -from abc import ABC +from libcst.metadata import QualifiedNameProvider, WhitespaceInclusivePositionProvider IS_TTY = hasattr(sys.stdout, "isatty") and sys.stdout.isatty() CYAN = "\033[96m" if IS_TTY else "" @@ -34,13 +35,24 @@ def codemod_result(self) -> str: return f"{position} {error_code}{fixable} {self.message}" +@dataclass(frozen=True) +class TorchError: + """Defines an error along with an explanation""" + + error_code: str + message_template: str + + def message(self, **kwargs): + return self.message_template.format(**kwargs) + + class TorchVisitor(cst.BatchableCSTVisitor, ABC): METADATA_DEPENDENCIES = ( QualifiedNameProvider, WhitespaceInclusivePositionProvider, ) - ERROR_CODE: Union[str, List[str]] + ERRORS: List[TorchError] def __init__(self) -> None: self.violations: List[LintViolation] = [] diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index dda4057..85c3943 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -14,7 +14,7 @@ from .visitors.internal import TorchScopedLibraryVisitor from .visitors.performance import TorchSynchronizedDataLoaderVisitor -from .visitors.misc import (TorchRequireGradVisitor, TorchReentrantCheckpointVisitor) +from .visitors.misc import TorchRequireGradVisitor, TorchReentrantCheckpointVisitor from .visitors.nonpublic import TorchNonPublicAliasVisitor from .visitors.vision import ( @@ -48,10 +48,7 @@ def GET_ALL_ERROR_CODES(): codes = set() for cls in ALL_VISITOR_CLS: - if isinstance(cls.ERROR_CODE, list): - codes |= set(cls.ERROR_CODE) - else: - codes.add(cls.ERROR_CODE) + codes |= set(error.error_code for error in cls.ERRORS) return codes @@ -86,16 +83,10 @@ def get_visitors_with_error_codes(error_codes): # only correspond to one visitor. found = False for visitor_cls in ALL_VISITOR_CLS: - if isinstance(visitor_cls.ERROR_CODE, list): - if error_code in visitor_cls.ERROR_CODE: - visitor_classes.add(visitor_cls) - found = True - break - else: - if error_code == visitor_cls.ERROR_CODE: - visitor_classes.add(visitor_cls) - found = True - break + if error_code in list(err.error_code for err in visitor_cls.ERRORS): + visitor_classes.add(visitor_cls) + found = True + break if not found: raise AssertionError(f"Unknown error code: {error_code}") out = [] @@ -120,8 +111,10 @@ def process_error_code_str(code_str): if c == "ALL": continue if len(expand_error_codes((c,))) == 0: - raise ValueError(f"Invalid error code: {c}, available error " - f"codes: {list(GET_ALL_ERROR_CODES())}") + raise ValueError( + f"Invalid error code: {c}, available error " + f"codes: {list(GET_ALL_ERROR_CODES())}" + ) if "ALL" in raw_codes: return GET_ALL_ERROR_CODES() diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index 9949242..f1cf61f 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -1,11 +1,12 @@ import libcst as cst import pkgutil import yaml -from typing import Optional +from typing import Optional, List from collections.abc import Sequence from ...common import ( TorchVisitor, + TorchError, call_with_name_changes, ) @@ -16,7 +17,12 @@ class TorchDeprecatedSymbolsVisitor(TorchVisitor): - ERROR_CODE = ["TOR001", "TOR101", "TOR004", "TOR103"] + ERRORS: List[TorchError] = [ + TorchError("TOR001", "Use of removed function {qualified_name}"), + TorchError("TOR101", "Import of deprecated function {qualified_name}"), + TorchError("TOR004", "Import of removed function {qualified_name}"), + TorchError("TOR103", "Import of deprecated function {qualified_name}"), + ] def __init__(self, deprecated_config_path=None): def read_deprecated_config(path=None): @@ -67,11 +73,11 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: qualified_name = f"{module}.{name.name.value}" if qualified_name in self.deprecated_config: if self.deprecated_config[qualified_name]["remove_pr"] is None: - error_code = self.ERROR_CODE[3] - message = f"Import of deprecated function {qualified_name}" + error_code = self.ERRORS[3].error_code + message = self.ERRORS[3].message(qualified_name=qualified_name) else: - error_code = self.ERROR_CODE[2] - message = f"Import of removed function {qualified_name}" + error_code = self.ERRORS[2].error_code + message = self.ERRORS[2].message(qualified_name=qualified_name) reference = self.deprecated_config[qualified_name].get("reference") if reference is not None: @@ -86,11 +92,12 @@ def visit_Call(self, node) -> None: if qualified_name in self.deprecated_config: if self.deprecated_config[qualified_name]["remove_pr"] is None: - error_code = self.ERROR_CODE[1] + error_code = self.ERRORS[1].error_code + message = self.ERRORS[1].message(qualified_name=qualified_name) message = f"Use of deprecated function {qualified_name}" else: - error_code = self.ERROR_CODE[0] - message = f"Use of removed function {qualified_name}" + error_code = self.ERRORS[0].error_code + message = self.ERRORS[0].message(qualified_name=qualified_name) replacement = self._call_replacement(node, qualified_name) reference = self.deprecated_config[qualified_name].get("reference") diff --git a/torchfix/visitors/internal/__init__.py b/torchfix/visitors/internal/__init__.py index 14389b3..908527a 100644 --- a/torchfix/visitors/internal/__init__.py +++ b/torchfix/visitors/internal/__init__.py @@ -1,4 +1,4 @@ -from ...common import TorchVisitor +from ...common import TorchError, TorchVisitor class TorchScopedLibraryVisitor(TorchVisitor): @@ -6,14 +6,24 @@ class TorchScopedLibraryVisitor(TorchVisitor): Suggest `torch.library._scoped_library` for PyTorch tests. """ - ERROR_CODE = "TOR901" - MESSAGE = ( - "Use `torch.library._scoped_library` instead of `torch.library.Library` " - "in PyTorch tests files. See https://github.com/pytorch/pytorch/pull/118318 " - "for details." - ) + ERRORS = [ + TorchError( + "TOR901", + ( + "Use `torch.library._scoped_library` " + "instead of `torch.library.Library` " + "in PyTorch tests files. " + "See https://github.com/pytorch/pytorch/pull/118318 " + "for details." + ), + ) + ] def visit_Call(self, node): qualified_name = self.get_qualified_name_for_call(node) if qualified_name == "torch.library.Library": - self.add_violation(node, error_code=self.ERROR_CODE, message=self.MESSAGE) + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + ) diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index ef60b3e..a8ee248 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -1,8 +1,7 @@ import libcst as cst import libcst.matchers as m - -from ...common import TorchVisitor +from ...common import TorchError, TorchVisitor class TorchRequireGradVisitor(TorchVisitor): @@ -10,8 +9,12 @@ class TorchRequireGradVisitor(TorchVisitor): Find and fix common misspelling `require_grad` (instead of `requires_grad`). """ - ERROR_CODE = "TOR002" - MESSAGE = "Likely typo `require_grad` in assignment. Did you mean `requires_grad`?" + ERRORS = [ + TorchError( + "TOR002", + "Likely typo `require_grad` in assignment. Did you mean `requires_grad`?", + ) + ] def visit_Assign(self, node): # Look for any assignment with `require_grad` attribute on the left. @@ -33,8 +36,8 @@ def visit_Assign(self, node): ) self.add_violation( node, - error_code=self.ERROR_CODE, - message=self.MESSAGE, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), replacement=replacement, ) @@ -44,12 +47,16 @@ class TorchReentrantCheckpointVisitor(TorchVisitor): Find and fix common misuse of reentrant checkpoints. """ - ERROR_CODE = "TOR003" - MESSAGE = ( - "Please pass `use_reentrant` explicitly to `checkpoint`. " - "To maintain old behavior, pass `use_reentrant=True`. " - "It is recommended to use `use_reentrant=False`." - ) + ERRORS = [ + TorchError( + "TOR003", + ( + "Please pass `use_reentrant` explicitly to `checkpoint`. " + "To maintain old behavior, pass `use_reentrant=True`. " + "It is recommended to use `use_reentrant=False`." + ), + ) + ] def visit_Call(self, node): qualified_name = self.get_qualified_name_for_call(node) @@ -65,7 +72,7 @@ def visit_Call(self, node): replacement = node.with_changes(args=node.args + (use_reentrant_arg,)) self.add_violation( node, - error_code=self.ERROR_CODE, - message=self.MESSAGE, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), replacement=replacement, ) diff --git a/torchfix/visitors/nonpublic/__init__.py b/torchfix/visitors/nonpublic/__init__.py index 7240e0f..575ad9d 100644 --- a/torchfix/visitors/nonpublic/__init__.py +++ b/torchfix/visitors/nonpublic/__init__.py @@ -1,10 +1,10 @@ from os.path import commonprefix -from typing import Sequence +from typing import Sequence, List import libcst as cst from libcst.codemod.visitors import ImportItem -from ...common import TorchVisitor +from ...common import TorchError, TorchVisitor class TorchNonPublicAliasVisitor(TorchVisitor): @@ -17,7 +17,20 @@ class TorchNonPublicAliasVisitor(TorchVisitor): see https://github.com/pytorch/pytorch/pull/69862/files """ - ERROR_CODE = ["TOR104", "TOR105"] + ERRORS: List[TorchError] = [ + TorchError( + "TOR104", ( + "Use of non-public function `{qualified_name}`, " + "please use `{public_name}` instead" + ), + ), + TorchError( + "TOR105", ( + "Import of non-public function `{qualified_name}`, " + "please use `{public_name}` instead" + ), + ), + ] # fmt: off ALIASES = { @@ -33,8 +46,10 @@ def visit_Call(self, node): if qualified_name in self.ALIASES: public_name = self.ALIASES[qualified_name] - error_code = self.ERROR_CODE[0] - message = f"Use of non-public function `{qualified_name}`, please use `{public_name}` instead" # noqa: E501 + error_code = self.ERRORS[0].error_code + message = self.ERRORS[0].message( + qualified_name=qualified_name, public_name=public_name + ) call_name = cst.helpers.get_full_name_for_node(node) replacement = None @@ -74,8 +89,10 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: qualified_name = f"{module}.{name.name.value}" if qualified_name in self.ALIASES: public_name = self.ALIASES[qualified_name] - error_code = self.ERROR_CODE[1] - message = f"Import of non-public function `{qualified_name}`, please use `{public_name}` instead" # noqa: E501 + error_code = self.ERRORS[1].error_code + message = self.ERRORS[1].message( + qualified_name=qualified_name, public_name=public_name + ) new_module = ".".join(public_name.split(".")[:-1]) new_name = public_name.split(".")[-1] diff --git a/torchfix/visitors/performance/__init__.py b/torchfix/visitors/performance/__init__.py index 427eb78..249df4c 100644 --- a/torchfix/visitors/performance/__init__.py +++ b/torchfix/visitors/performance/__init__.py @@ -1,7 +1,6 @@ import libcst.matchers as m - -from ...common import TorchVisitor +from ...common import TorchError, TorchVisitor class TorchSynchronizedDataLoaderVisitor(TorchVisitor): @@ -10,12 +9,16 @@ class TorchSynchronizedDataLoaderVisitor(TorchVisitor): https://github.com/pytorch/pytorch/blob/main/torch/profiler/_pattern_matcher.py """ - ERROR_CODE = "TOR401" - MESSAGE = ( - "Detected DataLoader running with synchronized implementation. " - "Please enable asynchronous dataloading by setting num_workers > 0 when " - "initializing DataLoader." - ) + ERRORS = [ + TorchError( + "TOR401", + ( + "Detected DataLoader running with synchronized implementation." + " Please enable asynchronous dataloading by setting " + "num_workers > 0 when initializing DataLoader." + ), + ) + ] def visit_Call(self, node): qualified_name = self.get_qualified_name_for_call(node) @@ -25,5 +28,7 @@ def visit_Call(self, node): num_workers_arg.value, m.Integer(value="0") ): self.add_violation( - node, error_code=self.ERROR_CODE, message=self.MESSAGE + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), ) diff --git a/torchfix/visitors/security/__init__.py b/torchfix/visitors/security/__init__.py index 5dfdf6e..775bed9 100644 --- a/torchfix/visitors/security/__init__.py +++ b/torchfix/visitors/security/__init__.py @@ -1,5 +1,6 @@ import libcst as cst -from ...common import TorchVisitor + +from ...common import TorchError, TorchVisitor class TorchUnsafeLoadVisitor(TorchVisitor): @@ -8,13 +9,17 @@ class TorchUnsafeLoadVisitor(TorchVisitor): See https://github.com/pytorch/pytorch/issues/31875. """ - ERROR_CODE = "TOR102" - MESSAGE = ( - "`torch.load` without `weights_only` parameter is unsafe. " - "Explicitly set `weights_only` to False only if you trust the data you load " - "and full pickle functionality is needed, otherwise set " - "`weights_only=True`." - ) + ERRORS = [ + TorchError( + "TOR102", + ( + "`torch.load` without `weights_only` parameter is unsafe. " + "Explicitly set `weights_only` to False only if you trust " + "the data you load " "and full pickle functionality is needed," + " otherwise set `weights_only=True`." + ), + ) + ] def visit_Call(self, node): qualified_name = self.get_qualified_name_for_call(node) @@ -40,7 +45,7 @@ def visit_Call(self, node): ) self.add_violation( node, - error_code=self.ERROR_CODE, - message=self.MESSAGE, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), replacement=replacement, ) diff --git a/torchfix/visitors/vision/models_import.py b/torchfix/visitors/vision/models_import.py index de75d5a..f3b0797 100644 --- a/torchfix/visitors/vision/models_import.py +++ b/torchfix/visitors/vision/models_import.py @@ -1,24 +1,32 @@ import libcst as cst import libcst.matchers as m -from ...common import TorchVisitor +from ...common import TorchError, TorchVisitor class TorchVisionModelsImportVisitor(TorchVisitor): - ERROR_CODE = "TOR203" - MESSAGE = ( - "Consider replacing 'import torchvision.models as models' " - "with 'from torchvision import models'." - ) + ERRORS = [ + TorchError( + "TOR203", + ( + "Consider replacing 'import torchvision.models as models' " + "with 'from torchvision import models'." + ), + ) + ] def visit_Import(self, node: cst.Import) -> None: replacement = None for imported_item in node.names: - if m.matches(imported_item, m.ImportAlias( - name=m.Attribute(value=m.Name("torchvision"), - attr=m.Name("models")), - asname=m.AsName(name=m.Name("models")) - )): + if m.matches( + imported_item, + m.ImportAlias( + name=m.Attribute( + value=m.Name("torchvision"), attr=m.Name("models") + ), + asname=m.AsName(name=m.Name("models")), + ), + ): # Replace only if the import statement has no other names if len(node.names) == 1: replacement = cst.ImportFrom( @@ -27,8 +35,8 @@ def visit_Import(self, node: cst.Import) -> None: ) self.add_violation( node, - error_code=self.ERROR_CODE, - message=self.MESSAGE, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), replacement=replacement, ) break diff --git a/torchfix/visitors/vision/pretrained.py b/torchfix/visitors/vision/pretrained.py index 99dd845..af52a0f 100644 --- a/torchfix/visitors/vision/pretrained.py +++ b/torchfix/visitors/vision/pretrained.py @@ -3,7 +3,7 @@ import libcst as cst from libcst.codemod.visitors import ImportItem -from ...common import TorchVisitor +from ...common import TorchError, TorchVisitor class TorchVisionDeprecatedPretrainedVisitor(TorchVisitor): @@ -16,7 +16,12 @@ class TorchVisionDeprecatedPretrainedVisitor(TorchVisitor): otherwise only lint violation is emitted. """ - ERROR_CODE = "TOR201" + ERRORS = [ + TorchError( + "TOR201", + "Parameter `{old_arg_name}` is deprecated, please use `{new_arg_name}` instead.", + ) + ] # flake8: noqa: E105 # fmt: off @@ -215,13 +220,17 @@ def _new_arg_and_import( message = None pretrained_arg = self.get_specific_arg(node, "pretrained", 0) if pretrained_arg is not None: - message = "Parameter `pretrained` is deprecated, please use `weights` instead." + message = self.ERRORS[0].message( + old_arg_name="pretrained", new_arg_name="weights" + ) pretrained_backbone_arg = self.get_specific_arg( node, "pretrained_backbone", 1 ) if pretrained_backbone_arg is not None: - message = "Parameter `pretrained_backbone` is deprecated, please use `weights_backbone` instead." + message = self.ERRORS[0].message( + old_arg_name="pretrained_backbone", new_arg_name="weights_backbone" + ) replacement_args = list(node.args) @@ -250,7 +259,7 @@ def _new_arg_and_import( if message is not None: self.add_violation( node, - error_code=self.ERROR_CODE, + error_code=self.ERRORS[0].error_code, message=message, replacement=replacement, ) diff --git a/torchfix/visitors/vision/to_tensor.py b/torchfix/visitors/vision/to_tensor.py index 3395dd9..791a9e5 100644 --- a/torchfix/visitors/vision/to_tensor.py +++ b/torchfix/visitors/vision/to_tensor.py @@ -1,21 +1,25 @@ from collections.abc import Sequence + import libcst as cst -from ...common import TorchVisitor +from ...common import TorchError, TorchVisitor + +MESSAGE = ( + "The transform `v2.ToTensor()` is deprecated and will be removed " + "in a future release. Instead, please use " + "`v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`." # noqa: E501 +) class TorchVisionDeprecatedToTensorVisitor(TorchVisitor): - ERROR_CODE = "TOR202" - MESSAGE = ( - "The transform `v2.ToTensor()` is deprecated and will be removed " - "in a future release. Instead, please use " - "`v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`." # noqa: E501 - ) + ERRORS = [TorchError("TOR202", MESSAGE)] def _maybe_add_violation(self, qualified_name, node): if qualified_name != "torchvision.transforms.v2.ToTensor": return - self.add_violation(node, error_code=self.ERROR_CODE, message=self.MESSAGE) + self.add_violation( + node, error_code=self.ERRORS[0].error_code, message=self.ERRORS[0].message() + ) def visit_ImportFrom(self, node): module_path = cst.helpers.get_absolute_module_from_package_for_import(