diff --git a/xdsl_smt/semantics/transfer_semantics.py b/xdsl_smt/semantics/transfer_semantics.py index 1936844e..61159722 100644 --- a/xdsl_smt/semantics/transfer_semantics.py +++ b/xdsl_smt/semantics/transfer_semantics.py @@ -1,14 +1,11 @@ from dataclasses import dataclass -from xdsl.pattern_rewriter import ( - PatternRewriter, -) +from typing import Mapping, Sequence +from xdsl.pattern_rewriter import PatternRewriter from xdsl_smt.dialects import smt_bitvector_dialect as smt_bv from xdsl_smt.dialects import smt_dialect as smt from xdsl_smt.dialects import transfer -from xdsl_smt.passes.lower_to_smt.smt_lowerer import ( - SMTLowerer, -) +from xdsl_smt.passes.lower_to_smt.smt_lowerer import SMTLowerer from xdsl_smt.dialects.smt_utils_dialect import ( AnyPairType, PairType, @@ -19,13 +16,10 @@ from xdsl_smt.dialects.smt_dialect import BoolType from xdsl_smt.semantics.semantics import OperationSemantics, TypeSemantics from xdsl.ir import Operation, SSAValue, Attribute -from typing import Mapping, Sequence from xdsl.utils.hints import isa from xdsl.dialects.builtin import IntegerAttr, IntegerType from xdsl_smt.utils.transfer_to_smt_util import ( get_low_bits, - set_high_bits, - set_low_bits, count_lzeros, count_rzeros, count_lones, @@ -34,8 +28,6 @@ is_non_negative, is_negative, get_high_bits, - clear_high_bits, - clear_low_bits, ) @@ -43,6 +35,14 @@ class AbstractValueTypeSemantics(TypeSemantics): """Lower all types in an abstract value to SMT types But the last element is useless, this makes GetOp easier""" + def lower_type(self, ty: Attribute) -> Attribute: + """ + If the input type is already a smt type, skip lowering + """ + if ty.name.startswith("smt"): + return ty + return SMTLowerer.lower_type(ty) + def get_semantics(self, type: Attribute) -> Attribute: assert isinstance(type, transfer.AbstractValueType) or isinstance( type, transfer.TupleType @@ -266,6 +266,7 @@ def get_semantics( bv_res, ops = smt_bool_to_bv1(umul_overflow.res) poison_op = smt.ConstantBoolOp(False) + res = PairOp(bv_res, poison_op.result) rewriter.insert_op_before_matched_op([umul_overflow] + ops + [poison_op, res]) return ((res.res,), effect_state) @@ -639,9 +640,27 @@ def get_semantics( effect_state: SSAValue | None, rewriter: PatternRewriter, ) -> tuple[Sequence[SSAValue], SSAValue | None]: - result = set_high_bits(operands[0], operands[1]) - rewriter.insert_op_before_matched_op(result) - return ((result[-1].results[0],), effect_state) + arg = operands[0] + count = operands[1] + assert isinstance(bv_type := arg.type, smt_bv.BitVectorType) + + const_bw = smt_bv.ConstantOp(bv_type.width, bv_type.width) + const_one = smt_bv.ConstantOp(1, bv_type.width) + + umin = smt_bv.UltOp(count, const_bw.res) + clamped_count = smt.IteOp(umin.res, count, const_bw.res) + + sub = smt_bv.SubOp(const_bw.res, clamped_count.res) + shl = smt_bv.ShlOp(const_one.res, clamped_count.res) + sub2 = smt_bv.SubOp(shl.res, const_one.res) + shl2 = smt_bv.ShlOp(sub2.res, sub.res) + or_op = smt_bv.OrOp(arg, shl2.res) + + rewriter.insert_op_before_matched_op( + [const_bw, const_one, umin, clamped_count, sub, shl, sub2, shl2, or_op] + ) + + return ((or_op.res,), effect_state) class SetLowBitsOpSemantics(OperationSemantics): @@ -653,9 +672,19 @@ def get_semantics( effect_state: SSAValue | None, rewriter: PatternRewriter, ) -> tuple[Sequence[SSAValue], SSAValue | None]: - result = set_low_bits(operands[0], operands[1]) - rewriter.insert_op_before_matched_op(result) - return ((result[-1].results[0],), effect_state) + arg = operands[0] + count = operands[1] + assert isinstance(bv_type := arg.type, smt_bv.BitVectorType) + + const_one = smt_bv.ConstantOp(1, bv_type.width) + + shl = smt_bv.ShlOp(const_one.res, count) + sub = smt_bv.SubOp(shl.res, const_one.res) + or_op = smt_bv.OrOp(arg, sub.res) + + rewriter.insert_op_before_matched_op([const_one, shl, sub, or_op]) + + return ((or_op.res,), effect_state) class SetSignBitOpSemantics(OperationSemantics): @@ -692,7 +721,7 @@ def get_semantics( operand_type = operand.type assert isinstance(operand_type, smt_bv.BitVectorType) width = operand_type.width.data - signed_max_value = smt_bv.ConstantOp(1 << (width - 1) - 1, width) + signed_max_value = smt_bv.ConstantOp((1 << (width - 1)) - 1, width) and_op = smt_bv.AndOp(signed_max_value.res, operand) result = [signed_max_value, and_op] @@ -737,9 +766,27 @@ def get_semantics( effect_state: SSAValue | None, rewriter: PatternRewriter, ) -> tuple[Sequence[SSAValue], SSAValue | None]: - result = clear_high_bits(operands[0], operands[1]) - rewriter.insert_op_before_matched_op(result) - return ((result[-1].results[0],), effect_state) + arg = operands[0] + count = operands[1] + assert isinstance(bv_type := arg.type, smt_bv.BitVectorType) + + const_bw = smt_bv.ConstantOp(bv_type.width, bv_type.width) + one = smt_bv.ConstantOp(1, bv_type.width) + + umin = smt_bv.UltOp(count, const_bw.res) + new_count = smt.IteOp(umin.res, count, const_bw.res) + + # mask = (1 << (width - count)) - 1 + sub = smt_bv.SubOp(const_bw.res, new_count.res) + shl = smt_bv.ShlOp(one.res, sub.res) + mask = smt_bv.SubOp(shl.res, one.res) + masked = smt_bv.AndOp(arg, mask.res) + + rewriter.insert_op_before_matched_op( + [const_bw, one, umin, new_count, sub, shl, mask, masked] + ) + + return ((masked.res,), effect_state) class ClearLowBitsOpSemantics(OperationSemantics): @@ -751,9 +798,21 @@ def get_semantics( effect_state: SSAValue | None, rewriter: PatternRewriter, ) -> tuple[Sequence[SSAValue], SSAValue | None]: - result = clear_low_bits(operands[0], operands[1]) - rewriter.insert_op_before_matched_op(result) - return ((result[-1].results[0],), effect_state) + arg = operands[0] + count = operands[1] + assert isinstance(bv_type := arg.type, smt_bv.BitVectorType) + + const_one = smt_bv.ConstantOp(1, bv_type.width) + + # mask = ~((1 << count) - 1) + shl = smt_bv.ShlOp(const_one.res, count) + sub = smt_bv.SubOp(shl.res, const_one.res) + not_mask = smt_bv.NotOp(sub.res) + masked = smt_bv.AndOp(arg, not_mask.res) + + rewriter.insert_op_before_matched_op([const_one, shl, sub, not_mask, masked]) + + return ((masked.res,), effect_state) class SMinOpSemantics(OperationSemantics): diff --git a/xdsl_smt/utils/transfer_to_smt_util.py b/xdsl_smt/utils/transfer_to_smt_util.py index bb348d41..6b8a561e 100644 --- a/xdsl_smt/utils/transfer_to_smt_util.py +++ b/xdsl_smt/utils/transfer_to_smt_util.py @@ -73,26 +73,6 @@ def get_high_bits_constant(high_bits: SSAValue) -> list[Operation]: return result + get_bits_constant(result[-1].results[0], width_constant) -def clear_low_bits(b: SSAValue, low_bits: SSAValue) -> list[Operation]: - """ - clear_low_bits(x, low_bits) -> x & ~(get_low_bits_constant(low_bits)) - """ - result = get_low_bits_constant(low_bits) - result.append(smt_bv.NotOp(result[-1].results[0])) - result.append(smt_bv.AndOp(result[-1].results[0], b)) - return result - - -def clear_high_bits(b: SSAValue, low_bits: SSAValue) -> list[Operation]: - """ - clear_high_bits(x, high_bits) -> x & ~(get_high_bits_constant(high_bits)) - """ - result = get_high_bits_constant(low_bits) - result.append(smt_bv.NotOp(result[-1].results[0])) - result.append(smt_bv.AndOp(result[-1].results[0], b)) - return result - - def get_low_bits(b: SSAValue, low_bits: SSAValue) -> list[Operation]: """ get_low_bits(x, low_bits) -> x & (get_low_bits_constant(low_bits)) @@ -111,24 +91,6 @@ def get_high_bits(b: SSAValue, low_bits: SSAValue) -> list[Operation]: return result -def set_high_bits(b: SSAValue, high_bits: SSAValue) -> list[Operation]: - """ - set_high_bits(x, high_bits) -> x | (get_high_bits_constant(high_bits)) - """ - result = get_high_bits_constant(high_bits) - result.append(smt_bv.OrOp(result[-1].results[0], b)) - return result - - -def set_low_bits(b: SSAValue, low_bits: SSAValue) -> list[Operation]: - """ - set_low_bits(x, low_bits) -> x | (get_low_bits_constant(low_bits)) - """ - result = get_low_bits_constant(low_bits) - result.append(smt_bv.OrOp(result[-1].results[0], b)) - return result - - def count_ones(b: SSAValue) -> list[Operation]: assert isinstance(b.type, smt_bv.BitVectorType) n = b.type.width.data