Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring jump targets and backedges to update inplace #83

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 33 additions & 60 deletions numba_rvsdg/core/datastructures/basic_block.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dis
from typing import Tuple, Dict, List, Optional
from dataclasses import dataclass, replace, field
from typing import Dict, List, Optional
from dataclasses import dataclass, field

from numba_rvsdg.core.utils import _next_inst_offset
from numba_rvsdg.core.datastructures import block_names
Expand All @@ -26,9 +26,12 @@ class BasicBlock:

name: str

_jump_targets: Tuple[str, ...] = tuple()
_jump_targets: list[str] = field(default_factory=list)
backedges: list[str] = field(default_factory=list)

backedges: Tuple[str, ...] = tuple()
def __post_init__(self) -> None:
assert isinstance(self._jump_targets, list)
assert isinstance(self.backedges, list)

@property
def is_exiting(self) -> bool:
Expand Down Expand Up @@ -58,7 +61,7 @@ def fallthrough(self) -> bool:
return len(self._jump_targets) == 1

@property
def jump_targets(self) -> Tuple[str, ...]:
def jump_targets(self) -> list[str]:
"""Retrieves the jump targets for this block,
excluding any jump targets that are also backedges.

Expand All @@ -73,31 +76,22 @@ def jump_targets(self) -> Tuple[str, ...]:
for j in self._jump_targets:
if j not in self.backedges:
acc.append(j)
return tuple(acc)
return acc

def declare_backedge(self, target: str) -> "BasicBlock":
def declare_backedge(self, target: str) -> None:
"""Declare one of the jump targets as a backedge of this block.

Parameters
----------
target: str
The jump target that is to be declared as a backedge.

Returns
-------
basic_block: BasicBlock
The resulting block.

"""
if target in self.jump_targets:
assert not self.backedges
return replace(self, backedges=(target,))
return self
self.backedges.append(target)

def replace_jump_targets(
self, jump_targets: Tuple[str, ...]
) -> "BasicBlock":
"""Replaces jump targets of this block by the given tuple.
def replace_jump_targets(self, jump_targets: list[str]) -> None:
"""Replaces jump targets of this block by the given list.

This method replaces the jump targets of the current BasicBlock.
The provided jump targets must be in the same order as their
Expand All @@ -109,36 +103,26 @@ def replace_jump_targets(

Parameters
----------
jump_targets: Tuple
The new jump target tuple. Must be ordered.

Returns
-------
basic_block: BasicBlock
The resulting BasicBlock.

jump_targets: List
The new jump target list. Must be ordered.
"""
return replace(self, _jump_targets=jump_targets)
self._jump_targets.clear()
self._jump_targets.extend(jump_targets)

def replace_backedges(self, backedges: Tuple[str, ...]) -> "BasicBlock":
"""Replaces back edges of this block by the given tuple.
def replace_backedges(self, backedges: list[str]) -> None:
"""Replaces back edges of this block by the given list.

This method replaces the back edges of the current BasicBlock.
The provided back edges must be in the same order as their
intended original replacements.

Parameters
----------
backedges: Tuple
The new back edges tuple. Must be ordered.

Returns
-------
basic_block: BasicBlock
The resulting BasicBlock.

backedges: List
The new back edges list. Must be ordered.
"""
return replace(self, backedges=backedges)
self.backedges.clear()
self.backedges.extend(backedges)


@dataclass(frozen=True)
Expand Down Expand Up @@ -271,10 +255,8 @@ class SyntheticBranch(SyntheticBlock):
variable: str = ""
branch_value_table: Dict[int, str] = field(default_factory=lambda: {})

def replace_jump_targets(
self, jump_targets: Tuple[str, ...]
) -> "BasicBlock":
"""Replaces jump targets of this block by the given tuple.
def replace_jump_targets(self, jump_targets: list[str]) -> None:
"""Replaces jump targets of this block by the given list.

This method replaces the jump targets of the current BasicBlock.
The provided jump targets must be in the same order as their
Expand All @@ -287,18 +269,12 @@ def replace_jump_targets(

Parameters
----------
jump_targets: Tuple
The new jump target tuple. Must be ordered.

Returns
-------
basic_block: BasicBlock
The resulting BasicBlock.

jump_targets: List
The new jump target list. Must be ordered.
"""

old_branch_value_table = self.branch_value_table
new_branch_value_table = {}
old_branch_value_table = self.branch_value_table.copy()
self.branch_value_table.clear()
for target in self._jump_targets:
if target not in jump_targets:
# ASSUMPTION: only one jump_target is being updated
Expand All @@ -307,18 +283,15 @@ def replace_jump_targets(
new_target = next(iter(diff))
for k, v in old_branch_value_table.items():
if v == target:
new_branch_value_table[k] = new_target
self.branch_value_table[k] = new_target
else:
# copy all old values
for k, v in old_branch_value_table.items():
if v == target:
new_branch_value_table[k] = v
self.branch_value_table[k] = v

return replace(
self,
_jump_targets=jump_targets,
branch_value_table=new_branch_value_table,
)
self._jump_targets.clear()
self._jump_targets.extend(jump_targets)


@dataclass(frozen=True)
Expand Down
86 changes: 23 additions & 63 deletions numba_rvsdg/core/datastructures/byte_flow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import dis
from copy import deepcopy
from dataclasses import dataclass
from typing import Generator, Callable

Expand Down Expand Up @@ -61,90 +60,51 @@ def from_bytecode(code: Callable) -> "ByteFlow": # type: ignore
scfg = flowinfo.build_basicblocks()
return ByteFlow(bc=bc, scfg=scfg)

def _join_returns(self) -> "ByteFlow":
def _join_returns(self) -> None:
"""Joins the return blocks within the corresponding SCFG.

This method creates a deep copy of the SCFG and performs
operation to join return blocks within the control flow.
It returns a new ByteFlow object with the updated SCFG.

Returns
-------
byteflow: ByteFlow
The new ByteFlow object with updated SCFG.
This method performs operation to join return blocks within
the control flow.
"""
scfg = deepcopy(self.scfg)
scfg.join_returns()
return ByteFlow(bc=self.bc, scfg=scfg)
self.scfg.join_returns()

def _restructure_loop(self) -> "ByteFlow":
def _restructure_loop(self) -> None:
"""Restructures the loops within the corresponding SCFG.

Creates a deep copy of the SCFG and performs the operation to
restructure loop constructs within the control flow using
the algorithm LOOP RESTRUCTURING from section 4.1 of Bahmann2015.
Performs the operation to restructure loop constructs within
the control flow using the algorithm LOOP RESTRUCTURING from
section 4.1 of Bahmann2015.
It applies the restructuring operation to both the main SCFG
and any subregions within it. It returns a new ByteFlow object
with the updated SCFG.

Returns
-------
byteflow: ByteFlow
The new ByteFlow object with updated SCFG.
and any subregions within it.
"""
scfg = deepcopy(self.scfg)
restructure_loop(scfg.region)
for region in _iter_subregions(scfg):
restructure_loop(self.scfg.region)
for region in _iter_subregions(self.scfg):
restructure_loop(region)
return ByteFlow(bc=self.bc, scfg=scfg)

def _restructure_branch(self) -> "ByteFlow":
def _restructure_branch(self) -> None:
"""Restructures the branches within the corresponding SCFG.

Creates a deep copy of the SCFG and performs the operation to
restructure branch constructs within the control flow. It applies
the restructuring operation to both the main SCFG and any
subregions within it. It returns a new ByteFlow object with
the updated SCFG.

Returns
-------
byteflow: ByteFlow
The new ByteFlow object with updated SCFG.
This method applies restructuring branch operation to both
the main SCFG and any subregions within it.
"""
scfg = deepcopy(self.scfg)
restructure_branch(scfg.region)
for region in _iter_subregions(scfg):
restructure_branch(self.scfg.region)
for region in _iter_subregions(self.scfg):
restructure_branch(region)
return ByteFlow(bc=self.bc, scfg=scfg)

def restructure(self) -> "ByteFlow":
def restructure(self) -> None:
"""Applies join_returns, restructure_loop and restructure_branch
in the respective order on the SCFG.

Creates a deep copy of the SCFG and applies a series of
restructuring operations to it. The operations include
joining return blocks, restructuring loop constructs, and
restructuring branch constructs. It returns a new ByteFlow
object with the updated SCFG.

Returns
-------
byteflow: ByteFlow
The new ByteFlow object with updated SCFG.
Applies a series of restructuring operations to given SCFG.
The operations include joining return blocks, restructuring
loop constructs, and restructuring branch constructs.
"""
scfg = deepcopy(self.scfg)
# close
scfg.join_returns()
self._join_returns()
# handle loop
restructure_loop(scfg.region)
for region in _iter_subregions(scfg):
restructure_loop(region)
self._restructure_loop()
# handle branch
restructure_branch(scfg.region)
for region in _iter_subregions(scfg):
restructure_branch(region)
return ByteFlow(bc=self.bc, scfg=scfg)
self._restructure_branch()


def _iter_subregions(scfg: SCFG) -> Generator[RegionBlock, SCFG, None]:
Expand Down
8 changes: 4 additions & 4 deletions numba_rvsdg/core/datastructures/flow_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,19 @@ def build_basicblocks(

for begin, end in zip(offsets, [*offsets[1:], end_offset]):
name = names[begin]
targets: Tuple[str, ...]
targets: list[str]
term_offset = _prev_inst_offset(end)
if term_offset not in self.jump_insts:
# implicit jump
targets = (names[end],)
targets = [names[end]]
else:
targets = tuple(names[o] for o in self.jump_insts[term_offset])
targets = [names[o] for o in self.jump_insts[term_offset]]
block = PythonBytecodeBlock(
name=name,
begin=begin,
end=end,
_jump_targets=targets,
backedges=(),
backedges=[],
)
scfg.add_block(block)
return scfg
Loading