Skip to content

Commit bff3e5f

Browse files
committed
Add rsqrt check
Signed-off-by: Yuanyuan Chen <[email protected]>
1 parent 0d9c3fe commit bff3e5f

File tree

6 files changed

+44
-1
lines changed

6 files changed

+44
-1
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import torch
2+
3+
4+
a = torch.randn(5)
5+
b = 1 / torch.sqrt(a)
6+
b = 1.0 / torch.sqrt(a)
7+
b = a / torch.sqrt(a)
8+
# False negative
9+
b = 1 / a.sqrt()
10+
b = 1.0 / a.sqrt()
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
5:5 TOR109 Consider using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`, which is faster.
2+
6:5 TOR109 Consider using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`, which is faster.
3+
7:5 TOR109 Consider using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`, which is faster.

tests/test_torchfix.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def pytest_generate_tests(metafunc):
4848
"TOR106",
4949
"TOR107",
5050
"TOR108",
51+
"TOR109",
5152
},
5253
),
5354
(None, set(GET_ALL_ERROR_CODES()) - exclude_set),

torchfix/torchfix.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
TorchScopedLibraryVisitor,
1919
TorchSynchronizedDataLoaderVisitor,
2020
TorchUnsafeLoadVisitor,
21+
TorchRsqrtVisitor,
2122
TorchVisionDeprecatedPretrainedVisitor,
2223
TorchVisionDeprecatedToTensorVisitor,
2324
TorchVisionSingletonImportVisitor,
@@ -35,6 +36,7 @@
3536
TorchExpm1Visitor,
3637
TorchLog1pVisitor,
3738
TorchLogsumexpVisitor,
39+
TorchRsqrtVisitor,
3840
TorchNonPublicAliasVisitor,
3941
TorchRequireGradVisitor,
4042
TorchReentrantCheckpointVisitor,

torchfix/visitors/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
TorchLogsumexpVisitor,
77
TorchReentrantCheckpointVisitor,
88
TorchRequireGradVisitor,
9+
TorchRsqrtVisitor,
910
)
1011
from .nonpublic import TorchNonPublicAliasVisitor
1112
from .performance import (
@@ -30,6 +31,7 @@
3031
"TorchScopedLibraryVisitor",
3132
"TorchSynchronizedDataLoaderVisitor",
3233
"TorchUnsafeLoadVisitor",
34+
"TorchRsqrtVisitor",
3335
"TorchVisionDeprecatedPretrainedVisitor",
3436
"TorchVisionDeprecatedToTensorVisitor",
3537
"TorchVisionSingletonImportVisitor",

torchfix/visitors/misc/__init__.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@ def visit_Call(self, node):
184184
)
185185
== "torch.exp"
186186
):
187-
188187
# if `dim` is not provided or None for sum, skip:
189188
# https://github.com/pytorch/pytorch/issues/144339
190189
dim_arg = self.get_specific_arg(
@@ -201,3 +200,29 @@ def visit_Call(self, node):
201200
message=self.ERRORS[0].message(),
202201
replacement=None,
203202
)
203+
204+
205+
class TorchRsqrtVisitor(TorchVisitor):
206+
"""
207+
Suggest using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`.
208+
"""
209+
210+
ERRORS = [
211+
TorchError(
212+
"TOR109",
213+
("Consider using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`, which is faster."),
214+
)
215+
]
216+
217+
def visit_BinaryOperation(self, node):
218+
if m.matches(
219+
node,
220+
m.BinaryOperation(operator=m.Divide(), right=m.Call()),
221+
):
222+
if self.get_qualified_name_for_call(node.right) == "torch.sqrt":
223+
self.add_violation(
224+
node,
225+
error_code=self.ERRORS[0].error_code,
226+
message=self.ERRORS[0].message(),
227+
replacement=None,
228+
)

0 commit comments

Comments
 (0)