From bf85502f3dcee1ac7c82d1d528789d0bf5578201 Mon Sep 17 00:00:00 2001
From: quaternic <57393910+quaternic@users.noreply.github.com>
Date: Thu, 1 May 2025 18:06:31 +0300
Subject: [PATCH 1/2] Implement Barrett reduction for modular multiplication
---
libm/src/math/support/int_traits.rs | 3 +
libm/src/math/support/int_traits/mod_mul.rs | 225 ++++++++++++++++++++
libm/src/math/support/mod.rs | 1 +
3 files changed, 229 insertions(+)
create mode 100644 libm/src/math/support/int_traits/mod_mul.rs
diff --git a/libm/src/math/support/int_traits.rs b/libm/src/math/support/int_traits.rs
index 3ec1faba1..1ef84d52a 100644
--- a/libm/src/math/support/int_traits.rs
+++ b/libm/src/math/support/int_traits.rs
@@ -1,5 +1,8 @@
use core::{cmp, fmt, ops};
+mod mod_mul;
+pub(crate) use mod_mul::Reducer;
+
/// Minimal integer implementations needed on all integer types, including wide integers.
pub trait MinInt:
Copy
diff --git a/libm/src/math/support/int_traits/mod_mul.rs b/libm/src/math/support/int_traits/mod_mul.rs
new file mode 100644
index 000000000..af8f4a398
--- /dev/null
+++ b/libm/src/math/support/int_traits/mod_mul.rs
@@ -0,0 +1,225 @@
+use super::{DInt, HInt, Int};
+
+/// Barrett reduction using the constant `R == (1 << K) == (1 << U::BITS)`
+///
+/// More specifically, implements single-word [Barrett multiplication]
+/// (https://en.wikipedia.org/wiki/Barrett_reduction#Single-word_Barrett_multiplication)
+/// and [division]
+/// (https://en.wikipedia.org/wiki/Barrett_reduction#Barrett_Division)
+/// for unsigned integers.
+///
+/// After constructing as `Reducer::new(b, n)`,
+/// provides operations to efficiently compute
+/// - `(a * b) / n` and `(a * b) % n`
+/// - `Reducer::new((a * b * b) % n, n)`, as long as `a * (n - 1) < R`
+#[derive(Clone, Copy, PartialEq, Eq, Debug)]
+pub(crate) struct Reducer<U> {
+ // the multiplying factor `b in 0..n`
+ num: U,
+ // the modulus `n in 1..=R/2`
+ div: U,
+ // the precomputed quotient, `q = (b << K) / n`
+ quo: U,
+ // the remainder of that division, `r = (b << K) % n`,
+ // (could always be recomputed as `(b << K) - q * n`,
+ // but it is convenient to save)
+ rem: U,
+}
+
+impl<U> Reducer<U>
+where
+ U: Int + HInt,
+ U::D: core::ops::Div<Output = U::D>,
+ U::D: core::ops::Rem<Output = U::D>,
+{
+ /// Requires `num < div <= R/2`, will panic otherwise
+ #[inline]
+ pub fn new(num: U, div: U) -> Self {
+ let _0 = U::ZERO;
+ let _1 = U::ONE;
+
+ assert!(num < div);
+ assert!(div.wrapping_sub(_1).leading_zeros() >= 1);
+
+ let bk = num.widen_hi();
+ let n = div.widen();
+ let quo = (bk / n).lo();
+ let rem = (bk % n).lo();
+
+ Self { num, div, quo, rem }
+ }
+}
+
+impl<U> Reducer<U>
+where
+ U: Int + HInt,
+ U::D: Int,
+{
+ /// Return the unique pair `(quotient, remainder)`
+ /// s.t. `a * b == quotient * n + remainder`, and `0 <= remainder < n`
+ #[inline]
+ pub fn mul_into_div_rem(&self, a: U) -> (U, U) {
+ let (q, mut r) = self.mul_into_unnormalized_div_rem(a);
+ // The unnormalized remainder is still guaranteed to be less than `2n`, so
+ // one checked subtraction is sufficient.
+ (q + U::cast_from(self.fixup(&mut r) as u8), r)
+ }
+
+ #[inline(always)]
+ pub fn fixup(&self, x: &mut U) -> bool {
+ x.checked_sub(self.div).map(|r| *x = r).is_some()
+ }
+
+ /// Return some pair `(quotient, remainder)`
+ /// s.t. `a * b == quotient * n + remainder`, and `0 <= remainder < 2n`
+ #[inline]
+ pub fn mul_into_unnormalized_div_rem(&self, a: U) -> (U, U) {
+ // General idea: Estimate the quotient `quotient = t in 0..a` s.t.
+ // the remainder `ab - tn` is close to zero, so `t ~= ab / n`
+
+ // Note: we use `R == 1 << U::BITS`, which means that
+ // - wrapping arithmetic with `U` is modulo `R`
+ // - all inputs are less than `R`
+
+ // Range analysis:
+ //
+ // Using the definition of euclidean division on the two divisions done:
+ // ```
+ // bR = qn + r, with 0 <= r < n
+ // aq = tR + s, with 0 <= s < R
+ // ```
+ let (_s, t) = a.widen_mul(self.quo).lo_hi();
+ // Then
+ // ```
+ // (ab - tn)R
+ // = abR - ntR
+ // = a(qn + r) - n(aq - s)
+ // = ar + ns
+ // ```
+ #[cfg(debug_assertions)]
+ {
+ assert!(t < a || (a == t && t.is_zero()));
+ let ab_tn = a.widen_mul(self.num) - t.widen_mul(self.div);
+ let ar_ns = a.widen_mul(self.rem) + _s.widen_mul(self.div);
+ assert!(ab_tn.hi().is_zero());
+ assert!(ar_ns.lo().is_zero());
+ assert_eq!(ab_tn.lo(), ar_ns.hi());
+ }
+ // Since `s < R` and `r < n`,
+ // ```
+ // 0 <= ns < nR
+ // 0 <= ar < an
+ // 0 <= (ab - tn) == (ar + ns)/R < n(1 + a/R)
+ // ```
+ // Since `a < R` and we check on construction that `n <= R/2`, the result
+ // is `0 <= ab - tn < R`, so it can be computed modulo `R`
+ // even though the intermediate terms generally wrap.
+ let ab = a.wrapping_mul(self.num);
+ let tn = t.wrapping_mul(self.div);
+ (t, ab.wrapping_sub(tn))
+ }
+
+ /// Constructs a new reducer with `b` set to `(ab * b) % n`
+ ///
+ /// Requires `r * ab == ra * b`, where `r = bR % n`.
+ #[inline(always)]
+ fn with_scaled_num_rem(&self, ab: U, ra: U) -> Self {
+ debug_assert_eq!(ab.widen_mul(self.rem), ra.widen_mul(self.num));
+ // The new factor `v = abb mod n`:
+ let (_, v) = self.mul_into_div_rem(ab);
+
+ // `rab = cn + d`, where `0 <= d < n`
+ let (c, d) = self.mul_into_div_rem(ra);
+
+ // We need `abbR = Xn + Y`:
+ // abbR
+ // = ab(qn + r)
+ // = abqn + rab
+ // = abqn + cn + d
+ // = (abq + c)n + d
+
+ Self {
+ num: v,
+ div: self.div,
+ quo: self.quo.wrapping_mul(ab).wrapping_add(c),
+ rem: d,
+ }
+ }
+
+ /// Computes the reducer with the factor `b` set to `(a * b * b) % n`
+ /// Requires that `a * (n - 1)` does not overflow.
+ #[allow(dead_code)]
+ #[inline]
+ pub fn squared_with_scale(&self, a: U) -> Self {
+ debug_assert!(a.widen_mul(self.div - U::ONE).hi().is_zero());
+ self.with_scaled_num_rem(a * self.num, a * self.rem)
+ }
+
+ /// Computes the reducer with the factor `b` set to `(b * b << s) % n`
+ /// Requires that `(n - 1) << s` does not overflow.
+ #[inline]
+ pub fn squared_with_shift(&self, s: u32) -> Self {
+ debug_assert!((self.div - U::ONE).leading_zeros() >= s);
+ self.with_scaled_num_rem(self.num << s, self.rem << s)
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::Reducer;
+
+ #[test]
+ fn u8_all() {
+ for y in 1..=128_u8 {
+ for r in 0..y {
+ let m = Reducer::new(r, y);
+ assert_eq!(m.quo, ((r as f32 * 256.0) / (y as f32)) as u8);
+ for x in 0..=u8::MAX {
+ let (quo, rem) = m.mul_into_div_rem(x);
+
+ let q0 = x as u32 * r as u32 / y as u32;
+ let r0 = x as u32 * r as u32 % y as u32;
+ assert_eq!(
+ (quo as u32, rem as u32),
+ (q0, r0),
+ "\n\
+ {x} * {r} = {xr}\n\
+ expected: = {q0} * {y} + {r0}\n\
+ returned: = {quo} * {y} + {rem} (== {})\n",
+ quo as u32 * y as u32 + rem as u32,
+ xr = x as u32 * r as u32,
+ );
+ }
+ for s in 0..=y.leading_zeros() {
+ assert_eq!(
+ m.squared_with_shift(s),
+ Reducer::new(((r << s) as u32 * r as u32 % y as u32) as u8, y)
+ );
+ }
+ for a in 0..=u8::MAX {
+ if a.checked_mul(y).is_some() {
+ let abb = a as u32 * r as u32 * r as u32;
+ assert_eq!(
+ m.squared_with_scale(a),
+ Reducer::new((abb % y as u32) as u8, y)
+ );
+ } else {
+ break;
+ }
+ }
+ for x0 in 0..=u8::MAX {
+ if m.num == 0 || x0 as u32 * m.rem as u32 % m.num as u32 != 0 {
+ continue;
+ }
+ let y0 = x0 as u32 * m.rem as u32 / m.num as u32;
+ let Ok(y0) = u8::try_from(y0) else { continue };
+
+ assert_eq!(
+ m.with_scaled_num_rem(x0, y0),
+ Reducer::new((x0 as u32 * m.num as u32 % y as u32) as u8, y)
+ );
+ }
+ }
+ }
+ }
+}
diff --git a/libm/src/math/support/mod.rs b/libm/src/math/support/mod.rs
index ee3f2bbdf..330217a7a 100644
--- a/libm/src/math/support/mod.rs
+++ b/libm/src/math/support/mod.rs
@@ -20,6 +20,7 @@ pub use hex_float::hf16;
pub use hex_float::hf128;
#[allow(unused_imports)]
pub use hex_float::{Hexf, hf32, hf64};
+pub(crate) use int_traits::Reducer;
pub use int_traits::{CastFrom, CastInto, DInt, HInt, Int, MinInt};
/// Hint to the compiler that the current path is cold.
From c811592b46797aae0769852d93fb1f9a90a8c9cb Mon Sep 17 00:00:00 2001
From: quaternic <57393910+quaternic@users.noreply.github.com>
Date: Thu, 1 May 2025 18:07:48 +0300
Subject: [PATCH 2/2] Optimize performance of fmod with Barrett multiplication
---
libm/src/math/generic/fmod.rs | 93 +++++++++++++++++++++
libm/src/math/support/int_traits/mod_mul.rs | 13 ++-
2 files changed, 98 insertions(+), 8 deletions(-)
diff --git a/libm/src/math/generic/fmod.rs b/libm/src/math/generic/fmod.rs
index e9898012f..183250c03 100644
--- a/libm/src/math/generic/fmod.rs
+++ b/libm/src/math/generic/fmod.rs
@@ -1,5 +1,6 @@
/* SPDX-License-Identifier: MIT OR Apache-2.0 */
use super::super::{CastFrom, Float, Int, MinInt};
+use crate::support::{DInt, HInt, Reducer};
#[inline]
pub fn fmod<F: Float>(x: F, y: F) -> F {
@@ -59,6 +60,33 @@ fn into_sig_exp<F: Float>(mut bits: F::Int) -> (F::Int, u32) {
/// Compute the remainder `(x * 2.pow(e)) % y` without overflow.
fn reduction<I: Int>(mut x: I, e: u32, y: I) -> I {
+ // FIXME: This is a temporary hack to get around the lack of `u256 / u256`.
+ // Actually, the algorithm only needs the operation `(x << I::BITS) / y`
+ // where `x < y`. That is, a division `u256 / u128` where the quotient must
+ // not overflow `u128` would be sufficient for `f128`.
+ unsafe {
+ use core::mem::transmute_copy;
+ if I::BITS == 64 {
+ let x = transmute_copy::<I, u64>(&x);
+ let y = transmute_copy::<I, u64>(&y);
+ let r = fast_reduction::<f64, u64>(x, e, y);
+ return transmute_copy::<u64, I>(&r);
+ }
+ if I::BITS == 32 {
+ let x = transmute_copy::<I, u32>(&x);
+ let y = transmute_copy::<I, u32>(&y);
+ let r = fast_reduction::<f32, u32>(x, e, y);
+ return transmute_copy::<u32, I>(&r);
+ }
+ #[cfg(f16_enabled)]
+ if I::BITS == 16 {
+ let x = transmute_copy::<I, u16>(&x);
+ let y = transmute_copy::<I, u16>(&y);
+ let r = fast_reduction::<f16, u16>(x, e, y);
+ return transmute_copy::<u16, I>(&r);
+ }
+ }
+
x %= y;
for _ in 0..e {
x <<= 1;
@@ -66,3 +94,68 @@ fn reduction<I: Int>(mut x: I, e: u32, y: I) -> I {
}
x
}
+
+trait SafeShift: Float {
+ // How many guaranteed leading zeros do the values have?
+ // A normalized floating point mantissa has `EXP_BITS` guaranteed leading
+ // zeros (exludes the implicit bit, but includes the now-zeroed sign bit)
+ // `-1` because we want to shift by either `BASE_SHIFT` or `BASE_SHIFT + 1`
+ const BASE_SHIFT: u32 = Self::EXP_BITS - 1;
+}
+impl<F: Float> SafeShift for F {}
+
+fn fast_reduction<F, I>(x: I, e: u32, y: I) -> I
+where
+ F: Float<Int = I>,
+ I: Int + HInt,
+ I::D: Int + DInt<H = I>,
+{
+ let _0 = I::ZERO;
+ let _1 = I::ONE;
+
+ if y == _1 {
+ return _0;
+ }
+
+ if e <= F::BASE_SHIFT {
+ return (x << e) % y;
+ }
+
+ // Find least depth s.t. `(e >> depth) < I::BITS`
+ let depth = (I::BITS - 1)
+ .leading_zeros()
+ .saturating_sub(e.leading_zeros());
+
+ let initial = (e >> depth) - F::BASE_SHIFT;
+
+ let max_rem = y.wrapping_sub(_1);
+ let max_ilog2 = max_rem.ilog2();
+ let mut pow2 = _1 << max_ilog2.min(initial);
+ for _ in max_ilog2..initial {
+ pow2 <<= 1;
+ pow2 = pow2.checked_sub(y).unwrap_or(pow2);
+ }
+
+ // At each step `k in [depth, ..., 0]`,
+ // `p` is `(e >> k) - BASE_SHIFT`
+ // `m` is `(1 << p) % y`
+ let mut k = depth;
+ let mut p = initial;
+ let mut m = Reducer::new(pow2, y);
+
+ while k > 0 {
+ k -= 1;
+ p = p + p + F::BASE_SHIFT;
+ if e & (1 << k) != 0 {
+ m = m.squared_with_shift(F::BASE_SHIFT + 1);
+ p += 1;
+ } else {
+ m = m.squared_with_shift(F::BASE_SHIFT);
+ };
+
+ debug_assert!(p == (e >> k) - F::BASE_SHIFT);
+ }
+
+ // (x << BASE_SHIFT) * (1 << p) == x << e
+ m.mul_into_div_rem(x << F::BASE_SHIFT).1
+}
diff --git a/libm/src/math/support/int_traits/mod_mul.rs b/libm/src/math/support/int_traits/mod_mul.rs
index af8f4a398..c770a122b 100644
--- a/libm/src/math/support/int_traits/mod_mul.rs
+++ b/libm/src/math/support/int_traits/mod_mul.rs
@@ -2,14 +2,11 @@ use super::{DInt, HInt, Int};
/// Barrett reduction using the constant `R == (1 << K) == (1 << U::BITS)`
///
-/// More specifically, implements single-word [Barrett multiplication]
-/// (https://en.wikipedia.org/wiki/Barrett_reduction#Single-word_Barrett_multiplication)
-/// and [division]
-/// (https://en.wikipedia.org/wiki/Barrett_reduction#Barrett_Division)
-/// for unsigned integers.
+/// For a more detailed description, see
+/// <https://en.wikipedia.org/wiki/Barrett_reduction>.
///
/// After constructing as `Reducer::new(b, n)`,
-/// provides operations to efficiently compute
+/// has operations to efficiently compute
/// - `(a * b) / n` and `(a * b) % n`
/// - `Reducer::new((a * b * b) % n, n)`, as long as `a * (n - 1) < R`
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
@@ -103,7 +100,7 @@ where
let ar_ns = a.widen_mul(self.rem) + _s.widen_mul(self.div);
assert!(ab_tn.hi().is_zero());
assert!(ar_ns.lo().is_zero());
- assert_eq!(ab_tn.lo(), ar_ns.hi());
+ assert!(ab_tn.lo() == ar_ns.hi());
}
// Since `s < R` and `r < n`,
// ```
@@ -124,7 +121,7 @@ where
/// Requires `r * ab == ra * b`, where `r = bR % n`.
#[inline(always)]
fn with_scaled_num_rem(&self, ab: U, ra: U) -> Self {
- debug_assert_eq!(ab.widen_mul(self.rem), ra.widen_mul(self.num));
+ debug_assert!(ab.widen_mul(self.rem) == ra.widen_mul(self.num));
// The new factor `v = abb mod n`:
let (_, v) = self.mul_into_div_rem(ab);