Skip to content

Commit af37f69

Browse files
authored
[refactoring] Extract helper method has_specific_arg (#49)
fix #8, extract helper method `has_specific_arg` that checks for the call argument presence, and simplify all relevant call sites
1 parent b2d55f8 commit af37f69

File tree

3 files changed

+55
-44
lines changed

3 files changed

+55
-44
lines changed

torchfix/common.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ def __init__(self) -> None:
6262
def get_specific_arg(
6363
node: cst.Call, arg_name: str, arg_pos: int
6464
) -> Optional[cst.Arg]:
65-
# `arg_pos` is zero-based.
65+
"""
66+
:param arg_pos: `arg_pos` is zero-based. -1 means it's a keyword argument.
67+
:note: consider using `has_specific_arg` if you only need to check for presence.
68+
"""
6669
curr_pos = 0
6770
for arg in node.args:
6871
if arg.keyword is None:
@@ -73,6 +76,18 @@ def get_specific_arg(
7376
return arg
7477
return None
7578

79+
@staticmethod
80+
def has_specific_arg(
81+
node: cst.Call, arg_name: str, position: Optional[int] = None
82+
) -> bool:
83+
"""
84+
Check if the specific argument is present in a call.
85+
"""
86+
return TorchVisitor.get_specific_arg(
87+
node, arg_name,
88+
position if position is not None else -1
89+
) is not None
90+
7691
def add_violation(
7792
self,
7893
node: cst.CSTNode,

torchfix/visitors/misc/__init__.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -59,20 +59,19 @@ class TorchReentrantCheckpointVisitor(TorchVisitor):
5959
]
6060

6161
def visit_Call(self, node):
62-
qualified_name = self.get_qualified_name_for_call(node)
63-
if qualified_name == "torch.utils.checkpoint.checkpoint":
64-
use_reentrant_arg = self.get_specific_arg(node, "use_reentrant", -1)
65-
if use_reentrant_arg is None:
66-
# This codemod maybe unsafe correctness-wise
67-
# if reentrant behavior is actually needed,
68-
# so the changes need to be verified/tested.
69-
use_reentrant_arg = cst.ensure_type(
70-
cst.parse_expression("f(use_reentrant=False)"), cst.Call
71-
).args[0]
72-
replacement = node.with_changes(args=node.args + (use_reentrant_arg,))
73-
self.add_violation(
74-
node,
75-
error_code=self.ERRORS[0].error_code,
76-
message=self.ERRORS[0].message(),
77-
replacement=replacement,
78-
)
62+
if (self.get_qualified_name_for_call(node) ==
63+
"torch.utils.checkpoint.checkpoint" and
64+
not self.has_specific_arg(node, "use_reentrant")):
65+
# This codemod maybe unsafe correctness-wise
66+
# if reentrant behavior is actually needed,
67+
# so the changes need to be verified/tested.
68+
use_reentrant_arg = cst.ensure_type(
69+
cst.parse_expression("f(use_reentrant=False)"), cst.Call
70+
).args[0]
71+
replacement = node.with_changes(args=node.args + (use_reentrant_arg,))
72+
self.add_violation(
73+
node,
74+
error_code=self.ERRORS[0].error_code,
75+
message=self.ERRORS[0].message(),
76+
replacement=replacement,
77+
)

torchfix/visitors/security/__init__.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,30 +22,27 @@ class TorchUnsafeLoadVisitor(TorchVisitor):
2222
]
2323

2424
def visit_Call(self, node):
25-
qualified_name = self.get_qualified_name_for_call(node)
26-
if qualified_name == "torch.load":
27-
weights_only_arg = self.get_specific_arg(node, "weights_only", -1)
28-
if weights_only_arg is None:
29-
# Add `weights_only=True` if there is no `pickle_module`.
30-
# (do not add `weights_only=False` with `pickle_module`, as it
31-
# needs to be an explicit choice).
32-
#
33-
# This codemod is somewhat unsafe correctness-wise
34-
# because full pickling functionality may still be needed
35-
# even without `pickle_module`,
36-
# so the changes need to be verified/tested.
37-
replacement = None
38-
pickle_module_arg = self.get_specific_arg(node, "pickle_module", 2)
39-
if pickle_module_arg is None:
40-
weights_only_arg = cst.ensure_type(
41-
cst.parse_expression("f(weights_only=True)"), cst.Call
42-
).args[0]
43-
replacement = node.with_changes(
44-
args=node.args + (weights_only_arg,)
45-
)
46-
self.add_violation(
47-
node,
48-
error_code=self.ERRORS[0].error_code,
49-
message=self.ERRORS[0].message(),
50-
replacement=replacement,
25+
if self.get_qualified_name_for_call(node) == "torch.load" and \
26+
not self.has_specific_arg(node, "weights_only"):
27+
# Add `weights_only=True` if there is no `pickle_module`.
28+
# (do not add `weights_only=False` with `pickle_module`, as it
29+
# needs to be an explicit choice).
30+
#
31+
# This codemod is somewhat unsafe correctness-wise
32+
# because full pickling functionality may still be needed
33+
# even without `pickle_module`,
34+
# so the changes need to be verified/tested.
35+
replacement = None
36+
if not self.has_specific_arg(node, "pickle_module", 2):
37+
weights_only_arg = cst.ensure_type(
38+
cst.parse_expression("f(weights_only=True)"), cst.Call
39+
).args[0]
40+
replacement = node.with_changes(
41+
args=node.args + (weights_only_arg,)
5142
)
43+
self.add_violation(
44+
node,
45+
error_code=self.ERRORS[0].error_code,
46+
message=self.ERRORS[0].message(),
47+
replacement=replacement,
48+
)

0 commit comments

Comments
 (0)