Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SSA] Ensure phi args are only added when defined. #118

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/test/to_ssa/if-const.bril
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# the order that basic blocks get renamed.
.zexit:
print a;
}
}
2 changes: 1 addition & 1 deletion examples/test/to_ssa/if-const.out
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
b.0: int = const 1;
jmp .zexit;
.zexit:
b.1: int = phi b.0 b.0 .false .true;
b.1: int = phi b.0 __undefined .false .true;
a.1: int = phi __undefined a.0 .false .true;
print a.1;
ret;
Expand Down
20 changes: 20 additions & 0 deletions examples/test/to_ssa/if-double-overwrite.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
@main(foo: bool) {
.entry:
cond: bool = const true;
b: int = const 42;
br cond .true .false;
.true:
b: int = const 1;
br foo .left .right;
.false:
b: int = const 0;
jmp .zexit;
.left:
b: int = const 2;
jmp .zexit;
.right:
b: int = const 3;
jmp .zexit;
.zexit:
print b;
}
22 changes: 22 additions & 0 deletions examples/test/to_ssa/if-double-overwrite.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
@main(foo: bool) {
.entry:
cond.0: bool = const true;
b.0: int = const 42;
br cond.0 .true .false;
.true:
b.2: int = const 1;
br foo .left .right;
.false:
b.1: int = const 0;
jmp .zexit;
.left:
b.3: int = const 2;
jmp .zexit;
.right:
b.4: int = const 3;
jmp .zexit;
.zexit:
b.5: int = phi b.1 b.3 b.4 .false .left .right;
print b.5;
ret;
}
16 changes: 16 additions & 0 deletions examples/test/to_ssa/if-overwrite.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
@main() {
.entry:
cond: bool = const true;
# This is overwritten in both the `.true` and `.false` branches,
# so should not be passed to the phi in `.zexit`.
b: int = const 42;
br cond .true .false;
.true:
b: int = const 1;
jmp .zexit;
.false:
b: int = const 0;
jmp .zexit;
.zexit:
print b;
}
16 changes: 16 additions & 0 deletions examples/test/to_ssa/if-overwrite.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
@main {
.entry:
cond.0: bool = const true;
b.0: int = const 42;
br cond.0 .true .false;
.true:
b.2: int = const 1;
jmp .zexit;
.false:
b.1: int = const 0;
jmp .zexit;
.zexit:
b.3: int = phi b.1 b.2 .false .true;
print b.3;
ret;
}
14 changes: 14 additions & 0 deletions examples/test/to_ssa/if-predecessor.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
@main() {
.entry:
cond: bool = const true;
b: int = const 42;
br cond .true .false;
.true:
# `b` is defined in predecessor `.entry`, so it should be passed along to the phi in `.zexit`.
jmp .zexit;
.false:
b: int = const 1;
jmp .zexit;
.zexit:
print b;
}
15 changes: 15 additions & 0 deletions examples/test/to_ssa/if-predecessor.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
@main {
.entry:
cond.0: bool = const true;
b.0: int = const 42;
br cond.0 .true .false;
.true:
jmp .zexit;
.false:
b.1: int = const 1;
jmp .zexit;
.zexit:
b.2: int = phi b.1 b.0 .false .true;
print b.2;
ret;
}
20 changes: 20 additions & 0 deletions examples/test/to_ssa/if-reaching.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
@main(cond: bool) {
.entry:
# Can be reached through `.left`, so it should be within a phi in `.zexit`.
a: int = const 47;

# Can be reached through `.right`, so it should be within a phi in `.zexit`.
b: int = const 42;
br cond .left .right;
.left:
b: int = const 1;
c: int = const 5;
jmp .zexit;
.right:
a: int = const 2;
c: int = const 10;
jmp .zexit;
.zexit:
d: int = sub a c;
print d;
}
21 changes: 21 additions & 0 deletions examples/test/to_ssa/if-reaching.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
@main(cond: bool) {
.entry:
a.0: int = const 47;
b.0: int = const 42;
br cond .left .right;
.left:
b.1: int = const 1;
c.0: int = const 5;
jmp .zexit;
.right:
a.1: int = const 2;
c.1: int = const 10;
jmp .zexit;
.zexit:
c.2: int = phi c.0 c.1 .left .right;
b.2: int = phi b.1 b.0 .left .right;
a.2: int = phi a.0 a.1 .left .right;
d.0: int = sub a.2 c.2;
print d.0;
ret;
}
2 changes: 1 addition & 1 deletion examples/test/to_ssa/if-ssa.bril
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
.zexit:
a.4: int = phi .left a.2 .right a.3;
print a.4;
}
}
2 changes: 1 addition & 1 deletion examples/test/to_ssa/if-ssa.out
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
jmp .zexit;
.zexit:
a.3.1: int = phi __undefined a.3.0 .left .right;
a.2.1: int = phi a.2.0 a.2.0 .left .right;
a.2.1: int = phi a.2.0 __undefined .left .right;
a.4.0: int = phi a.2.1 a.3.1 .left .right;
print a.4.0;
ret;
Expand Down
15 changes: 15 additions & 0 deletions examples/test/to_ssa/if-successor.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
@main() {
.entry:
cond: bool = const true;
b: int = const 42;
br cond .true .false;
.true:
b: int = const 1;
jmp .zexit;
.false:
b: int = const 2;
jmp .zexit;
.zexit:
b: int = const 3;
print b;
}
17 changes: 17 additions & 0 deletions examples/test/to_ssa/if-successor.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
@main {
.entry:
cond.0: bool = const true;
b.0: int = const 42;
br cond.0 .true .false;
.true:
b.2: int = const 1;
jmp .zexit;
.false:
b.1: int = const 2;
jmp .zexit;
.zexit:
b.3: int = phi b.1 b.2 .false .true;
b.4: int = const 3;
print b.4;
ret;
}
10 changes: 10 additions & 0 deletions examples/test/to_ssa/jmp.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
@main() {
.entry:
cond: bool = const true;
# No `phi` needed here since there is no branching.
b: int = const 42;
jmp .exit;
.exit:
b: int = const 0;
print b;
}
10 changes: 10 additions & 0 deletions examples/test/to_ssa/jmp.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
@main {
.entry:
cond.0: bool = const true;
b.0: int = const 42;
jmp .exit;
.exit:
b.1: int = const 0;
print b.1;
ret;
}
63 changes: 53 additions & 10 deletions examples/to_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,53 @@
import sys
from collections import defaultdict

from cfg import block_map, successors, add_terminators, add_entry, reassemble
from cfg import block_map, successors, add_terminators, add_entry, reassemble, edges
from form_blocks import form_blocks
from dom import get_dom, dom_fronts, dom_tree, map_inv


def defined_analysis(init, blocks, pred, succ):
"""Accumulates all currently-defined variables at
the beginning and end of each block. Returns a
mapping:
block -> {v1, v2, ...}
where `v1, v2, ...` are defined in block.

This is done for both the entrance and exit
of a block, resulting in two mappings.
"""

def merge(_initial, _pred, _out):
return set().union(_initial, *[_out[name] for name in _pred])

def transfer(_in, _block):
return _in.union(definitions[_block])

# Initialize.
block_names = list(blocks.keys())
in_ = {name: {} for name in block_names}
out_ = {name: {} for name in block_names}

entry = block_names[0]
for a in init:
in_[entry][a] = {entry}

# Mapping from block name to a set of definitions within the block.
definitions = {n: set(i['dest'] for i in blocks[n] if 'dest' in i)
for n in block_names}

worklist = [name for name in block_names]
while worklist:
b = worklist.pop()
in_[b] = merge(in_[b], pred[b], out_)
copy = out_[b].copy()
out_[b] = transfer(in_[b], b)
if out_[b] == copy:
continue
worklist.extend(succ[b])
return in_, out_


def def_blocks(blocks):
"""Get a map from variable names to defining blocks.
"""
Expand All @@ -20,7 +62,6 @@ def def_blocks(blocks):

def get_phis(blocks, df, defs):
"""Find where to insert phi-nodes in the blocks.

Produce a map from block names to variable names that need phi-nodes
in those blocks. (We will need to generate names and actually insert
instructions later.)
Expand All @@ -39,10 +80,12 @@ def get_phis(blocks, df, defs):
return phis


def ssa_rename(blocks, phis, succ, domtree, args):
def ssa_rename(blocks, phis, pred, succ, domtree, args):
stack = defaultdict(list, {v: [v] for v in args})
phi_args = {b: {p: [] for p in phis[b]} for b in blocks}
phi_dests = {b: {p: None for p in phis[b]} for b in blocks}

_, defined_out = defined_analysis(args, blocks, pred, succ)
counters = defaultdict(int)

def _push_fresh(var):
Expand Down Expand Up @@ -72,10 +115,12 @@ def _rename(block):
# Rename phi-node arguments (in successors).
for s in succ[block]:
for p in phis[s]:
if stack[p]:
# We want to ensure that `p` is defined in the
# path of predecessors, or in the current block.
if p in defined_out[block] and stack[p]:
phi_args[s][p].append((block, stack[p][0]))
else:
# The variable is not defined on this path
# The variable is not defined on this path.
phi_args[s][p].append((block, "__undefined"))

# Recursive calls.
Expand All @@ -91,7 +136,6 @@ def _rename(block):
return phi_args, phi_dests



def insert_phis(blocks, phi_args, phi_dests, types):
for block, instrs in blocks.items():
for dest, pairs in sorted(phi_args[block].items()):
Expand Down Expand Up @@ -120,8 +164,7 @@ def func_to_ssa(func):
blocks = block_map(form_blocks(func['instrs']))
add_entry(blocks)
add_terminators(blocks)
succ = {name: successors(block[-1]) for name, block in blocks.items()}
pred = map_inv(succ)
pred, succ = edges(blocks)
dom = get_dom(succ, list(blocks.keys())[0])

df = dom_fronts(dom, succ)
Expand All @@ -130,8 +173,8 @@ def func_to_ssa(func):
arg_names = {a['name'] for a in func['args']} if 'args' in func else set()

phis = get_phis(blocks, df, defs)
phi_args, phi_dests = ssa_rename(blocks, phis, succ, dom_tree(dom),
arg_names)
phi_args, phi_dests = ssa_rename(blocks, phis, pred, succ,
dom_tree(dom), arg_names)
insert_phis(blocks, phi_args, phi_dests, types)

func['instrs'] = reassemble(blocks)
Expand Down