From 7ffeacee9706a421000354723e74d226028654d2 Mon Sep 17 00:00:00 2001 From: Juan Munoz Date: Fri, 14 Jun 2024 17:38:14 -0300 Subject: [PATCH] fix mul not following spec --- programs/mul/mul.zasm | 4 ++-- programs/mul/mul_big.zasm | 2 +- src/op_handlers/div.rs | 10 ++++++++++ src/op_handlers/mul.rs | 32 ++++++++++++++++++++------------ tests/integration_test.rs | 18 +++++++++--------- 5 files changed, 42 insertions(+), 24 deletions(-) diff --git a/programs/mul/mul.zasm b/programs/mul/mul.zasm index d72e4c06..a70fc6ac 100644 --- a/programs/mul/mul.zasm +++ b/programs/mul/mul.zasm @@ -6,8 +6,8 @@ __entry: add 3, r0, r1 add 2, r0, r2 mul r1, r2, r1, r0 - mul 1, r1, r1, r0 - sstore r0, r1 + mul 1, r1, r3, r4 + sstore r0, r3 ret .func_end0: diff --git a/programs/mul/mul_big.zasm b/programs/mul/mul_big.zasm index a9e67b70..35d65eb8 100644 --- a/programs/mul/mul_big.zasm +++ b/programs/mul/mul_big.zasm @@ -3,7 +3,7 @@ .globl __entry __entry: .func_begin0: - ; test sets r1 = 2**(256) - 1, r2 = 1 + ; test sets r1 = 2**(256) - 1, r2 = 2**(256) mul r1, r2, r3, r4 sstore r0, r1 ret diff --git a/src/op_handlers/div.rs b/src/op_handlers/div.rs index 3b2f5532..59d3ea47 100644 --- a/src/op_handlers/div.rs +++ b/src/op_handlers/div.rs @@ -4,5 +4,15 @@ use crate::{opcode::Opcode, state::VMState}; pub fn _div(vm: &mut VMState, opcode: &Opcode) { let (src0, src1) = address_operands_read(vm, &opcode); let (quotient, remainder) = src0.div_mod(src1); + if opcode.alters_vm_flags { + // If overflow, set the flag. + // otherwise keep the current value. + // vm.flag_lt_of |= overflow; + // Set eq if res == 0 + // vm.flag_eq |= quotient.is_zero(); + // Gt is set if both of lt_of and eq are cleared. + vm.flag_gt |= !vm.flag_lt_of && !vm.flag_eq; + } + address_operands_store(vm, &opcode, (quotient, Some(remainder))); } diff --git a/src/op_handlers/mul.rs b/src/op_handlers/mul.rs index 3a1db6de..5d6a4221 100644 --- a/src/op_handlers/mul.rs +++ b/src/op_handlers/mul.rs @@ -1,27 +1,35 @@ -use u256::U256; +use u256::{U256, U512}; use crate::address_operands::{address_operands_read, address_operands_store}; use crate::{opcode::Opcode, state::VMState}; pub fn _mul(vm: &mut VMState, opcode: &Opcode) { let (src0, src1) = address_operands_read(vm, &opcode); + let src0 = U512::from(src0); + let src1 = U512::from(src1); + let res = src0 * src1; - let max = U256::max_value(); - let low_mask = U256::from(max.low_u128()); - let high_mask = !low_mask & max; + let u256_mask = U512::from(U256::MAX); + let low_bits = res & u256_mask; + let high_bits = res >> 256 & u256_mask; - let (res, overflow) = src0.overflowing_mul(src1); if opcode.alters_vm_flags { - // If overflow, set the flag. - // otherwise keep the current value. + // Lt overflow, is set if + // src0 * src1 >= 2^256 + let overflow = res >= U512::from(U256::MAX); vm.flag_lt_of |= overflow; - // Set eq if res == 0 - vm.flag_eq |= res.is_zero(); + // Eq is set if res_low == 0 + vm.flag_eq |= low_bits.is_zero(); // Gt is set if both of lt_of and eq are cleared. vm.flag_gt |= !vm.flag_lt_of && !vm.flag_eq; } - let low_bits = res & low_mask; - let high_bits = res & high_mask; - address_operands_store(vm, &opcode, (low_bits, Some(high_bits))); + address_operands_store( + vm, + &opcode, + ( + low_bits.try_into().unwrap(), + Some(high_bits.try_into().unwrap()), + ), + ); } diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 679c1dd0..aa6e4eb4 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -259,15 +259,19 @@ fn test_sub_and_add() { #[test] fn test_mul_asm() { let bin_path = make_bin_path_asm("mul"); - let (result, _) = run_program(&bin_path); - assert_eq!(result, U256::from_dec_str("6").unwrap()); + let (_, vm) = run_program(&bin_path); + let low = vm.get_register(3); + let high = vm.get_register(4); + + assert_eq!(low, U256::from_dec_str("6").unwrap()); + assert_eq!(high, U256::zero()); } #[test] fn test_mul_big_asm() { let bin_path = make_bin_path_asm("mul_big"); let r1 = U256::MAX; - let r2 = U256::from(1); + let r2 = U256::from(2); let mut registers: [U256; 15] = [U256::zero(); 15]; registers[0] = r1; registers[1] = r2; @@ -278,12 +282,8 @@ fn test_mul_big_asm() { let low = vm.get_register(3); let high = vm.get_register(4); - let max = U256::max_value(); - let expected_low = U256::from(max.low_u128()); - let expected_high = !low & max; - - assert_eq!(low, expected_low); - assert_eq!(high, expected_high); + assert_eq!(low, U256::MAX - 1); + assert_eq!(high, U256::from(1)); // multiply by 2 == shift left by 1 } #[test]