From bff3e5f15ccdf2f6415e67113f108444545025d7 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Sun, 14 Sep 2025 11:15:28 +0800 Subject: [PATCH 1/2] Add rsqrt check Signed-off-by: Yuanyuan Chen --- tests/fixtures/misc/checker/rsqrt.py | 10 ++++++++++ tests/fixtures/misc/checker/rsqrt.txt | 3 +++ tests/test_torchfix.py | 1 + torchfix/torchfix.py | 2 ++ torchfix/visitors/__init__.py | 2 ++ torchfix/visitors/misc/__init__.py | 27 ++++++++++++++++++++++++++- 6 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 tests/fixtures/misc/checker/rsqrt.py create mode 100644 tests/fixtures/misc/checker/rsqrt.txt diff --git a/tests/fixtures/misc/checker/rsqrt.py b/tests/fixtures/misc/checker/rsqrt.py new file mode 100644 index 0000000..d596e49 --- /dev/null +++ b/tests/fixtures/misc/checker/rsqrt.py @@ -0,0 +1,10 @@ +import torch + + +a = torch.randn(5) +b = 1 / torch.sqrt(a) +b = 1.0 / torch.sqrt(a) +b = a / torch.sqrt(a) +# False negative +b = 1 / a.sqrt() +b = 1.0 / a.sqrt() diff --git a/tests/fixtures/misc/checker/rsqrt.txt b/tests/fixtures/misc/checker/rsqrt.txt new file mode 100644 index 0000000..1f882e4 --- /dev/null +++ b/tests/fixtures/misc/checker/rsqrt.txt @@ -0,0 +1,3 @@ +5:5 TOR109 Consider using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`, which is faster. +6:5 TOR109 Consider using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`, which is faster. +7:5 TOR109 Consider using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`, which is faster. diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 5baa12a..b008ecd 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -48,6 +48,7 @@ def pytest_generate_tests(metafunc): "TOR106", "TOR107", "TOR108", + "TOR109", }, ), (None, set(GET_ALL_ERROR_CODES()) - exclude_set), diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 5e96e38..3968213 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -18,6 +18,7 @@ TorchScopedLibraryVisitor, TorchSynchronizedDataLoaderVisitor, TorchUnsafeLoadVisitor, + TorchRsqrtVisitor, TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, TorchVisionSingletonImportVisitor, @@ -35,6 +36,7 @@ TorchExpm1Visitor, TorchLog1pVisitor, TorchLogsumexpVisitor, + TorchRsqrtVisitor, TorchNonPublicAliasVisitor, TorchRequireGradVisitor, TorchReentrantCheckpointVisitor, diff --git a/torchfix/visitors/__init__.py b/torchfix/visitors/__init__.py index 45f2438..0a488f6 100644 --- a/torchfix/visitors/__init__.py +++ b/torchfix/visitors/__init__.py @@ -6,6 +6,7 @@ TorchLogsumexpVisitor, TorchReentrantCheckpointVisitor, TorchRequireGradVisitor, + TorchRsqrtVisitor, ) from .nonpublic import TorchNonPublicAliasVisitor from .performance import ( @@ -30,6 +31,7 @@ "TorchScopedLibraryVisitor", "TorchSynchronizedDataLoaderVisitor", "TorchUnsafeLoadVisitor", + "TorchRsqrtVisitor", "TorchVisionDeprecatedPretrainedVisitor", "TorchVisionDeprecatedToTensorVisitor", "TorchVisionSingletonImportVisitor", diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index 8f0c70c..f669274 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -184,7 +184,6 @@ def visit_Call(self, node): ) == "torch.exp" ): - # if `dim` is not provided or None for sum, skip: # https://github.com/pytorch/pytorch/issues/144339 dim_arg = self.get_specific_arg( @@ -201,3 +200,29 @@ def visit_Call(self, node): message=self.ERRORS[0].message(), replacement=None, ) + + +class TorchRsqrtVisitor(TorchVisitor): + """ + Suggest using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`. + """ + + ERRORS = [ + TorchError( + "TOR109", + ("Consider using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`, which is faster."), + ) + ] + + def visit_BinaryOperation(self, node): + if m.matches( + node, + m.BinaryOperation(operator=m.Divide(), right=m.Call()), + ): + if self.get_qualified_name_for_call(node.right) == "torch.sqrt": + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, + ) From b4cee8dde72fb9fba04f131d72738562a9e719aa Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Mon, 15 Sep 2025 07:59:18 +0800 Subject: [PATCH 2/2] Fix flake8 warnings Signed-off-by: Yuanyuan Chen --- tests/fixtures/misc/checker/rsqrt.txt | 6 +++--- torchfix/visitors/misc/__init__.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/fixtures/misc/checker/rsqrt.txt b/tests/fixtures/misc/checker/rsqrt.txt index 1f882e4..1e480be 100644 --- a/tests/fixtures/misc/checker/rsqrt.txt +++ b/tests/fixtures/misc/checker/rsqrt.txt @@ -1,3 +1,3 @@ -5:5 TOR109 Consider using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`, which is faster. -6:5 TOR109 Consider using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`, which is faster. -7:5 TOR109 Consider using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`, which is faster. +5:5 TOR109 Consider faster `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`. +6:5 TOR109 Consider faster `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`. +7:5 TOR109 Consider faster `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`. diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index f669274..e5dafcb 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -210,7 +210,7 @@ class TorchRsqrtVisitor(TorchVisitor): ERRORS = [ TorchError( "TOR109", - ("Consider using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`, which is faster."), + ("Consider faster `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`."), ) ]