Skip to content

Commit

Permalink
Move torchvision.models visitor to vision dir
Browse files Browse the repository at this point in the history
  • Loading branch information
gesuwen committed Mar 3, 2024
1 parent ec17643 commit 05aaad2
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 74 deletions.
22 changes: 0 additions & 22 deletions tests/fixtures/deprecated_symbols/codemod/torchvision_models.py

This file was deleted.

This file was deleted.

5 changes: 5 additions & 0 deletions tests/fixtures/vision/checker/models_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import torchvision.models as models
import torchvision.models as cnn
from torchvision.models import resnet50, resnet101
import torchvision.models
from torchvision.models import *
1 change: 1 addition & 0 deletions tests/fixtures/vision/checker/models_import.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
1:1 TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'.
8 changes: 3 additions & 5 deletions torchfix/torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from .visitors.deprecated_symbols import (
TorchDeprecatedSymbolsVisitor,
_UpdateFunctorchImports,
_UpdateTorchvisionModelsImports,
)

from .visitors.internal import TorchScopedLibraryVisitor
Expand All @@ -20,6 +19,7 @@
from .visitors.vision import (
TorchVisionDeprecatedPretrainedVisitor,
TorchVisionDeprecatedToTensorVisitor,
TorchVisionModelsImportVisitor,
)
from .visitors.security import TorchUnsafeLoadVisitor

Expand All @@ -36,6 +36,7 @@
TorchSynchronizedDataLoaderVisitor,
TorchVisionDeprecatedPretrainedVisitor,
TorchVisionDeprecatedToTensorVisitor,
TorchVisionModelsImportVisitor,
TorchUnsafeLoadVisitor,
TorchReentrantCheckpointVisitor,
]
Expand Down Expand Up @@ -230,11 +231,8 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module:
update_functorch_imports_visitor = _UpdateFunctorchImports()
new_module = new_module.visit(update_functorch_imports_visitor)

update_torchvision_models_visitor = _UpdateTorchvisionModelsImports()
new_module = new_module.visit(update_torchvision_models_visitor)

if fixes_count == 0 and not update_functorch_imports_visitor.changed \
and not update_torchvision_models_visitor.changed:
if fixes_count == 0 and not update_functorch_imports_visitor.changed:
raise codemod.SkipFile("No changes")

return new_module
25 changes: 0 additions & 25 deletions torchfix/visitors/deprecated_symbols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,28 +117,3 @@ def leave_ImportFrom(
self.changed = True
return updated_node.with_changes(module=cst.parse_expression("torch.func"))
return updated_node

# TODO: refactor/generalize this.
class _UpdateTorchvisionModelsImports(cst.CSTTransformer):

def __init__(self):
self.changed = False

def leave_Import(
self, node: cst.Import, updated_node: cst.Import
) -> cst.CSTNode:
if len(updated_node.names) == 1:
alias = updated_node.names[0]
if isinstance(alias.name, cst.Attribute) and \
alias.name.value.value == 'torchvision' and \
alias.name.attr.value == 'models' and \
alias.asname and alias.asname.name.value == 'models':

self.changed = True
new_import = cst.ImportFrom(
module=cst.Name(value='torchvision'),
names=[cst.ImportAlias(name=cst.Name(value='models'))]
)
return new_import

return updated_node
1 change: 1 addition & 0 deletions torchfix/visitors/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .pretrained import TorchVisionDeprecatedPretrainedVisitor # noqa: F401
from .to_tensor import TorchVisionDeprecatedToTensorVisitor # noqa: F401
from .models_import import TorchVisionModelsImportVisitor # noqa: F401
40 changes: 40 additions & 0 deletions torchfix/visitors/vision/models_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import libcst as cst

from ...common import LintViolation, TorchVisitor


class TorchVisionModelsImportVisitor(TorchVisitor):
ERROR_CODE = "TOR203"

def visit_Import(self, node: cst.Import) -> None:
for imported_item in node.names:
if isinstance(imported_item.name, cst.Attribute):
if (
isinstance(imported_item.name.value, cst.Name)
and imported_item.name.value.value == "torchvision"
and imported_item.name.attr.value == "models"
and imported_item.asname is not None
and imported_item.asname.name.value == "models"
):
print(imported_item.asname.name.value)
position = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)
# print(position)
replacement = cst.ImportFrom(
module=cst.Name("torchvision"),
names=[cst.ImportAlias(name=cst.Name("models"))],
)
self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=(
"Consider replacing 'import torchvision.models as"
" models' with 'from torchvision import models'."
),
line=position.start.line,
column=position.start.column,
node=node,
replacement=replacement
)
)

0 comments on commit 05aaad2

Please sign in to comment.