| 
 | 1 | +/* SPDX-License-Identifier: MIT OR Apache-2.0 */  | 
 | 2 | + | 
 | 3 | +//! To keep the equations somewhat concise, the following conventions are used:  | 
 | 4 | +//!  - all integer operations are in the mathematical sense, without overflow  | 
 | 5 | +//!  - concatenation means multiplication: `2xq = 2 * x * q`  | 
 | 6 | +//!  - `R = (1 << U::BITS)` is the modulus of wrapping arithmetic in `U`  | 
 | 7 | +
  | 
 | 8 | +use crate::support::int_traits::NarrowingDiv;  | 
 | 9 | +use crate::support::{DInt, HInt, Int};  | 
 | 10 | + | 
 | 11 | +/// Compute the remainder `(x << e) % y` with unbounded integers.  | 
 | 12 | +/// Requires `x < 2y` and `y.leading_zeros() >= 2`  | 
 | 13 | +#[allow(dead_code)]  | 
 | 14 | +pub fn linear_mul_reduction<U>(x: U, mut e: u32, mut y: U) -> U  | 
 | 15 | +where  | 
 | 16 | +    U: HInt + Int<Unsigned = U>,  | 
 | 17 | +    U::D: NarrowingDiv,  | 
 | 18 | +{  | 
 | 19 | +    assert!(y <= U::MAX >> 2);  | 
 | 20 | +    assert!(x < (y << 1));  | 
 | 21 | +    let _0 = U::ZERO;  | 
 | 22 | +    let _1 = U::ONE;  | 
 | 23 | + | 
 | 24 | +    // power of two divisors  | 
 | 25 | +    if (y & (y - _1)).is_zero() {  | 
 | 26 | +        if e < U::BITS {  | 
 | 27 | +            // shift and only keep low bits  | 
 | 28 | +            return (x << e) & (y - _1);  | 
 | 29 | +        } else {  | 
 | 30 | +            // would shift out all the bits  | 
 | 31 | +            return _0;  | 
 | 32 | +        }  | 
 | 33 | +    }  | 
 | 34 | + | 
 | 35 | +    // Use the identity `(x << e) % y == ((x << (e + s)) % (y << s)) >> s`  | 
 | 36 | +    // to shift the divisor so it has exactly two leading zeros to satisfy  | 
 | 37 | +    // the precondition of `Reducer::new`  | 
 | 38 | +    let s = y.leading_zeros() - 2;  | 
 | 39 | +    e += s;  | 
 | 40 | +    y <<= s;  | 
 | 41 | + | 
 | 42 | +    // `m: Reducer` keeps track of the remainder `x` in a form that makes it  | 
 | 43 | +    //  very efficient to do `x <<= k` modulo `y` for integers `k < U::BITS`  | 
 | 44 | +    let mut m = Reducer::new(x, y);  | 
 | 45 | + | 
 | 46 | +    // Use the faster special case with constant `k == U::BITS - 1` while we can  | 
 | 47 | +    while e >= U::BITS - 1 {  | 
 | 48 | +        m.word_reduce();  | 
 | 49 | +        e -= U::BITS - 1;  | 
 | 50 | +    }  | 
 | 51 | +    // Finish with the variable shift operation  | 
 | 52 | +    m.shift_reduce(e);  | 
 | 53 | + | 
 | 54 | +    // The partial remainder is in `[0, 2y)` ...  | 
 | 55 | +    let r = m.partial_remainder();  | 
 | 56 | +    // ... so check and correct, and compensate for the earlier shift.  | 
 | 57 | +    r.checked_sub(y).unwrap_or(r) >> s  | 
 | 58 | +}  | 
 | 59 | + | 
 | 60 | +/// Helper type for computing the reductions. The implementation has a number  | 
 | 61 | +/// of seemingly weird choices, but everything is aimed at streamlining  | 
 | 62 | +/// `Reducer::word_reduce` into its current form.  | 
 | 63 | +///  | 
 | 64 | +/// Implicitly contains:  | 
 | 65 | +///  n in (R/8, R/4)  | 
 | 66 | +///  x in [0, 2n)  | 
 | 67 | +/// The value of `n` is fixed for a given `Reducer`,  | 
 | 68 | +/// but the value of `x` is modified by the methods.  | 
 | 69 | +#[derive(Debug, Clone, PartialEq, Eq)]  | 
 | 70 | +struct Reducer<U: HInt> {  | 
 | 71 | +    // m = 2n  | 
 | 72 | +    m: U,  | 
 | 73 | +    // q = (RR/2) / m  | 
 | 74 | +    // r = (RR/2) % m  | 
 | 75 | +    // Then RR/2 = qm + r, where `0 <= r < m`  | 
 | 76 | +    // The value `q` is only needed during construction, so isn't saved.  | 
 | 77 | +    r: U,  | 
 | 78 | +    // The value `x` is implicitly stored as `2 * q * x`:  | 
 | 79 | +    _2xq: U::D,  | 
 | 80 | +}  | 
 | 81 | + | 
 | 82 | +impl<U> Reducer<U>  | 
 | 83 | +where  | 
 | 84 | +    U: HInt,  | 
 | 85 | +    U: Int<Unsigned = U>,  | 
 | 86 | +{  | 
 | 87 | +    /// Construct a reducer for `(x << _) mod n`.  | 
 | 88 | +    ///  | 
 | 89 | +    /// Requires `R/8 < n < R/4` and `x < 2n`.  | 
 | 90 | +    fn new(x: U, n: U) -> Self  | 
 | 91 | +    where  | 
 | 92 | +        U::D: NarrowingDiv,  | 
 | 93 | +    {  | 
 | 94 | +        let _1 = U::ONE;  | 
 | 95 | +        assert!(n > (_1 << (U::BITS - 3)));  | 
 | 96 | +        assert!(n < (_1 << (U::BITS - 2)));  | 
 | 97 | +        let m = n << 1;  | 
 | 98 | +        assert!(x < m);  | 
 | 99 | + | 
 | 100 | +        // We need to compute the parameters  | 
 | 101 | +        // `q = (RR/2) / m`  | 
 | 102 | +        // `r = (RR/2) % m`  | 
 | 103 | + | 
 | 104 | +        // Since `m` is in `(R/4, R/2)`, the quotient `q` is in `[R, 2R)`, and  | 
 | 105 | +        // it would overflow in `U` if computed directly. Instead, we compute  | 
 | 106 | +        // `f = q - R`, which is in `[0, R)`. To do so, we simply subtract `Rm`  | 
 | 107 | +        // from the dividend, which doesn't change the remainder:  | 
 | 108 | +        // `f = R(R/2 - m) / m`  | 
 | 109 | +        // `r = R(R/2 - m) % m`  | 
 | 110 | +        let dividend = ((_1 << (U::BITS - 1)) - m).widen_hi();  | 
 | 111 | +        let (f, r) = dividend.checked_narrowing_div_rem(m).unwrap();  | 
 | 112 | + | 
 | 113 | +        // As `x < m`, `xq < qm <= RR/2`  | 
 | 114 | +        // Thus `2xq = 2xR + 2xf` does not overflow in `U::D`.  | 
 | 115 | +        let _2x = x + x;  | 
 | 116 | +        let _2xq = _2x.widen_hi() + _2x.widen_mul(f);  | 
 | 117 | +        Self { m, r, _2xq }  | 
 | 118 | +    }  | 
 | 119 | + | 
 | 120 | +    /// Extract the current remainder `x` in the range `[0, 2n)`  | 
 | 121 | +    fn partial_remainder(&self) -> U {  | 
 | 122 | +        // `RR/2 = qm + r`, where `0 <= r < m`  | 
 | 123 | +        // `2xq = uR + v`,  where `0 <= v < R`  | 
 | 124 | + | 
 | 125 | +        // The goal is to extract the current value of `x` from the value `2xq`  | 
 | 126 | +        // that we actually have. A bit simplified, we could multiply it by `m`  | 
 | 127 | +        // to obtain `2xqm == 2x(RR/2 - r) == xRR - 2xr`, where `2xr < RR`.  | 
 | 128 | +        // We could just round that up to the next multiple of `RR` to get `x`,  | 
 | 129 | +        // but we can avoid having to multiply the full double-wide `2xq` by  | 
 | 130 | +        // making a couple of adjustments:  | 
 | 131 | + | 
 | 132 | +        // First, let's only use the high half `u` for the product, and  | 
 | 133 | +        // include an additional error term due to the truncation:  | 
 | 134 | +        //  `mu = xR - (2xr + mv)/R`  | 
 | 135 | + | 
 | 136 | +        // Next, show bounds for the error term  | 
 | 137 | +        //  `0 <= mv < mR` follows from `0 <= v < R`  | 
 | 138 | +        //  `0 <= 2xr < mR` follows from `0 <= x < m < R/2` and `0 <= r < m`  | 
 | 139 | +        // Adding those together, we have:  | 
 | 140 | +        //  `0 <= (mv + 2xr)/R < 2m`  | 
 | 141 | +        // Which also implies:  | 
 | 142 | +        //  `0 < 2m - (mv + 2xr)/R <= 2m < R`  | 
 | 143 | + | 
 | 144 | +        // For that reason, we can use `u + 2` as the factor to obtain  | 
 | 145 | +        //  `m(u + 2) = xR + (2m - (mv + 2xr)/R)`  | 
 | 146 | +        // By the previous inequality, the second term fits neatly in the lower  | 
 | 147 | +        // half, so we get exactly `x` as the high half.  | 
 | 148 | +        let u = self._2xq.hi();  | 
 | 149 | +        let _2 = U::ONE + U::ONE;  | 
 | 150 | +        self.m.widen_mul(u + _2).hi()  | 
 | 151 | + | 
 | 152 | +        // Additionally, we should ensure that `u + 2` cannot overflow:  | 
 | 153 | +        // Since `x < m` and `2qm <= RR`,  | 
 | 154 | +        //  `2xq <= 2q(m-1) <= RR - 2q`  | 
 | 155 | +        // As we also have `q > R`,  | 
 | 156 | +        //  `2xq < RR - 2R`  | 
 | 157 | +        // which is sufficient.  | 
 | 158 | +    }  | 
 | 159 | + | 
 | 160 | +    /// Replace the remainder `x` with `(x << k) - un`,  | 
 | 161 | +    /// for a suitable quotient `u`, which is returned.  | 
 | 162 | +    ///  | 
 | 163 | +    /// Requires that `k < U::BITS`.  | 
 | 164 | +    fn shift_reduce(&mut self, k: u32) -> U {  | 
 | 165 | +        assert!(k < U::BITS);  | 
 | 166 | + | 
 | 167 | +        // First, split the shifted value:  | 
 | 168 | +        // `2xq << k = aRR/2 + b`, where `0 <= b < RR/2`  | 
 | 169 | +        let a = self._2xq.hi() >> (U::BITS - 1 - k);  | 
 | 170 | +        let (low, high) = (self._2xq << k).lo_hi();  | 
 | 171 | +        let b = U::D::from_lo_hi(low, high & (U::MAX >> 1));  | 
 | 172 | + | 
 | 173 | +        // Then, subtract `2anq = aqm`:  | 
 | 174 | +        // ```  | 
 | 175 | +        // (2xq << k) - aqm  | 
 | 176 | +        // = aRR/2 + b - aqm  | 
 | 177 | +        // = a(RR/2 - qm) + b  | 
 | 178 | +        // = ar + b  | 
 | 179 | +        // ```  | 
 | 180 | +        self._2xq = a.widen_mul(self.r) + b;  | 
 | 181 | +        a  | 
 | 182 | + | 
 | 183 | +        // Since `a` is at most the high half of `2xq`, we have  | 
 | 184 | +        //  `a + 2 < R` (shown above, in `partial_remainder`)  | 
 | 185 | +        // Using that together with `b < RR/2` and `r < m < R/2`,  | 
 | 186 | +        // we get `(a + 2)r + b < RR`, so  | 
 | 187 | +        //  `ar + b < RR - 2r = 2mq`  | 
 | 188 | +        // which shows that the new remainder still satisfies `x < m`.  | 
 | 189 | +    }  | 
 | 190 | + | 
 | 191 | +    // NB: `word_reduce()` is just the special case `shift_reduce(U::BITS - 1)`  | 
 | 192 | +    // that optimizes especially well. The correspondence is that `a == u` and  | 
 | 193 | +    //  `b == (v >> 1).widen_hi()`  | 
 | 194 | +    //  | 
 | 195 | +    /// Replace the remainder `x` with `x(R/2) - un`,  | 
 | 196 | +    /// for a suitable quotient `u`, which is returned.  | 
 | 197 | +    fn word_reduce(&mut self) -> U {  | 
 | 198 | +        // To do so, we replace `2xq = uR + v` with  | 
 | 199 | +        // ```  | 
 | 200 | +        // 2 * (x(R/2) - un) * q  | 
 | 201 | +        // = xqR - 2unq  | 
 | 202 | +        // = xqR - uqm  | 
 | 203 | +        // = uRR/2 + vR/2 - uRR/2 + ur  | 
 | 204 | +        // = ur + (v/2)R  | 
 | 205 | +        // ```  | 
 | 206 | +        let (v, u) = self._2xq.lo_hi();  | 
 | 207 | +        self._2xq = u.widen_mul(self.r) + U::widen_hi(v >> 1);  | 
 | 208 | +        u  | 
 | 209 | + | 
 | 210 | +        // Additional notes:  | 
 | 211 | +        //  1. As `v` is the low bits of `2xq`, it is even and can be halved.  | 
 | 212 | +        //  2. The new remainder is `(xr + mv/2) / R` (see below)  | 
 | 213 | +        //      and since `v < R`, `r < m`, `x < m < R/2`,  | 
 | 214 | +        //      that is also strictly less than `m`.  | 
 | 215 | +        // ```  | 
 | 216 | +        // (x(R/2) - un)R  | 
 | 217 | +        //      = xRR/2 - (m/2)uR  | 
 | 218 | +        //      = x(qm + r) - (m/2)(2xq - v)  | 
 | 219 | +        //      = xqm + xr - xqm + mv/2  | 
 | 220 | +        //      = xr + mv/2  | 
 | 221 | +        // ```  | 
 | 222 | +    }  | 
 | 223 | +}  | 
 | 224 | + | 
 | 225 | +#[cfg(test)]  | 
 | 226 | +mod test {  | 
 | 227 | +    use crate::support::linear_mul_reduction;  | 
 | 228 | +    use crate::support::modular::Reducer;  | 
 | 229 | + | 
 | 230 | +    #[test]  | 
 | 231 | +    fn reducer_ops() {  | 
 | 232 | +        for n in 33..=63_u8 {  | 
 | 233 | +            for x in 0..2 * n {  | 
 | 234 | +                let temp = Reducer::new(x, n);  | 
 | 235 | +                let n = n as u32;  | 
 | 236 | +                let x0 = temp.partial_remainder() as u32;  | 
 | 237 | +                assert_eq!(x as u32, x0);  | 
 | 238 | +                for k in 0..=7 {  | 
 | 239 | +                    let mut red = temp.clone();  | 
 | 240 | +                    let u = red.shift_reduce(k) as u32;  | 
 | 241 | +                    let x1 = red.partial_remainder() as u32;  | 
 | 242 | +                    assert_eq!(x1, (x0 << k) - u * n);  | 
 | 243 | +                    assert!(x1 < 2 * n);  | 
 | 244 | +                    assert!((red._2xq as u32).is_multiple_of(2 * x1));  | 
 | 245 | + | 
 | 246 | +                    // `word_reduce` is equivalent to  | 
 | 247 | +                    // `shift_reduce(U::BITS - 1)`  | 
 | 248 | +                    if k == 7 {  | 
 | 249 | +                        let mut alt = temp.clone();  | 
 | 250 | +                        let w = alt.word_reduce();  | 
 | 251 | +                        assert_eq!(u, w as u32);  | 
 | 252 | +                        assert_eq!(alt, red);  | 
 | 253 | +                    }  | 
 | 254 | +                }  | 
 | 255 | +            }  | 
 | 256 | +        }  | 
 | 257 | +    }  | 
 | 258 | +    #[test]  | 
 | 259 | +    fn reduction_u8() {  | 
 | 260 | +        for y in 1..64u8 {  | 
 | 261 | +            for x in 0..2 * y {  | 
 | 262 | +                let mut r = x % y;  | 
 | 263 | +                for e in 0..100 {  | 
 | 264 | +                    assert_eq!(r, linear_mul_reduction(x, e, y));  | 
 | 265 | +                    // maintain the correct expected remainder  | 
 | 266 | +                    r <<= 1;  | 
 | 267 | +                    if r >= y {  | 
 | 268 | +                        r -= y;  | 
 | 269 | +                    }  | 
 | 270 | +                }  | 
 | 271 | +            }  | 
 | 272 | +        }  | 
 | 273 | +    }  | 
 | 274 | +    #[test]  | 
 | 275 | +    fn reduction_u128() {  | 
 | 276 | +        assert_eq!(  | 
 | 277 | +            linear_mul_reduction::<u128>(17, 100, 123456789),  | 
 | 278 | +            (17 << 100) % 123456789  | 
 | 279 | +        );  | 
 | 280 | + | 
 | 281 | +        // power-of-two divisor  | 
 | 282 | +        assert_eq!(  | 
 | 283 | +            linear_mul_reduction(0xdead_beef, 100, 1_u128 << 116),  | 
 | 284 | +            0xbeef << 100  | 
 | 285 | +        );  | 
 | 286 | + | 
 | 287 | +        let x = 10_u128.pow(37);  | 
 | 288 | +        let y = 11_u128.pow(36);  | 
 | 289 | +        assert!(x < y);  | 
 | 290 | +        let mut r = x;  | 
 | 291 | +        for e in 0..1000 {  | 
 | 292 | +            assert_eq!(r, linear_mul_reduction(x, e, y));  | 
 | 293 | +            // maintain the correct expected remainder  | 
 | 294 | +            r <<= 1;  | 
 | 295 | +            if r >= y {  | 
 | 296 | +                r -= y;  | 
 | 297 | +            }  | 
 | 298 | +            assert!(r != 0);  | 
 | 299 | +        }  | 
 | 300 | +    }  | 
 | 301 | +}  | 
0 commit comments