Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 0932cc9

Browse files
authoredOct 30, 2018
Merge pull request #179 from bcaller/ifexp
Better handling of IfExp (ternary)
2 parents 5d7a94b + 2e4f8c9 commit 0932cc9

File tree

7 files changed

+152
-1
lines changed

7 files changed

+152
-1
lines changed
 

‎examples/example_inputs/ternary.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
result = (
2+
"abc"
3+
if t.u == v.w else
4+
"def"
5+
if x else
6+
y # This is the only RHS variable which taints result
7+
if func(z if 1 + 1 == 2 else z) else
8+
"ghi"
9+
)

‎pyt/core/transformer.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,80 @@ def visit_Return(self, node):
6464
return self.visit_chain(node)
6565

6666

67-
class PytTransformer(AsyncTransformer, ChainedFunctionTransformer, ast.NodeTransformer):
67+
class IfExpRewriter(ast.NodeTransformer):
68+
"""Splits IfExp ternary expressions containing complex tests into multiple statements
69+
70+
Will change
71+
72+
a if b(c) else d
73+
74+
into
75+
76+
a if __if_exp_0 else d
77+
78+
with Assign nodes in assignments [__if_exp_0 = b(c)]
79+
"""
80+
81+
def __init__(self, starting_index=0):
82+
self._temporary_variable_index = starting_index
83+
self.assignments = []
84+
super().__init__()
85+
86+
def visit_IfExp(self, node):
87+
if isinstance(node.test, (ast.Name, ast.Attribute)):
88+
return self.generic_visit(node)
89+
else:
90+
temp_var_id = '__if_exp_{}'.format(self._temporary_variable_index)
91+
self._temporary_variable_index += 1
92+
assignment_of_test = ast.Assign(
93+
targets=[ast.Name(id=temp_var_id, ctx=ast.Store())],
94+
value=self.visit(node.test),
95+
)
96+
ast.copy_location(assignment_of_test, node)
97+
self.assignments.append(assignment_of_test)
98+
transformed_if_exp = ast.IfExp(
99+
test=ast.Name(id=temp_var_id, ctx=ast.Load()),
100+
body=self.visit(node.body),
101+
orelse=self.visit(node.orelse),
102+
)
103+
ast.copy_location(transformed_if_exp, node)
104+
return transformed_if_exp
105+
106+
def visit_FunctionDef(self, node):
107+
return node
108+
109+
110+
class IfExpTransformer:
111+
"""Goes through module and function bodies, adding extra Assign nodes due to IfExp expressions."""
112+
113+
def visit_body(self, nodes):
114+
new_nodes = []
115+
count = 0
116+
for node in nodes:
117+
rewriter = IfExpRewriter(count)
118+
possibly_transformed_node = rewriter.visit(node)
119+
if rewriter.assignments:
120+
new_nodes.extend(rewriter.assignments)
121+
count += len(rewriter.assignments)
122+
new_nodes.append(possibly_transformed_node)
123+
return new_nodes
124+
125+
def visit_FunctionDef(self, node):
126+
transformed = ast.FunctionDef(
127+
name=node.name,
128+
args=node.args,
129+
body=self.visit_body(node.body),
130+
decorator_list=node.decorator_list,
131+
returns=node.returns
132+
)
133+
ast.copy_location(transformed, node)
134+
return self.generic_visit(transformed)
135+
136+
def visit_Module(self, node):
137+
transformed = ast.Module(self.visit_body(node.body))
138+
ast.copy_location(transformed, node)
139+
return self.generic_visit(transformed)
140+
141+
142+
class PytTransformer(AsyncTransformer, IfExpTransformer, ChainedFunctionTransformer, ast.NodeTransformer):
68143
pass

‎pyt/helper_visitors/label_visitor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,3 +324,12 @@ def visit_FormattedValue(self, node):
324324
def visit_Starred(self, node):
325325
self.result += '*'
326326
self.visit(node.value)
327+
328+
def visit_IfExp(self, node):
329+
self.result += '('
330+
self.visit(node.test)
331+
self.result += ') ? ('
332+
self.visit(node.body)
333+
self.result += ') : ('
334+
self.visit(node.orelse)
335+
self.result += ')'

‎pyt/helper_visitors/right_hand_side_visitor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ def visit_Call(self, node):
2222
for keyword in node.keywords:
2323
self.visit(keyword)
2424

25+
def visit_IfExp(self, node):
26+
# The test doesn't taint the assignment
27+
self.visit(node.body)
28+
self.visit(node.orelse)
29+
2530
@classmethod
2631
def result_for_node(cls, node):
2732
visitor = cls()

‎tests/cfg/cfg_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,39 @@ def test_if_not(self):
580580
(_exit, _if)
581581
])
582582

583+
def test_ternary_ifexp(self):
584+
self.cfg_create_from_file('examples/example_inputs/ternary.py')
585+
586+
# entry = 0
587+
tmp_if_1 = 1
588+
# tmp_if_inner = 2
589+
call = 3
590+
# tmp_if_call = 4
591+
actual_if_exp = 5
592+
exit = 6
593+
594+
self.assert_length(self.cfg.nodes, expected_length=exit + 1)
595+
self.assertInCfg([
596+
(i + 1, i) for i in range(exit)
597+
])
598+
599+
self.assertCountEqual(
600+
self.cfg.nodes[actual_if_exp].right_hand_side_variables,
601+
['y'],
602+
"The variables in the test expressions shouldn't appear as RHS variables"
603+
)
604+
605+
self.assertCountEqual(
606+
self.cfg.nodes[tmp_if_1].right_hand_side_variables,
607+
['t', 'v'],
608+
)
609+
610+
self.assertIn(
611+
'ret_func(',
612+
self.cfg.nodes[call].label,
613+
"Function calls inside the test expressions should still appear in the CFG",
614+
)
615+
583616

584617
class CFGWhileTest(CFGBaseTestCase):
585618

‎tests/core/transformer_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,19 @@ def test_chained_function(self):
5252

5353
transformed = PytTransformer().visit(chained_tree)
5454
self.assertEqual(ast.dump(transformed), ast.dump(separated_tree))
55+
56+
def test_if_exp(self):
57+
complex_if_exp_tree = ast.parse("\n".join([
58+
"def a():",
59+
" b = c if d.e(f) else g if h else i if j.k(l) else m",
60+
]))
61+
62+
separated_tree = ast.parse("\n".join([
63+
"def a():",
64+
" __if_exp_0 = d.e(f)",
65+
" __if_exp_1 = j.k(l)",
66+
" b = c if __if_exp_0 else g if h else i if __if_exp_1 else m",
67+
]))
68+
69+
transformed = PytTransformer().visit(complex_if_exp_tree)
70+
self.assertEqual(ast.dump(transformed), ast.dump(separated_tree))

‎tests/helper_visitors/label_visitor_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,7 @@ def test_joined_str_with_format_spec(self):
8383
def test_starred(self):
8484
label = self.perform_labeling_on_expression('[a, *b] = *c, d')
8585
self.assertEqual(label.result, '[a, *b] = (*c, d)')
86+
87+
def test_if_exp(self):
88+
label = self.perform_labeling_on_expression('a = b if c else d')
89+
self.assertEqual(label.result, 'a = (c) ? (b) : (d)')

0 commit comments

Comments
 (0)
Please sign in to comment.