@@ -22,30 +22,27 @@ class TorchUnsafeLoadVisitor(TorchVisitor):
22
22
]
23
23
24
24
def visit_Call (self , node ):
25
- qualified_name = self .get_qualified_name_for_call (node )
26
- if qualified_name == "torch.load" :
27
- weights_only_arg = self .get_specific_arg (node , "weights_only" , - 1 )
28
- if weights_only_arg is None :
29
- # Add `weights_only=True` if there is no `pickle_module`.
30
- # (do not add `weights_only=False` with `pickle_module`, as it
31
- # needs to be an explicit choice).
32
- #
33
- # This codemod is somewhat unsafe correctness-wise
34
- # because full pickling functionality may still be needed
35
- # even without `pickle_module`,
36
- # so the changes need to be verified/tested.
37
- replacement = None
38
- pickle_module_arg = self .get_specific_arg (node , "pickle_module" , 2 )
39
- if pickle_module_arg is None :
40
- weights_only_arg = cst .ensure_type (
41
- cst .parse_expression ("f(weights_only=True)" ), cst .Call
42
- ).args [0 ]
43
- replacement = node .with_changes (
44
- args = node .args + (weights_only_arg ,)
45
- )
46
- self .add_violation (
47
- node ,
48
- error_code = self .ERRORS [0 ].error_code ,
49
- message = self .ERRORS [0 ].message (),
50
- replacement = replacement ,
25
+ if self .get_qualified_name_for_call (node ) == "torch.load" and \
26
+ not self .has_specific_arg (node , "weights_only" ):
27
+ # Add `weights_only=True` if there is no `pickle_module`.
28
+ # (do not add `weights_only=False` with `pickle_module`, as it
29
+ # needs to be an explicit choice).
30
+ #
31
+ # This codemod is somewhat unsafe correctness-wise
32
+ # because full pickling functionality may still be needed
33
+ # even without `pickle_module`,
34
+ # so the changes need to be verified/tested.
35
+ replacement = None
36
+ if not self .has_specific_arg (node , "pickle_module" , 2 ):
37
+ weights_only_arg = cst .ensure_type (
38
+ cst .parse_expression ("f(weights_only=True)" ), cst .Call
39
+ ).args [0 ]
40
+ replacement = node .with_changes (
41
+ args = node .args + (weights_only_arg ,)
51
42
)
43
+ self .add_violation (
44
+ node ,
45
+ error_code = self .ERRORS [0 ].error_code ,
46
+ message = self .ERRORS [0 ].message (),
47
+ replacement = replacement ,
48
+ )
0 commit comments