File tree Expand file tree Collapse file tree 6 files changed +44
-1
lines changed Expand file tree Collapse file tree 6 files changed +44
-1
lines changed Original file line number Diff line number Diff line change 1+ import  torch 
2+ 
3+ 
4+ a  =  torch .randn (5 )
5+ b  =  1  /  torch .sqrt (a )
6+ b  =  1.0  /  torch .sqrt (a )
7+ b  =  a  /  torch .sqrt (a )
8+ # False negative 
9+ b  =  1  /  a .sqrt ()
10+ b  =  1.0  /  a .sqrt ()
Original file line number Diff line number Diff line change 1+ 5:5 TOR109 Consider using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`, which is faster.
2+ 6:5 TOR109 Consider using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`, which is faster.
3+ 7:5 TOR109 Consider using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`, which is faster.
Original file line number Diff line number Diff line change @@ -48,6 +48,7 @@ def pytest_generate_tests(metafunc):
4848                    "TOR106" ,
4949                    "TOR107" ,
5050                    "TOR108" ,
51+                     "TOR109" ,
5152                },
5253            ),
5354            (None , set (GET_ALL_ERROR_CODES ()) -  exclude_set ),
Original file line number Diff line number Diff line change 1818    TorchScopedLibraryVisitor ,
1919    TorchSynchronizedDataLoaderVisitor ,
2020    TorchUnsafeLoadVisitor ,
21+     TorchRsqrtVisitor ,
2122    TorchVisionDeprecatedPretrainedVisitor ,
2223    TorchVisionDeprecatedToTensorVisitor ,
2324    TorchVisionSingletonImportVisitor ,
3536    TorchExpm1Visitor ,
3637    TorchLog1pVisitor ,
3738    TorchLogsumexpVisitor ,
39+     TorchRsqrtVisitor ,
3840    TorchNonPublicAliasVisitor ,
3941    TorchRequireGradVisitor ,
4042    TorchReentrantCheckpointVisitor ,
Original file line number Diff line number Diff line change 66    TorchLogsumexpVisitor ,
77    TorchReentrantCheckpointVisitor ,
88    TorchRequireGradVisitor ,
9+     TorchRsqrtVisitor ,
910)
1011from  .nonpublic  import  TorchNonPublicAliasVisitor 
1112from  .performance  import  (
3031    "TorchScopedLibraryVisitor" ,
3132    "TorchSynchronizedDataLoaderVisitor" ,
3233    "TorchUnsafeLoadVisitor" ,
34+     "TorchRsqrtVisitor" ,
3335    "TorchVisionDeprecatedPretrainedVisitor" ,
3436    "TorchVisionDeprecatedToTensorVisitor" ,
3537    "TorchVisionSingletonImportVisitor" ,
Original file line number Diff line number Diff line change @@ -184,7 +184,6 @@ def visit_Call(self, node):
184184                        )
185185                        ==  "torch.exp" 
186186                    ):
187- 
188187                        # if `dim` is not provided or None for sum, skip: 
189188                        # https://github.com/pytorch/pytorch/issues/144339 
190189                        dim_arg  =  self .get_specific_arg (
@@ -201,3 +200,29 @@ def visit_Call(self, node):
201200                                    message = self .ERRORS [0 ].message (),
202201                                    replacement = None ,
203202                                )
203+ 
204+ 
205+ class  TorchRsqrtVisitor (TorchVisitor ):
206+     """ 
207+     Suggest using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`. 
208+     """ 
209+ 
210+     ERRORS  =  [
211+         TorchError (
212+             "TOR109" ,
213+             ("Consider using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`, which is faster." ),
214+         )
215+     ]
216+ 
217+     def  visit_BinaryOperation (self , node ):
218+         if  m .matches (
219+             node ,
220+             m .BinaryOperation (operator = m .Divide (), right = m .Call ()),
221+         ):
222+             if  self .get_qualified_name_for_call (node .right ) ==  "torch.sqrt" :
223+                 self .add_violation (
224+                     node ,
225+                     error_code = self .ERRORS [0 ].error_code ,
226+                     message = self .ERRORS [0 ].message (),
227+                     replacement = None ,
228+                 )
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments