diff --git a/tests/filecheck/lower-to-smt/transfer-to-smt/popcount.mlir b/tests/filecheck/lower-to-smt/transfer-to-smt/popcount.mlir new file mode 100644 index 00000000..e35745cc --- /dev/null +++ b/tests/filecheck/lower-to-smt/transfer-to-smt/popcount.mlir @@ -0,0 +1,42 @@ +// RUN: xdsl-smt %s -p=lower-to-smt,canonicalize,dce | filecheck %s +// RUN: xdsl-smt %s -p=lower-to-smt,lower-effects,canonicalize,dce,merge-func-results -t=smt | z3 -in + +"builtin.module"() ({ + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.popcount"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "test", "function_type" = (!transfer.integer) -> !transfer.integer, "sym_visibility" = "private"} : () -> () +}) : () -> () + +// CHECK: builtin.module { +// CHECK-NEXT: %0 = "smt.define_fun"() ({ +// CHECK-NEXT: ^0(%x : !smt.bv<8>, %1 : !effect.state): +// CHECK-NEXT: %2 = smt.bv.constant #smt.bv<0> : !smt.bv<8> +// CHECK-NEXT: %3 = "smt.bv.extract"(%x) {start = #builtin.int<0>, end = #builtin.int<0>} : (!smt.bv<8>) -> !smt.bv<1> +// CHECK-NEXT: %4 = "smt.bv.zero_extend"(%3) : (!smt.bv<1>) -> !smt.bv<8> +// CHECK-NEXT: %5 = "smt.bv.add"(%2, %4) : (!smt.bv<8>, !smt.bv<8>) -> !smt.bv<8> +// CHECK-NEXT: %6 = "smt.bv.extract"(%x) {start = #builtin.int<1>, end = #builtin.int<1>} : (!smt.bv<8>) -> !smt.bv<1> +// CHECK-NEXT: %7 = "smt.bv.zero_extend"(%6) : (!smt.bv<1>) -> !smt.bv<8> +// CHECK-NEXT: %8 = "smt.bv.add"(%5, %7) : (!smt.bv<8>, !smt.bv<8>) -> !smt.bv<8> +// CHECK-NEXT: %9 = "smt.bv.extract"(%x) {start = #builtin.int<2>, end = #builtin.int<2>} : (!smt.bv<8>) -> !smt.bv<1> +// CHECK-NEXT: %10 = "smt.bv.zero_extend"(%9) : (!smt.bv<1>) -> !smt.bv<8> +// CHECK-NEXT: %11 = "smt.bv.add"(%8, %10) : (!smt.bv<8>, !smt.bv<8>) -> !smt.bv<8> +// CHECK-NEXT: %12 = "smt.bv.extract"(%x) {start = #builtin.int<3>, end = #builtin.int<3>} : (!smt.bv<8>) -> !smt.bv<1> +// CHECK-NEXT: %13 = "smt.bv.zero_extend"(%12) : (!smt.bv<1>) -> !smt.bv<8> +// CHECK-NEXT: %14 = "smt.bv.add"(%11, %13) : (!smt.bv<8>, !smt.bv<8>) -> !smt.bv<8> +// CHECK-NEXT: %15 = "smt.bv.extract"(%x) {start = #builtin.int<4>, end = #builtin.int<4>} : (!smt.bv<8>) -> !smt.bv<1> +// CHECK-NEXT: %16 = "smt.bv.zero_extend"(%15) : (!smt.bv<1>) -> !smt.bv<8> +// CHECK-NEXT: %17 = "smt.bv.add"(%14, %16) : (!smt.bv<8>, !smt.bv<8>) -> !smt.bv<8> +// CHECK-NEXT: %18 = "smt.bv.extract"(%x) {start = #builtin.int<5>, end = #builtin.int<5>} : (!smt.bv<8>) -> !smt.bv<1> +// CHECK-NEXT: %19 = "smt.bv.zero_extend"(%18) : (!smt.bv<1>) -> !smt.bv<8> +// CHECK-NEXT: %20 = "smt.bv.add"(%17, %19) : (!smt.bv<8>, !smt.bv<8>) -> !smt.bv<8> +// CHECK-NEXT: %21 = "smt.bv.extract"(%x) {start = #builtin.int<6>, end = #builtin.int<6>} : (!smt.bv<8>) -> !smt.bv<1> +// CHECK-NEXT: %22 = "smt.bv.zero_extend"(%21) : (!smt.bv<1>) -> !smt.bv<8> +// CHECK-NEXT: %23 = "smt.bv.add"(%20, %22) : (!smt.bv<8>, !smt.bv<8>) -> !smt.bv<8> +// CHECK-NEXT: %24 = "smt.bv.extract"(%x) {start = #builtin.int<7>, end = #builtin.int<7>} : (!smt.bv<8>) -> !smt.bv<1> +// CHECK-NEXT: %25 = "smt.bv.zero_extend"(%24) : (!smt.bv<1>) -> !smt.bv<8> +// CHECK-NEXT: %r = "smt.bv.add"(%23, %25) : (!smt.bv<8>, !smt.bv<8>) -> !smt.bv<8> +// CHECK-NEXT: "smt.return"(%r, %1) : (!smt.bv<8>, !effect.state) -> () +// CHECK-NEXT: }) {fun_name = "test"} : () -> ((!smt.bv<8>, !effect.state) -> (!smt.bv<8>, !effect.state)) +// CHECK-NEXT: } diff --git a/tests/filecheck/lower-to-smt/transfer-to-smt/ssub_overflow.mlir b/tests/filecheck/lower-to-smt/transfer-to-smt/ssub_overflow.mlir new file mode 100644 index 00000000..55d30517 --- /dev/null +++ b/tests/filecheck/lower-to-smt/transfer-to-smt/ssub_overflow.mlir @@ -0,0 +1,23 @@ +// RUN: xdsl-smt %s -p=lower-to-smt,canonicalize,dce | filecheck %s +// RUN: xdsl-smt %s -p=lower-to-smt,lower-effects,canonicalize,dce,merge-func-results -t=smt | z3 -in + +"builtin.module"() ({ + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.ssub_overflow"(%x, %y) : (!transfer.integer,!transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "test", "function_type" = (!transfer.integer, !transfer.integer) -> i1, "sym_visibility" = "private"} : () -> () +}) : () -> () + +// CHECK: builtin.module { +// CHECK-NEXT: %0 = "smt.define_fun"() ({ +// CHECK-NEXT: ^0(%x : !smt.bv<8>, %y : !smt.bv<8>, %1 : !effect.state): +// CHECK-NEXT: %2 = "smt.bv.ssubo"(%x, %y) : (!smt.bv<8>, !smt.bv<8>) -> !smt.bool +// CHECK-NEXT: %3 = smt.bv.constant #smt.bv<1> : !smt.bv<1> +// CHECK-NEXT: %4 = smt.bv.constant #smt.bv<0> : !smt.bv<1> +// CHECK-NEXT: %5 = "smt.ite"(%2, %3, %4) : (!smt.bool, !smt.bv<1>, !smt.bv<1>) -> !smt.bv<1> +// CHECK-NEXT: %6 = smt.constant false +// CHECK-NEXT: %r = "smt.utils.pair"(%5, %6) : (!smt.bv<1>, !smt.bool) -> !smt.utils.pair, !smt.bool> +// CHECK-NEXT: "smt.return"(%r, %1) : (!smt.utils.pair, !smt.bool>, !effect.state) -> () +// CHECK-NEXT: }) {fun_name = "test"} : () -> ((!smt.bv<8>, !smt.bv<8>, !effect.state) -> (!smt.utils.pair, !smt.bool>, !effect.state)) +// CHECK-NEXT: } diff --git a/tests/filecheck/lower-to-smt/transfer-to-smt/usub_overflow.mlir b/tests/filecheck/lower-to-smt/transfer-to-smt/usub_overflow.mlir new file mode 100644 index 00000000..33269f24 --- /dev/null +++ b/tests/filecheck/lower-to-smt/transfer-to-smt/usub_overflow.mlir @@ -0,0 +1,23 @@ +// RUN: xdsl-smt %s -p=lower-to-smt,canonicalize,dce | filecheck %s +// RUN: xdsl-smt %s -p=lower-to-smt,lower-effects,canonicalize,dce,merge-func-results -t=smt | z3 -in + +"builtin.module"() ({ + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.usub_overflow"(%x, %y) : (!transfer.integer,!transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "test", "function_type" = (!transfer.integer, !transfer.integer) -> i1, "sym_visibility" = "private"} : () -> () +}) : () -> () + +// CHECK: builtin.module { +// CHECK-NEXT: %0 = "smt.define_fun"() ({ +// CHECK-NEXT: ^0(%x : !smt.bv<8>, %y : !smt.bv<8>, %1 : !effect.state): +// CHECK-NEXT: %2 = "smt.bv.usubo"(%x, %y) : (!smt.bv<8>, !smt.bv<8>) -> !smt.bool +// CHECK-NEXT: %3 = smt.bv.constant #smt.bv<1> : !smt.bv<1> +// CHECK-NEXT: %4 = smt.bv.constant #smt.bv<0> : !smt.bv<1> +// CHECK-NEXT: %5 = "smt.ite"(%2, %3, %4) : (!smt.bool, !smt.bv<1>, !smt.bv<1>) -> !smt.bv<1> +// CHECK-NEXT: %6 = smt.constant false +// CHECK-NEXT: %r = "smt.utils.pair"(%5, %6) : (!smt.bv<1>, !smt.bool) -> !smt.utils.pair, !smt.bool> +// CHECK-NEXT: "smt.return"(%r, %1) : (!smt.utils.pair, !smt.bool>, !effect.state) -> () +// CHECK-NEXT: }) {fun_name = "test"} : () -> ((!smt.bv<8>, !smt.bv<8>, !effect.state) -> (!smt.utils.pair, !smt.bool>, !effect.state)) +// CHECK-NEXT: } diff --git a/xdsl_smt/dialects/smt_bitvector_dialect.py b/xdsl_smt/dialects/smt_bitvector_dialect.py index 624745d7..8b41ec0d 100644 --- a/xdsl_smt/dialects/smt_bitvector_dialect.py +++ b/xdsl_smt/dialects/smt_bitvector_dialect.py @@ -707,6 +707,26 @@ def op_name(self) -> str: return "bvsaddo" +@irdl_op_definition +class UsubOverflowOp(BinaryPredBVOp, SimpleSMTLibOp): + name = "smt.bv.usubo" + + traits = traits_def(traits.Pure()) + + def op_name(self) -> str: + return "bvusubo" + + +@irdl_op_definition +class SsubOverflowOp(BinaryPredBVOp, SimpleSMTLibOp): + name = "smt.bv.ssubo" + + traits = traits_def(traits.Pure()) + + def op_name(self) -> str: + return "bvssubo" + + @irdl_op_definition class UmulOverflowOp(BinaryPredBVOp, SimpleSMTLibOp): """ @@ -909,6 +929,8 @@ def print_expr_to_smtlib(self, stream: IO[str], ctx: SMTConversionCtx) -> None: NegOverflowOp, UaddOverflowOp, SaddOverflowOp, + UsubOverflowOp, + SsubOverflowOp, UmulOverflowOp, SmulOverflowOp, UmulNoOverflowOp, diff --git a/xdsl_smt/dialects/transfer.py b/xdsl_smt/dialects/transfer.py index 28c89f68..18f5ff56 100644 --- a/xdsl_smt/dialects/transfer.py +++ b/xdsl_smt/dialects/transfer.py @@ -287,6 +287,16 @@ class SAddOverflowOp(PredicateOp): name = "transfer.sadd_overflow" +@irdl_op_definition +class USubOverflowOp(PredicateOp): + name = "transfer.usub_overflow" + + +@irdl_op_definition +class SSubOverflowOp(PredicateOp): + name = "transfer.ssub_overflow" + + @irdl_op_definition class AndOp(BinOp): name = "transfer.and" @@ -337,6 +347,11 @@ class CountROneOp(UnaryOp): name = "transfer.countr_one" +@irdl_op_definition +class PopCountOp(UnaryOp): + name = "transfer.popcount" + + @irdl_op_definition class SMinOp(BinOp): name = "transfer.smin" @@ -810,6 +825,7 @@ class GetSignedMinValueOp(UnaryOp): CountLZeroOp, CountROneOp, CountRZeroOp, + PopCountOp, SetHighBitsOp, SetLowBitsOp, SetSignBitOp, @@ -827,6 +843,8 @@ class GetSignedMinValueOp(UnaryOp): SMulOverflowOp, UAddOverflowOp, SAddOverflowOp, + USubOverflowOp, + SSubOverflowOp, UShlOverflowOp, SShlOverflowOp, SelectOp, diff --git a/xdsl_smt/semantics/transfer_semantics.py b/xdsl_smt/semantics/transfer_semantics.py index 1936844e..c9e8d845 100644 --- a/xdsl_smt/semantics/transfer_semantics.py +++ b/xdsl_smt/semantics/transfer_semantics.py @@ -325,6 +325,42 @@ def get_semantics( return ((res.res,), effect_state) +class USubOverflowOpSemantics(OperationSemantics): + def get_semantics( + self, + operands: Sequence[SSAValue], + results: Sequence[Attribute], + attributes: Mapping[str, Attribute | SSAValue], + effect_state: SSAValue | None, + rewriter: PatternRewriter, + ) -> tuple[Sequence[SSAValue], SSAValue | None]: + usub_overflow = smt_bv.UsubOverflowOp(operands[0], operands[1]) + bv_res, ops = smt_bool_to_bv1(usub_overflow.res) + + poison_op = smt.ConstantBoolOp(False) + res = PairOp(bv_res, poison_op.result) + rewriter.insert_op_before_matched_op([usub_overflow] + ops + [poison_op, res]) + return ((res.res,), effect_state) + + +class SSubOverflowOpSemantics(OperationSemantics): + def get_semantics( + self, + operands: Sequence[SSAValue], + results: Sequence[Attribute], + attributes: Mapping[str, Attribute | SSAValue], + effect_state: SSAValue | None, + rewriter: PatternRewriter, + ) -> tuple[Sequence[SSAValue], SSAValue | None]: + ssub_overflow = smt_bv.SsubOverflowOp(operands[0], operands[1]) + bv_res, ops = smt_bool_to_bv1(ssub_overflow.res) + + poison_op = smt.ConstantBoolOp(False) + res = PairOp(bv_res, poison_op.result) + rewriter.insert_op_before_matched_op([ssub_overflow] + ops + [poison_op, res]) + return ((res.res,), effect_state) + + class UShlOverflowOpSemantics(OperationSemantics): def get_semantics( self, @@ -630,6 +666,36 @@ def get_semantics( return ((resList[-1].results[0],), effect_state) +class PopCountOpSemantics(OperationSemantics): + def get_semantics( + self, + operands: Sequence[SSAValue], + results: Sequence[Attribute], + attributes: Mapping[str, Attribute | SSAValue], + effect_state: SSAValue | None, + rewriter: PatternRewriter, + ) -> tuple[Sequence[SSAValue], SSAValue | None]: + operand = operands[0] + assert isinstance(bv_type := operand.type, smt_bv.BitVectorType) + width = bv_type.width + + const_zero = smt_bv.ConstantOp(0, width) + ops: list[Operation] = [const_zero] + acc = const_zero.res + + for i in range(width.data): + extract = smt_bv.ExtractOp(operand, i, i) + ops.append(extract) + zext = smt_bv.ZeroExtendOp(extract.res, bv_type) + ops.append(zext) + add = smt_bv.AddOp(acc, zext.res) + ops.append(add) + acc = add.res + + rewriter.insert_op_before_matched_op(ops) + return ((acc,), effect_state) + + class SetHighBitsOpSemantics(OperationSemantics): def get_semantics( self, @@ -971,6 +1037,8 @@ def get_semantics( transfer.SMulOverflowOp: SMulOverflowOpSemantics(), transfer.UAddOverflowOp: UAddOverflowOpSemantics(), transfer.SAddOverflowOp: SAddOverflowOpSemantics(), + transfer.USubOverflowOp: USubOverflowOpSemantics(), + transfer.SSubOverflowOp: SSubOverflowOpSemantics(), transfer.UShlOverflowOp: UShlOverflowOpSemantics(), transfer.SShlOverflowOp: SShlOverflowOpSemantics(), transfer.CmpOp: CmpOpSemantics(), @@ -981,6 +1049,7 @@ def get_semantics( transfer.CountLZeroOp: CountLZeroOpSemantics(), transfer.CountROneOp: CountROneOpSemantics(), transfer.CountRZeroOp: CountRZeroOpSemantics(), + transfer.PopCountOp: PopCountOpSemantics(), transfer.SMaxOp: SMaxOpSemantics(), transfer.SMinOp: SMinOpSemantics(), transfer.UMaxOp: UMaxOpSemantics(),