@@ -15,15 +15,17 @@ class TorchUnsafeLoadVisitor(TorchVisitor):
15
15
(
16
16
"`torch.load` without `weights_only` parameter is unsafe. "
17
17
"Explicitly set `weights_only` to False only if you trust "
18
- "the data you load " "and full pickle functionality is needed,"
18
+ "the data you load "
19
+ "and full pickle functionality is needed,"
19
20
" otherwise set `weights_only=True`."
20
21
),
21
22
)
22
23
]
23
24
24
25
def visit_Call (self , node ):
25
- if self .get_qualified_name_for_call (node ) == "torch.load" and \
26
- not self .has_specific_arg (node , "weights_only" ):
26
+ if self .get_qualified_name_for_call (
27
+ node
28
+ ) == "torch.load" and not self .has_specific_arg (node , "weights_only" ):
27
29
# Add `weights_only=True` if there is no `pickle_module`.
28
30
# (do not add `weights_only=False` with `pickle_module`, as it
29
31
# needs to be an explicit choice).
@@ -37,9 +39,7 @@ def visit_Call(self, node):
37
39
weights_only_arg = cst .ensure_type (
38
40
cst .parse_expression ("f(weights_only=True)" ), cst .Call
39
41
).args [0 ]
40
- replacement = node .with_changes (
41
- args = node .args + (weights_only_arg ,)
42
- )
42
+ replacement = node .with_changes (args = node .args + (weights_only_arg ,))
43
43
self .add_violation (
44
44
node ,
45
45
error_code = self .ERRORS [0 ].error_code ,
0 commit comments