Skip to content

Commit 8305b46

Browse files
committed
rename: safe arith
1 parent cd94fec commit 8305b46

File tree

4 files changed

+321
-2
lines changed

4 files changed

+321
-2
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[workspace]
2-
members = [ "alloy-helper","reakula", "safe_arith", "specs"]
2+
members = [ "alloy-helper","reakula", "safe-arith", "specs"]
33
default-members = ["reakula"]
44
resolver = "2"
55

@@ -16,7 +16,7 @@ tokio = { version = "1.41.1", features = ["full"] }
1616
eyre = { version = "0.6.12" }
1717
sszb = { package = "sszb", git = "https://github.com/ghiliweld/sszb.git" }
1818
sszb_derive = { package = "sszb_derive", git = "https://github.com/ghiliweld/sszb.git" }
19-
safe_arith = { path = "safe_arith" }
19+
safe_arith = { path = "safe-arith" }
2020
alloy-primitives = { version = "0.8.15" }
2121
serde = { version = "1.0.216", features = ["derive"] }
2222
bytes = { version = "1.9.0", features = ["serde"] }

safe-arith/Cargo.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[package]
2+
name = "safe_arith"
3+
description.workspace = true
4+
version.workspace = true
5+
authors.workspace = true
6+
edition.workspace = true
7+
rust-version.workspace = true
8+
9+
[dependencies]

safe-arith/src/iter.rs

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
use crate::SafeArith;
2+
3+
/// Error representing the failure of an arithmetic operation.
4+
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
5+
pub enum ArithError {
6+
Overflow,
7+
DivisionByZero,
8+
}
9+
10+
pub type Result<T, E = crate::ArithError> = core::result::Result<T, E>;
11+
12+
/// Extension trait for iterators, providing a safe replacement for `sum`.
13+
pub trait SafeArithIter<T> {
14+
fn safe_sum(self) -> Result<T>;
15+
}
16+
17+
impl<I, T> SafeArithIter<T> for I
18+
where
19+
I: Iterator<Item = T> + Sized,
20+
T: SafeArith,
21+
{
22+
fn safe_sum(mut self) -> Result<T> {
23+
self.try_fold(T::ZERO, |acc, x| acc.safe_add(x))
24+
}
25+
}
26+
27+
#[cfg(test)]
28+
mod test {
29+
use super::*;
30+
use crate::ArithError;
31+
32+
#[test]
33+
fn empty_sum() {
34+
let v: Vec<u64> = vec![];
35+
assert_eq!(v.into_iter().safe_sum(), Ok(0));
36+
}
37+
38+
#[test]
39+
fn unsigned_sum_small() {
40+
let arr = [500u64, 501, 502, 503, 504, 505, 506];
41+
assert_eq!(
42+
arr.iter().copied().safe_sum().unwrap(),
43+
arr.iter().copied().sum()
44+
);
45+
46+
// Additional case with different values
47+
let arr = [10u64, 20, 30, 40, 50, 60, 70, 80, 90, 100];
48+
assert_eq!(arr.iter().copied().safe_sum().unwrap(), arr.iter().copied().sum());
49+
}
50+
51+
#[test]
52+
fn unsigned_sum_overflow() {
53+
let v = vec![u64::MAX, 2];
54+
assert_eq!(v.into_iter().safe_sum(), Err(ArithError::Overflow));
55+
56+
// Additional edge case with close-to-limit values
57+
let v = vec![u64::MAX - 1, 2];
58+
assert_eq!(v.into_iter().safe_sum(), Err(ArithError::Overflow));
59+
}
60+
61+
#[test]
62+
fn signed_sum_small() {
63+
let v = vec![-2i64, -3i64, -4i64, 4, 3, 2];
64+
assert_eq!(v.into_iter().safe_sum(), Ok(0));
65+
66+
// Additional case with alternating signs
67+
let v = vec![-15i64, 25, -35, 45, -55, 65];
68+
assert_eq!(v.into_iter().safe_sum(), Ok(30));
69+
}
70+
71+
#[test]
72+
fn signed_sum_overflow_above() {
73+
let v = vec![1, 2, 3, 4, i16::MAX, 5, 6, 7];
74+
assert_eq!(v.into_iter().safe_sum(), Err(ArithError::Overflow));
75+
76+
// Additional edge case with large positive numbers
77+
let v = vec![i64::MAX - 1, 2];
78+
assert_eq!(v.into_iter().safe_sum(), Err(ArithError::Overflow));
79+
}
80+
81+
#[test]
82+
fn signed_sum_overflow_below() {
83+
let v = vec![i16::MIN + 1, -2];
84+
assert_eq!(v.into_iter().safe_sum(), Err(ArithError::Overflow));
85+
86+
// Additional edge case with large negative numbers
87+
let v = vec![i64::MIN + 1, -2];
88+
assert_eq!(v.into_iter().safe_sum(), Err(ArithError::Overflow));
89+
}
90+
91+
#[test]
92+
fn signed_sum_almost_overflow() {
93+
let arr = [i64::MIN + 1, 2, -2i64, i64::MAX - 1, i64::MAX, -1];
94+
assert_eq!(
95+
arr.iter().copied().safe_sum().unwrap(),
96+
arr.iter().copied().sum()
97+
);
98+
99+
// Additional case with values close to the limits
100+
let arr = [i64::MAX / 3, i64::MAX / 3, i64::MAX / 3];
101+
assert_eq!(arr.iter().copied().safe_sum().unwrap(), arr.iter().copied().sum());
102+
}
103+
}

safe-arith/src/lib.rs

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
mod iter;
2+
3+
pub use iter::*;
4+
5+
pub trait SafeArith<Rhs = Self>: Sized + Copy {
6+
const ZERO: Self;
7+
const ONE: Self;
8+
9+
/// Safe variant of `+` that guards against overflow.
10+
fn safe_add(&self, other: Rhs) -> Result<Self>;
11+
12+
/// Safe variant of `-` that guards against overflow.
13+
fn safe_sub(&self, other: Rhs) -> Result<Self>;
14+
15+
/// Safe variant of `*` that guards against overflow.
16+
fn safe_mul(&self, other: Rhs) -> Result<Self>;
17+
18+
/// Safe variant of `/` that guards against division by 0.
19+
fn safe_div(&self, other: Rhs) -> Result<Self>;
20+
21+
/// Safe variant of `%` that guards against division by 0.
22+
fn safe_rem(&self, other: Rhs) -> Result<Self>;
23+
24+
/// Safe variant of `<<` that guards against overflow.
25+
fn safe_shl(&self, other: u32) -> Result<Self>;
26+
27+
/// Safe variant of `>>` that guards against overflow.
28+
fn safe_shr(&self, other: u32) -> Result<Self>;
29+
30+
/// Safe variant of `+=` that guards against overflow.
31+
fn safe_add_assign(&mut self, other: Rhs) -> Result<()> {
32+
*self = self.safe_add(other)?;
33+
Ok(())
34+
}
35+
36+
/// Safe variant of `-=` that guards against overflow.
37+
fn safe_sub_assign(&mut self, other: Rhs) -> Result<()> {
38+
*self = self.safe_sub(other)?;
39+
Ok(())
40+
}
41+
42+
/// Safe variant of `*=` that guards against overflow.
43+
fn safe_mul_assign(&mut self, other: Rhs) -> Result<()> {
44+
*self = self.safe_mul(other)?;
45+
Ok(())
46+
}
47+
48+
/// Safe variant of `/=` that guards against division by 0.
49+
fn safe_div_assign(&mut self, other: Rhs) -> Result<()> {
50+
*self = self.safe_div(other)?;
51+
Ok(())
52+
}
53+
54+
/// Safe variant of `%=` that guards against division by 0.
55+
fn safe_rem_assign(&mut self, other: Rhs) -> Result<()> {
56+
*self = self.safe_rem(other)?;
57+
Ok(())
58+
}
59+
60+
/// Safe variant of `<<=` that guards against overflow.
61+
fn safe_shl_assign(&mut self, other: u32) -> Result<()> {
62+
*self = self.safe_shl(other)?;
63+
Ok(())
64+
}
65+
66+
/// Safe variant of `>>=` that guards against overflow.
67+
fn safe_shr_assign(&mut self, other: u32) -> Result<()> {
68+
*self = self.safe_shr(other)?;
69+
Ok(())
70+
}
71+
}
72+
73+
macro_rules! impl_safe_arith {
74+
($typ:ty) => {
75+
impl SafeArith for $typ {
76+
const ZERO: Self = 0;
77+
const ONE: Self = 1;
78+
79+
#[inline]
80+
fn safe_add(&self, other: Self) -> Result<Self> {
81+
self.checked_add(other).ok_or(ArithError::Overflow)
82+
}
83+
84+
#[inline]
85+
fn safe_sub(&self, other: Self) -> Result<Self> {
86+
self.checked_sub(other).ok_or(ArithError::Overflow)
87+
}
88+
89+
#[inline]
90+
fn safe_mul(&self, other: Self) -> Result<Self> {
91+
self.checked_mul(other).ok_or(ArithError::Overflow)
92+
}
93+
94+
#[inline]
95+
fn safe_div(&self, other: Self) -> Result<Self> {
96+
self.checked_div(other).ok_or(ArithError::DivisionByZero)
97+
}
98+
99+
#[inline]
100+
fn safe_rem(&self, other: Self) -> Result<Self> {
101+
self.checked_rem(other).ok_or(ArithError::DivisionByZero)
102+
}
103+
104+
#[inline]
105+
fn safe_shl(&self, other: u32) -> Result<Self> {
106+
self.checked_shl(other).ok_or(ArithError::Overflow)
107+
}
108+
109+
#[inline]
110+
fn safe_shr(&self, other: u32) -> Result<Self> {
111+
self.checked_shr(other).ok_or(ArithError::Overflow)
112+
}
113+
}
114+
};
115+
}
116+
117+
impl_safe_arith!(u8);
118+
impl_safe_arith!(u16);
119+
impl_safe_arith!(u32);
120+
impl_safe_arith!(u64);
121+
impl_safe_arith!(usize);
122+
impl_safe_arith!(i8);
123+
impl_safe_arith!(i16);
124+
impl_safe_arith!(i32);
125+
impl_safe_arith!(i64);
126+
impl_safe_arith!(isize);
127+
128+
#[cfg(test)]
129+
mod test {
130+
use super::*;
131+
132+
#[test]
133+
fn basic() {
134+
// Test with u8
135+
let x = 10u8;
136+
let y = 5u8;
137+
assert_eq!(x.safe_add(y), Ok(x + y));
138+
assert_eq!(x.safe_sub(y), Ok(x - y));
139+
assert_eq!(x.safe_mul(y), Ok(x * y));
140+
assert_eq!(x.safe_div(y), Ok(x / y));
141+
assert_eq!(x.safe_rem(y), Ok(x % y));
142+
143+
assert_eq!(x.safe_shl(2), Ok(x << 2));
144+
assert_eq!(x.safe_shr(1), Ok(x >> 1));
145+
146+
// Test with u16
147+
let x = 100u16;
148+
let y = 25u16;
149+
assert_eq!(x.safe_add(y), Ok(x + y));
150+
assert_eq!(x.safe_sub(y), Ok(x - y));
151+
assert_eq!(x.safe_mul(y), Ok(x * y));
152+
assert_eq!(x.safe_div(y), Ok(x / y));
153+
assert_eq!(x.safe_rem(y), Ok(x % y));
154+
155+
assert_eq!(x.safe_shl(3), Ok(x << 3));
156+
assert_eq!(x.safe_shr(2), Ok(x >> 2));
157+
}
158+
159+
#[test]
160+
fn mutate() {
161+
// Test with edge case values
162+
let mut x = 1u8;
163+
x.safe_add_assign(254).unwrap();
164+
assert_eq!(x, 255);
165+
x.safe_sub_assign(255).unwrap();
166+
assert_eq!(x, 0);
167+
x.safe_add_assign(1).unwrap();
168+
x.safe_shl_assign(7).unwrap();
169+
assert_eq!(x, 128);
170+
x.safe_shr_assign(7).unwrap();
171+
assert_eq!(x, 1);
172+
173+
// Test with larger integer types
174+
let mut y = 100u16;
175+
y.safe_mul_assign(2).unwrap();
176+
assert_eq!(y, 200);
177+
y.safe_div_assign(4).unwrap();
178+
assert_eq!(y, 50);
179+
y.safe_add_assign(1000).unwrap();
180+
assert_eq!(y, 1050);
181+
}
182+
183+
#[test]
184+
fn errors() {
185+
// Overflow and underflow for u32
186+
assert!(u32::MAX.safe_add(1).is_err());
187+
assert!(u32::MIN.safe_sub(1).is_err());
188+
assert!(u32::MAX.safe_mul(2).is_err());
189+
190+
// Division by zero
191+
assert!(10u32.safe_div(0).is_err());
192+
assert!(10u32.safe_rem(0).is_err());
193+
194+
// Shift overflow
195+
assert!(u32::MAX.safe_shl(33).is_err());
196+
assert!(u32::MAX.safe_shr(33).is_err());
197+
198+
// Edge cases for smaller types
199+
assert!(u8::MAX.safe_add(1).is_err());
200+
assert!(u8::MIN.safe_sub(1).is_err());
201+
assert!(u8::MAX.safe_mul(2).is_err());
202+
203+
// Shifting too far
204+
assert!(u8::MAX.safe_shl(9).is_err());
205+
assert!(u8::MAX.safe_shr(9).is_err());
206+
}
207+
}

0 commit comments

Comments
 (0)