From a809e97d5b0c47eb317aee59a59e707c456984e2 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Mon, 5 Feb 2024 15:36:55 -0800 Subject: [PATCH] Fix bug with function name replacement --- .../deprecated_symbols/codemod/ger-outer.py | 3 ++ .../codemod/ger-outer.py.out | 3 ++ torchfix/common.py | 34 ++++++++++++++----- .../visitors/deprecated_symbols/__init__.py | 6 ++-- 4 files changed, 36 insertions(+), 10 deletions(-) diff --git a/tests/fixtures/deprecated_symbols/codemod/ger-outer.py b/tests/fixtures/deprecated_symbols/codemod/ger-outer.py index c5e64c4..6fce087 100644 --- a/tests/fixtures/deprecated_symbols/codemod/ger-outer.py +++ b/tests/fixtures/deprecated_symbols/codemod/ger-outer.py @@ -1,6 +1,9 @@ import torch +from torch import ger deprecated = torch.norm() sinusoid_inp = torch.ger(pos_seq, inv_freq) other = something.ger(pos_seq, inv_freq) deprecated = torch.norm() one_more = torch.ger(pos_seq, inv_freq) + +just_name = ger(pos_seq, inv_freq) diff --git a/tests/fixtures/deprecated_symbols/codemod/ger-outer.py.out b/tests/fixtures/deprecated_symbols/codemod/ger-outer.py.out index 45f3d84..3303ed0 100644 --- a/tests/fixtures/deprecated_symbols/codemod/ger-outer.py.out +++ b/tests/fixtures/deprecated_symbols/codemod/ger-outer.py.out @@ -1,6 +1,9 @@ import torch +from torch import outer, ger deprecated = torch.norm() sinusoid_inp = torch.outer(pos_seq, inv_freq) other = something.ger(pos_seq, inv_freq) deprecated = torch.norm() one_more = torch.outer(pos_seq, inv_freq) + +just_name = outer(pos_seq, inv_freq) diff --git a/torchfix/common.py b/torchfix/common.py index 52f2f52..b302346 100644 --- a/torchfix/common.py +++ b/torchfix/common.py @@ -3,7 +3,7 @@ import libcst as cst from libcst.metadata import QualifiedNameProvider, WhitespaceInclusivePositionProvider from libcst.codemod.visitors import ImportItem -from typing import Optional, List, Set, Union +from typing import Optional, List, Set, Tuple, Union from abc import ABC IS_TTY = hasattr(sys.stdout, "isatty") and sys.stdout.isatty() @@ -83,19 +83,34 @@ def get_qualified_name_for_call(self, node: cst.Call) -> Optional[str]: def call_with_name_changes( node: cst.Call, old_qualified_name: str, new_qualified_name: str -) -> Optional[cst.Call]: +) -> Optional[Tuple[cst.Call, Set[ImportItem]]]: """ - Return new `Call` node with name changes. + Return an optional tuple: + new `Call` node with name changes + and a set of newly needed imports. """ old_begin, _, old_last = old_qualified_name.rpartition(".") new_begin, _, new_last = new_qualified_name.rpartition(".") + needed_imports: Set[ImportItem] = set() # If the only difference is the last name part. if old_begin == new_begin: - replacement = node.with_deep_changes( - old_node=cst.ensure_type(node.func, cst.Attribute).attr, - value=new_last, - ) + if isinstance(node.func, cst.Attribute): + replacement = node.with_deep_changes( + old_node=node.func.attr, + value=new_last, + ) + elif isinstance(node.func, cst.Name): + replacement = node.with_deep_changes( + old_node=node.func, + value=new_last, + ) + needed_imports.add( + ImportItem( + module_name=new_begin, + obj_name=new_last, + ) + ) # If the last name part is the same and # originally called without a dot: don't change the call site, @@ -106,7 +121,10 @@ def call_with_name_changes( # Replace with new_qualified_name. else: replacement = node.with_changes(func=cst.parse_expression(new_qualified_name)) - return replacement + if replacement is None: + return None + else: + return replacement, needed_imports def deep_multi_replace(tree, replacement_map): diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index fed7032..93a9082 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -49,10 +49,12 @@ def _call_replacement( qualified_name, {} ).get("replacement", "") if function_name_replacement: - replacement = call_with_name_changes( + replacement_and_imports = call_with_name_changes( node, qualified_name, function_name_replacement ) - + if replacement_and_imports is not None: + replacement, imports = replacement_and_imports + self.needed_imports.update(imports) return replacement def visit_Call(self, node):