Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Issue 7] Update import torchvision.models as models #26

Merged
merged 4 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tests/fixtures/vision/checker/models_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import torchvision.models as models
import torchvision.models as cnn
from torchvision.models import resnet50, resnet101
import torchvision.models
from torchvision.models import *
1 change: 1 addition & 0 deletions tests/fixtures/vision/checker/models_import.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
1:1 TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'.
2 changes: 2 additions & 0 deletions torchfix/torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .visitors.vision import (
TorchVisionDeprecatedPretrainedVisitor,
TorchVisionDeprecatedToTensorVisitor,
TorchVisionModelsImportVisitor,
)
from .visitors.security import TorchUnsafeLoadVisitor

Expand All @@ -35,6 +36,7 @@
TorchSynchronizedDataLoaderVisitor,
TorchVisionDeprecatedPretrainedVisitor,
TorchVisionDeprecatedToTensorVisitor,
TorchVisionModelsImportVisitor,
TorchUnsafeLoadVisitor,
TorchReentrantCheckpointVisitor,
]
Expand Down
1 change: 1 addition & 0 deletions torchfix/visitors/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .pretrained import TorchVisionDeprecatedPretrainedVisitor # noqa: F401
from .to_tensor import TorchVisionDeprecatedToTensorVisitor # noqa: F401
from .models_import import TorchVisionModelsImportVisitor # noqa: F401
40 changes: 40 additions & 0 deletions torchfix/visitors/vision/models_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import libcst as cst

from ...common import LintViolation, TorchVisitor


class TorchVisionModelsImportVisitor(TorchVisitor):
ERROR_CODE = "TOR203"

def visit_Import(self, node: cst.Import) -> None:
for imported_item in node.names:
if isinstance(imported_item.name, cst.Attribute):
if (
isinstance(imported_item.name.value, cst.Name)
and imported_item.name.value.value == "torchvision"
and isinstance(imported_item.name.attr, cst.Name)
and imported_item.name.attr.value == "models"
and imported_item.asname is not None
and isinstance(imported_item.asname.name, cst.Name)
and imported_item.asname.name.value == "models"
):
position = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)
replacement = cst.ImportFrom(
module=cst.Name("torchvision"),
names=[cst.ImportAlias(name=cst.Name("models"))],
)
self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=(
"Consider replacing 'import torchvision.models as"
" models' with 'from torchvision import models'."
),
line=position.start.line,
column=position.start.column,
node=node,
replacement=replacement
)
)
Loading