Skip to content

Commit

Permalink
add rule to flip lt/gt ops
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper committed Nov 12, 2024
1 parent af25137 commit 1ebc876
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
14 changes: 14 additions & 0 deletions vyper/venom/basicblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@

COMMUTATIVE_INSTRUCTIONS = frozenset(["add", "mul", "smul", "or", "xor", "and", "eq"])

FLIPPABLE_INSTRUCTIONS = ("gt", "lt", "sgt", "slt")

if TYPE_CHECKING:
from vyper.venom.function import IRFunction

Expand Down Expand Up @@ -230,6 +232,10 @@ def is_volatile(self) -> bool:
def is_commutative(self) -> bool:
return self.opcode in COMMUTATIVE_INSTRUCTIONS

@property
def is_flippable(self) -> bool:
return self.opcode in FLIPPABLE_INSTRUCTIONS

@property
def is_bb_terminator(self) -> bool:
return self.opcode in BB_TERMINATORS
Expand Down Expand Up @@ -282,6 +288,14 @@ def get_outputs(self) -> list[IROperand]:
"""
return [self.output] if self.output else []

def flip_operands(self):
assert self.is_flippable
if self.opcode in ("gt", "sgt"):
self.opcode = self.opcode.replace("g", "l")
else:
self.opcode = self.opcode.replace("l", "g")
self.operands.reverse()

def replace_operands(self, replacements: dict) -> None:
"""
Update operands with replacements.
Expand Down
6 changes: 4 additions & 2 deletions vyper/venom/passes/dft.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def run_pass(self) -> None:
self.visited_instructions: OrderedSet[IRInstruction] = OrderedSet()

self.dfg = self.analyses_cache.request_analysis(DFGAnalysis)
basic_blocks = list(self.function.get_basic_blocks())

for bb in self.function.get_basic_blocks():
self._process_basic_block(bb)
Expand Down Expand Up @@ -66,7 +65,7 @@ def _process_instruction_r(self, instructions: list[IRInstruction], inst: IRInst

def key(x):
cost = 0
if x.output in inst.operands and not inst.is_commutative:
if x.output in inst.operands and not inst.is_commutative and not inst.is_flippable:
cost = inst.operands.index(x.output)
return cost - len(self.inst_offspring[x]) * 0.5

Expand All @@ -76,6 +75,9 @@ def key(x):
if inst.is_commutative and children != list(self.ida[inst]):
inst.operands.reverse()

if inst.is_flippable and children != list(self.ida[inst]):
inst.flip_operands()

for dep_inst in children:
self._process_instruction_r(instructions, dep_inst)

Expand Down

0 comments on commit 1ebc876

Please sign in to comment.