From bf85502f3dcee1ac7c82d1d528789d0bf5578201 Mon Sep 17 00:00:00 2001
From: quaternic <57393910+quaternic@users.noreply.github.com>
Date: Thu, 1 May 2025 18:06:31 +0300
Subject: [PATCH 1/2] Implement Barrett reduction for modular multiplication

---
 libm/src/math/support/int_traits.rs         |   3 +
 libm/src/math/support/int_traits/mod_mul.rs | 225 ++++++++++++++++++++
 libm/src/math/support/mod.rs                |   1 +
 3 files changed, 229 insertions(+)
 create mode 100644 libm/src/math/support/int_traits/mod_mul.rs

diff --git a/libm/src/math/support/int_traits.rs b/libm/src/math/support/int_traits.rs
index 3ec1faba1..1ef84d52a 100644
--- a/libm/src/math/support/int_traits.rs
+++ b/libm/src/math/support/int_traits.rs
@@ -1,5 +1,8 @@
 use core::{cmp, fmt, ops};
 
+mod mod_mul;
+pub(crate) use mod_mul::Reducer;
+
 /// Minimal integer implementations needed on all integer types, including wide integers.
 pub trait MinInt:
     Copy
diff --git a/libm/src/math/support/int_traits/mod_mul.rs b/libm/src/math/support/int_traits/mod_mul.rs
new file mode 100644
index 000000000..af8f4a398
--- /dev/null
+++ b/libm/src/math/support/int_traits/mod_mul.rs
@@ -0,0 +1,225 @@
+use super::{DInt, HInt, Int};
+
+/// Barrett reduction using the constant `R == (1 << K) == (1 << U::BITS)`
+///
+/// More specifically, implements single-word [Barrett multiplication]
+/// (https://en.wikipedia.org/wiki/Barrett_reduction#Single-word_Barrett_multiplication)
+/// and [division]
+/// (https://en.wikipedia.org/wiki/Barrett_reduction#Barrett_Division)
+/// for unsigned integers.
+///
+/// After constructing as `Reducer::new(b, n)`,
+/// provides operations to efficiently compute
+///  - `(a * b) / n` and `(a * b) % n`
+///  - `Reducer::new((a * b * b) % n, n)`, as long as `a * (n - 1) < R`
+#[derive(Clone, Copy, PartialEq, Eq, Debug)]
+pub(crate) struct Reducer<U> {
+    // the multiplying factor `b in 0..n`
+    num: U,
+    // the modulus `n in 1..=R/2`
+    div: U,
+    // the precomputed quotient, `q = (b << K) / n`
+    quo: U,
+    // the remainder of that division, `r = (b << K) % n`,
+    // (could always be recomputed as `(b << K) - q * n`,
+    // but it is convenient to save)
+    rem: U,
+}
+
+impl<U> Reducer<U>
+where
+    U: Int + HInt,
+    U::D: core::ops::Div<Output = U::D>,
+    U::D: core::ops::Rem<Output = U::D>,
+{
+    /// Requires `num < div <= R/2`, will panic otherwise
+    #[inline]
+    pub fn new(num: U, div: U) -> Self {
+        let _0 = U::ZERO;
+        let _1 = U::ONE;
+
+        assert!(num < div);
+        assert!(div.wrapping_sub(_1).leading_zeros() >= 1);
+
+        let bk = num.widen_hi();
+        let n = div.widen();
+        let quo = (bk / n).lo();
+        let rem = (bk % n).lo();
+
+        Self { num, div, quo, rem }
+    }
+}
+
+impl<U> Reducer<U>
+where
+    U: Int + HInt,
+    U::D: Int,
+{
+    /// Return the unique pair `(quotient, remainder)`
+    /// s.t. `a * b == quotient * n + remainder`, and `0 <= remainder < n`
+    #[inline]
+    pub fn mul_into_div_rem(&self, a: U) -> (U, U) {
+        let (q, mut r) = self.mul_into_unnormalized_div_rem(a);
+        // The unnormalized remainder is still guaranteed to be less than `2n`, so
+        // one checked subtraction is sufficient.
+        (q + U::cast_from(self.fixup(&mut r) as u8), r)
+    }
+
+    #[inline(always)]
+    pub fn fixup(&self, x: &mut U) -> bool {
+        x.checked_sub(self.div).map(|r| *x = r).is_some()
+    }
+
+    /// Return some pair `(quotient, remainder)`
+    /// s.t. `a * b == quotient * n + remainder`, and `0 <= remainder < 2n`
+    #[inline]
+    pub fn mul_into_unnormalized_div_rem(&self, a: U) -> (U, U) {
+        // General idea: Estimate the quotient `quotient = t in 0..a` s.t.
+        // the remainder `ab - tn` is close to zero, so `t ~= ab / n`
+
+        // Note: we use `R == 1 << U::BITS`, which means that
+        //  - wrapping arithmetic with `U` is modulo `R`
+        //  - all inputs are less than `R`
+
+        // Range analysis:
+        //
+        // Using the definition of euclidean division on the two divisions done:
+        // ```
+        // bR = qn + r,  with 0 <= r < n
+        // aq = tR + s,  with 0 <= s < R
+        // ```
+        let (_s, t) = a.widen_mul(self.quo).lo_hi();
+        // Then
+        // ```
+        // (ab - tn)R
+        // = abR - ntR
+        // = a(qn + r) - n(aq - s)
+        // = ar + ns
+        // ```
+        #[cfg(debug_assertions)]
+        {
+            assert!(t < a || (a == t && t.is_zero()));
+            let ab_tn = a.widen_mul(self.num) - t.widen_mul(self.div);
+            let ar_ns = a.widen_mul(self.rem) + _s.widen_mul(self.div);
+            assert!(ab_tn.hi().is_zero());
+            assert!(ar_ns.lo().is_zero());
+            assert_eq!(ab_tn.lo(), ar_ns.hi());
+        }
+        // Since `s < R` and `r < n`,
+        // ```
+        // 0 <= ns < nR
+        // 0 <= ar < an
+        // 0 <= (ab - tn) == (ar + ns)/R < n(1 + a/R)
+        // ```
+        // Since `a < R` and we check on construction that `n <= R/2`, the result
+        // is `0 <= ab - tn < R`, so it can be computed modulo `R`
+        // even though the intermediate terms generally wrap.
+        let ab = a.wrapping_mul(self.num);
+        let tn = t.wrapping_mul(self.div);
+        (t, ab.wrapping_sub(tn))
+    }
+
+    /// Constructs a new reducer with `b` set to `(ab * b) % n`
+    ///
+    /// Requires `r * ab == ra * b`, where `r = bR % n`.
+    #[inline(always)]
+    fn with_scaled_num_rem(&self, ab: U, ra: U) -> Self {
+        debug_assert_eq!(ab.widen_mul(self.rem), ra.widen_mul(self.num));
+        // The new factor `v = abb mod n`:
+        let (_, v) = self.mul_into_div_rem(ab);
+
+        // `rab = cn + d`, where `0 <= d < n`
+        let (c, d) = self.mul_into_div_rem(ra);
+
+        // We need `abbR = Xn + Y`:
+        // abbR
+        // = ab(qn + r)
+        // = abqn + rab
+        // = abqn + cn + d
+        // = (abq + c)n + d
+
+        Self {
+            num: v,
+            div: self.div,
+            quo: self.quo.wrapping_mul(ab).wrapping_add(c),
+            rem: d,
+        }
+    }
+
+    /// Computes the reducer with the factor `b` set to `(a * b * b) % n`
+    /// Requires that `a * (n - 1)` does not overflow.
+    #[allow(dead_code)]
+    #[inline]
+    pub fn squared_with_scale(&self, a: U) -> Self {
+        debug_assert!(a.widen_mul(self.div - U::ONE).hi().is_zero());
+        self.with_scaled_num_rem(a * self.num, a * self.rem)
+    }
+
+    /// Computes the reducer with the factor `b` set to `(b * b << s) % n`
+    /// Requires that `(n - 1) << s` does not overflow.
+    #[inline]
+    pub fn squared_with_shift(&self, s: u32) -> Self {
+        debug_assert!((self.div - U::ONE).leading_zeros() >= s);
+        self.with_scaled_num_rem(self.num << s, self.rem << s)
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use super::Reducer;
+
+    #[test]
+    fn u8_all() {
+        for y in 1..=128_u8 {
+            for r in 0..y {
+                let m = Reducer::new(r, y);
+                assert_eq!(m.quo, ((r as f32 * 256.0) / (y as f32)) as u8);
+                for x in 0..=u8::MAX {
+                    let (quo, rem) = m.mul_into_div_rem(x);
+
+                    let q0 = x as u32 * r as u32 / y as u32;
+                    let r0 = x as u32 * r as u32 % y as u32;
+                    assert_eq!(
+                        (quo as u32, rem as u32),
+                        (q0, r0),
+                        "\n\
+                        {x} * {r} = {xr}\n\
+                        expected: = {q0} * {y} + {r0}\n\
+                        returned: = {quo} * {y} + {rem} (== {})\n",
+                        quo as u32 * y as u32 + rem as u32,
+                        xr = x as u32 * r as u32,
+                    );
+                }
+                for s in 0..=y.leading_zeros() {
+                    assert_eq!(
+                        m.squared_with_shift(s),
+                        Reducer::new(((r << s) as u32 * r as u32 % y as u32) as u8, y)
+                    );
+                }
+                for a in 0..=u8::MAX {
+                    if a.checked_mul(y).is_some() {
+                        let abb = a as u32 * r as u32 * r as u32;
+                        assert_eq!(
+                            m.squared_with_scale(a),
+                            Reducer::new((abb % y as u32) as u8, y)
+                        );
+                    } else {
+                        break;
+                    }
+                }
+                for x0 in 0..=u8::MAX {
+                    if m.num == 0 || x0 as u32 * m.rem as u32 % m.num as u32 != 0 {
+                        continue;
+                    }
+                    let y0 = x0 as u32 * m.rem as u32 / m.num as u32;
+                    let Ok(y0) = u8::try_from(y0) else { continue };
+
+                    assert_eq!(
+                        m.with_scaled_num_rem(x0, y0),
+                        Reducer::new((x0 as u32 * m.num as u32 % y as u32) as u8, y)
+                    );
+                }
+            }
+        }
+    }
+}
diff --git a/libm/src/math/support/mod.rs b/libm/src/math/support/mod.rs
index ee3f2bbdf..330217a7a 100644
--- a/libm/src/math/support/mod.rs
+++ b/libm/src/math/support/mod.rs
@@ -20,6 +20,7 @@ pub use hex_float::hf16;
 pub use hex_float::hf128;
 #[allow(unused_imports)]
 pub use hex_float::{Hexf, hf32, hf64};
+pub(crate) use int_traits::Reducer;
 pub use int_traits::{CastFrom, CastInto, DInt, HInt, Int, MinInt};
 
 /// Hint to the compiler that the current path is cold.

From c811592b46797aae0769852d93fb1f9a90a8c9cb Mon Sep 17 00:00:00 2001
From: quaternic <57393910+quaternic@users.noreply.github.com>
Date: Thu, 1 May 2025 18:07:48 +0300
Subject: [PATCH 2/2] Optimize performance of fmod with Barrett multiplication

---
 libm/src/math/generic/fmod.rs               | 93 +++++++++++++++++++++
 libm/src/math/support/int_traits/mod_mul.rs | 13 ++-
 2 files changed, 98 insertions(+), 8 deletions(-)

diff --git a/libm/src/math/generic/fmod.rs b/libm/src/math/generic/fmod.rs
index e9898012f..183250c03 100644
--- a/libm/src/math/generic/fmod.rs
+++ b/libm/src/math/generic/fmod.rs
@@ -1,5 +1,6 @@
 /* SPDX-License-Identifier: MIT OR Apache-2.0 */
 use super::super::{CastFrom, Float, Int, MinInt};
+use crate::support::{DInt, HInt, Reducer};
 
 #[inline]
 pub fn fmod<F: Float>(x: F, y: F) -> F {
@@ -59,6 +60,33 @@ fn into_sig_exp<F: Float>(mut bits: F::Int) -> (F::Int, u32) {
 
 /// Compute the remainder `(x * 2.pow(e)) % y` without overflow.
 fn reduction<I: Int>(mut x: I, e: u32, y: I) -> I {
+    // FIXME: This is a temporary hack to get around the lack of `u256 / u256`.
+    // Actually, the algorithm only needs the operation `(x << I::BITS) / y`
+    // where `x < y`. That is, a division `u256 / u128` where the quotient must
+    // not overflow `u128` would be sufficient for `f128`.
+    unsafe {
+        use core::mem::transmute_copy;
+        if I::BITS == 64 {
+            let x = transmute_copy::<I, u64>(&x);
+            let y = transmute_copy::<I, u64>(&y);
+            let r = fast_reduction::<f64, u64>(x, e, y);
+            return transmute_copy::<u64, I>(&r);
+        }
+        if I::BITS == 32 {
+            let x = transmute_copy::<I, u32>(&x);
+            let y = transmute_copy::<I, u32>(&y);
+            let r = fast_reduction::<f32, u32>(x, e, y);
+            return transmute_copy::<u32, I>(&r);
+        }
+        #[cfg(f16_enabled)]
+        if I::BITS == 16 {
+            let x = transmute_copy::<I, u16>(&x);
+            let y = transmute_copy::<I, u16>(&y);
+            let r = fast_reduction::<f16, u16>(x, e, y);
+            return transmute_copy::<u16, I>(&r);
+        }
+    }
+
     x %= y;
     for _ in 0..e {
         x <<= 1;
@@ -66,3 +94,68 @@ fn reduction<I: Int>(mut x: I, e: u32, y: I) -> I {
     }
     x
 }
+
+trait SafeShift: Float {
+    // How many guaranteed leading zeros do the values have?
+    // A normalized floating point mantissa has `EXP_BITS` guaranteed leading
+    // zeros (exludes the implicit bit, but includes the now-zeroed sign bit)
+    // `-1` because we want to shift by either `BASE_SHIFT` or `BASE_SHIFT + 1`
+    const BASE_SHIFT: u32 = Self::EXP_BITS - 1;
+}
+impl<F: Float> SafeShift for F {}
+
+fn fast_reduction<F, I>(x: I, e: u32, y: I) -> I
+where
+    F: Float<Int = I>,
+    I: Int + HInt,
+    I::D: Int + DInt<H = I>,
+{
+    let _0 = I::ZERO;
+    let _1 = I::ONE;
+
+    if y == _1 {
+        return _0;
+    }
+
+    if e <= F::BASE_SHIFT {
+        return (x << e) % y;
+    }
+
+    // Find least depth s.t. `(e >> depth) < I::BITS`
+    let depth = (I::BITS - 1)
+        .leading_zeros()
+        .saturating_sub(e.leading_zeros());
+
+    let initial = (e >> depth) - F::BASE_SHIFT;
+
+    let max_rem = y.wrapping_sub(_1);
+    let max_ilog2 = max_rem.ilog2();
+    let mut pow2 = _1 << max_ilog2.min(initial);
+    for _ in max_ilog2..initial {
+        pow2 <<= 1;
+        pow2 = pow2.checked_sub(y).unwrap_or(pow2);
+    }
+
+    // At each step `k in [depth, ..., 0]`,
+    // `p` is `(e >> k) - BASE_SHIFT`
+    // `m` is `(1 << p) % y`
+    let mut k = depth;
+    let mut p = initial;
+    let mut m = Reducer::new(pow2, y);
+
+    while k > 0 {
+        k -= 1;
+        p = p + p + F::BASE_SHIFT;
+        if e & (1 << k) != 0 {
+            m = m.squared_with_shift(F::BASE_SHIFT + 1);
+            p += 1;
+        } else {
+            m = m.squared_with_shift(F::BASE_SHIFT);
+        };
+
+        debug_assert!(p == (e >> k) - F::BASE_SHIFT);
+    }
+
+    // (x << BASE_SHIFT) * (1 << p) == x << e
+    m.mul_into_div_rem(x << F::BASE_SHIFT).1
+}
diff --git a/libm/src/math/support/int_traits/mod_mul.rs b/libm/src/math/support/int_traits/mod_mul.rs
index af8f4a398..c770a122b 100644
--- a/libm/src/math/support/int_traits/mod_mul.rs
+++ b/libm/src/math/support/int_traits/mod_mul.rs
@@ -2,14 +2,11 @@ use super::{DInt, HInt, Int};
 
 /// Barrett reduction using the constant `R == (1 << K) == (1 << U::BITS)`
 ///
-/// More specifically, implements single-word [Barrett multiplication]
-/// (https://en.wikipedia.org/wiki/Barrett_reduction#Single-word_Barrett_multiplication)
-/// and [division]
-/// (https://en.wikipedia.org/wiki/Barrett_reduction#Barrett_Division)
-/// for unsigned integers.
+/// For a more detailed description, see
+/// <https://en.wikipedia.org/wiki/Barrett_reduction>.
 ///
 /// After constructing as `Reducer::new(b, n)`,
-/// provides operations to efficiently compute
+/// has operations to efficiently compute
 ///  - `(a * b) / n` and `(a * b) % n`
 ///  - `Reducer::new((a * b * b) % n, n)`, as long as `a * (n - 1) < R`
 #[derive(Clone, Copy, PartialEq, Eq, Debug)]
@@ -103,7 +100,7 @@ where
             let ar_ns = a.widen_mul(self.rem) + _s.widen_mul(self.div);
             assert!(ab_tn.hi().is_zero());
             assert!(ar_ns.lo().is_zero());
-            assert_eq!(ab_tn.lo(), ar_ns.hi());
+            assert!(ab_tn.lo() == ar_ns.hi());
         }
         // Since `s < R` and `r < n`,
         // ```
@@ -124,7 +121,7 @@ where
     /// Requires `r * ab == ra * b`, where `r = bR % n`.
     #[inline(always)]
     fn with_scaled_num_rem(&self, ab: U, ra: U) -> Self {
-        debug_assert_eq!(ab.widen_mul(self.rem), ra.widen_mul(self.num));
+        debug_assert!(ab.widen_mul(self.rem) == ra.widen_mul(self.num));
         // The new factor `v = abb mod n`:
         let (_, v) = self.mul_into_div_rem(ab);