Skip to content

Commit d24830e

Browse files
committed
Implement accelerated computation of (x << e) % y in unsigned integers
1 parent 0f55279 commit d24830e

File tree

3 files changed

+304
-1
lines changed

3 files changed

+304
-1
lines changed

libm/src/math/support/int_traits/narrowing_div.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ use crate::support::{CastInto, DInt, HInt, Int, MinInt, u256};
77
/// This is the inverse of widening multiplication:
88
/// - for any `x` and nonzero `y`: `x.widen_mul(y).checked_narrowing_div_rem(y) == Some((x, 0))`,
99
/// - and for any `r in 0..y`: `x.carrying_mul(y, r).checked_narrowing_div_rem(y) == Some((x, r))`,
10-
#[allow(dead_code)]
1110
pub trait NarrowingDiv: DInt + MinInt<Unsigned = Self> {
1211
/// Computes `(self / n, self % n))`
1312
///

libm/src/math/support/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ pub(crate) mod feature_detect;
88
mod float_traits;
99
pub mod hex_float;
1010
mod int_traits;
11+
mod modular;
1112

1213
#[allow(unused_imports)]
1314
pub use big::{i256, u256};
@@ -30,6 +31,8 @@ pub use hex_float::hf128;
3031
pub use hex_float::{hf32, hf64};
3132
#[allow(unused_imports)]
3233
pub use int_traits::{CastFrom, CastInto, DInt, HInt, Int, MinInt, NarrowingDiv};
34+
#[allow(unused_imports)]
35+
pub use modular::linear_mul_reduction;
3336

3437
/// Hint to the compiler that the current path is cold.
3538
pub fn cold_path() {

libm/src/math/support/modular.rs

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
/* SPDX-License-Identifier: MIT OR Apache-2.0 */
2+
3+
//! To keep the equations somewhat concise, the following conventions are used:
4+
//! - all integer operations are in the mathematical sense, without overflow
5+
//! - concatenation means multiplication: `2xq = 2 * x * q`
6+
//! - `R = (1 << U::BITS)` is the modulus of wrapping arithmetic in `U`
7+
8+
use crate::support::int_traits::NarrowingDiv;
9+
use crate::support::{DInt, HInt, Int};
10+
11+
/// Compute the remainder `(x << e) % y` with unbounded integers.
12+
/// Requires `x < 2y` and `y.leading_zeros() >= 2`
13+
#[allow(dead_code)]
14+
pub fn linear_mul_reduction<U>(x: U, mut e: u32, mut y: U) -> U
15+
where
16+
U: HInt + Int<Unsigned = U>,
17+
U::D: NarrowingDiv,
18+
{
19+
assert!(y <= U::MAX >> 2);
20+
assert!(x < (y << 1));
21+
let _0 = U::ZERO;
22+
let _1 = U::ONE;
23+
24+
// power of two divisors
25+
if (y & (y - _1)).is_zero() {
26+
if e < U::BITS {
27+
// shift and only keep low bits
28+
return (x << e) & (y - _1);
29+
} else {
30+
// would shift out all the bits
31+
return _0;
32+
}
33+
}
34+
35+
// Use the identity `(x << e) % y == ((x << (e + s)) % (y << s)) >> s`
36+
// to shift the divisor so it has exactly two leading zeros to satisfy
37+
// the precondition of `Reducer::new`
38+
let s = y.leading_zeros() - 2;
39+
e += s;
40+
y <<= s;
41+
42+
// `m: Reducer` keeps track of the remainder `x` in a form that makes it
43+
// very efficient to do `x <<= k` modulo `y` for integers `k < U::BITS`
44+
let mut m = Reducer::new(x, y);
45+
46+
// Use the faster special case with constant `k == U::BITS - 1` while we can
47+
while e >= U::BITS - 1 {
48+
m.word_reduce();
49+
e -= U::BITS - 1;
50+
}
51+
// Finish with the variable shift operation
52+
m.shift_reduce(e);
53+
54+
// The partial remainder is in `[0, 2y)` ...
55+
let r = m.partial_remainder();
56+
// ... so check and correct, and compensate for the earlier shift.
57+
r.checked_sub(y).unwrap_or(r) >> s
58+
}
59+
60+
/// Helper type for computing the reductions. The implementation has a number
61+
/// of seemingly weird choices, but everything is aimed at streamlining
62+
/// `Reducer::word_reduce` into its current form.
63+
///
64+
/// Implicitly contains:
65+
/// n in (R/8, R/4)
66+
/// x in [0, 2n)
67+
/// The value of `n` is fixed for a given `Reducer`,
68+
/// but the value of `x` is modified by the methods.
69+
#[derive(Debug, Clone, PartialEq, Eq)]
70+
struct Reducer<U: HInt> {
71+
// m = 2n
72+
m: U,
73+
// q = (RR/2) / m
74+
// r = (RR/2) % m
75+
// Then RR/2 = qm + r, where `0 <= r < m`
76+
// The value `q` is only needed during construction, so isn't saved.
77+
r: U,
78+
// The value `x` is implicitly stored as `2 * q * x`:
79+
_2xq: U::D,
80+
}
81+
82+
impl<U> Reducer<U>
83+
where
84+
U: HInt,
85+
U: Int<Unsigned = U>,
86+
{
87+
/// Construct a reducer for `(x << _) mod n`.
88+
///
89+
/// Requires `R/8 < n < R/4` and `x < 2n`.
90+
fn new(x: U, n: U) -> Self
91+
where
92+
U::D: NarrowingDiv,
93+
{
94+
let _1 = U::ONE;
95+
assert!(n > (_1 << (U::BITS - 3)));
96+
assert!(n < (_1 << (U::BITS - 2)));
97+
let m = n << 1;
98+
assert!(x < m);
99+
100+
// We need to compute the parameters
101+
// `q = (RR/2) / m`
102+
// `r = (RR/2) % m`
103+
104+
// Since `m` is in `(R/4, R/2)`, the quotient `q` is in `[R, 2R)`, and
105+
// it would overflow in `U` if computed directly. Instead, we compute
106+
// `f = q - R`, which is in `[0, R)`. To do so, we simply subtract `Rm`
107+
// from the dividend, which doesn't change the remainder:
108+
// `f = R(R/2 - m) / m`
109+
// `r = R(R/2 - m) % m`
110+
let dividend = ((_1 << (U::BITS - 1)) - m).widen_hi();
111+
let (f, r) = dividend.checked_narrowing_div_rem(m).unwrap();
112+
113+
// As `x < m`, `xq < qm <= RR/2`
114+
// Thus `2xq = 2xR + 2xf` does not overflow in `U::D`.
115+
let _2x = x + x;
116+
let _2xq = _2x.widen_hi() + _2x.widen_mul(f);
117+
Self { m, r, _2xq }
118+
}
119+
120+
/// Extract the current remainder `x` in the range `[0, 2n)`
121+
fn partial_remainder(&self) -> U {
122+
// `RR/2 = qm + r`, where `0 <= r < m`
123+
// `2xq = uR + v`, where `0 <= v < R`
124+
125+
// The goal is to extract the current value of `x` from the value `2xq`
126+
// that we actually have. A bit simplified, we could multiply it by `m`
127+
// to obtain `2xqm == 2x(RR/2 - r) == xRR - 2xr`, where `2xr < RR`.
128+
// We could just round that up to the next multiple of `RR` to get `x`,
129+
// but we can avoid having to multiply the full double-wide `2xq` by
130+
// making a couple of adjustments:
131+
132+
// First, let's only use the high half `u` for the product, and
133+
// include an additional error term due to the truncation:
134+
// `mu = xR - (2xr + mv)/R`
135+
136+
// Next, show bounds for the error term
137+
// `0 <= mv < mR` follows from `0 <= v < R`
138+
// `0 <= 2xr < mR` follows from `0 <= x < m < R/2` and `0 <= r < m`
139+
// Adding those together, we have:
140+
// `0 <= (mv + 2xr)/R < 2m`
141+
// Which also implies:
142+
// `0 < 2m - (mv + 2xr)/R <= 2m < R`
143+
144+
// For that reason, we can use `u + 2` as the factor to obtain
145+
// `m(u + 2) = xR + (2m - (mv + 2xr)/R)`
146+
// By the previous inequality, the second term fits neatly in the lower
147+
// half, so we get exactly `x` as the high half.
148+
let u = self._2xq.hi();
149+
let _2 = U::ONE + U::ONE;
150+
self.m.widen_mul(u + _2).hi()
151+
152+
// Additionally, we should ensure that `u + 2` cannot overflow:
153+
// Since `x < m` and `2qm <= RR`,
154+
// `2xq <= 2q(m-1) <= RR - 2q`
155+
// As we also have `q > R`,
156+
// `2xq < RR - 2R`
157+
// which is sufficient.
158+
}
159+
160+
/// Replace the remainder `x` with `(x << k) - un`,
161+
/// for a suitable quotient `u`, which is returned.
162+
///
163+
/// Requires that `k < U::BITS`.
164+
fn shift_reduce(&mut self, k: u32) -> U {
165+
assert!(k < U::BITS);
166+
167+
// First, split the shifted value:
168+
// `2xq << k = aRR/2 + b`, where `0 <= b < RR/2`
169+
let a = self._2xq.hi() >> (U::BITS - 1 - k);
170+
let (low, high) = (self._2xq << k).lo_hi();
171+
let b = U::D::from_lo_hi(low, high & (U::MAX >> 1));
172+
173+
// Then, subtract `2anq = aqm`:
174+
// ```
175+
// (2xq << k) - aqm
176+
// = aRR/2 + b - aqm
177+
// = a(RR/2 - qm) + b
178+
// = ar + b
179+
// ```
180+
self._2xq = a.widen_mul(self.r) + b;
181+
a
182+
183+
// Since `a` is at most the high half of `2xq`, we have
184+
// `a + 2 < R` (shown above, in `partial_remainder`)
185+
// Using that together with `b < RR/2` and `r < m < R/2`,
186+
// we get `(a + 2)r + b < RR`, so
187+
// `ar + b < RR - 2r = 2mq`
188+
// which shows that the new remainder still satisfies `x < m`.
189+
}
190+
191+
// NB: `word_reduce()` is just the special case `shift_reduce(U::BITS - 1)`
192+
// that optimizes especially well. The correspondence is that `a == u` and
193+
// `b == (v >> 1).widen_hi()`
194+
//
195+
/// Replace the remainder `x` with `x(R/2) - un`,
196+
/// for a suitable quotient `u`, which is returned.
197+
fn word_reduce(&mut self) -> U {
198+
// To do so, we replace `2xq = uR + v` with
199+
// ```
200+
// 2 * (x(R/2) - un) * q
201+
// = xqR - 2unq
202+
// = xqR - uqm
203+
// = uRR/2 + vR/2 - uRR/2 + ur
204+
// = ur + (v/2)R
205+
// ```
206+
let (v, u) = self._2xq.lo_hi();
207+
self._2xq = u.widen_mul(self.r) + U::widen_hi(v >> 1);
208+
u
209+
210+
// Additional notes:
211+
// 1. As `v` is the low bits of `2xq`, it is even and can be halved.
212+
// 2. The new remainder is `(xr + mv/2) / R` (see below)
213+
// and since `v < R`, `r < m`, `x < m < R/2`,
214+
// that is also strictly less than `m`.
215+
// ```
216+
// (x(R/2) - un)R
217+
// = xRR/2 - (m/2)uR
218+
// = x(qm + r) - (m/2)(2xq - v)
219+
// = xqm + xr - xqm + mv/2
220+
// = xr + mv/2
221+
// ```
222+
}
223+
}
224+
225+
#[cfg(test)]
226+
mod test {
227+
use crate::support::linear_mul_reduction;
228+
use crate::support::modular::Reducer;
229+
230+
#[test]
231+
fn reducer_ops() {
232+
for n in 33..=63_u8 {
233+
for x in 0..2 * n {
234+
let temp = Reducer::new(x, n);
235+
let n = n as u32;
236+
let x0 = temp.partial_remainder() as u32;
237+
assert_eq!(x as u32, x0);
238+
for k in 0..=7 {
239+
let mut red = temp.clone();
240+
let u = red.shift_reduce(k) as u32;
241+
let x1 = red.partial_remainder() as u32;
242+
assert_eq!(x1, (x0 << k) - u * n);
243+
assert!(x1 < 2 * n);
244+
assert!((red._2xq as u32).is_multiple_of(2 * x1));
245+
246+
// `word_reduce` is equivalent to
247+
// `shift_reduce(U::BITS - 1)`
248+
if k == 7 {
249+
let mut alt = temp.clone();
250+
let w = alt.word_reduce();
251+
assert_eq!(u, w as u32);
252+
assert_eq!(alt, red);
253+
}
254+
}
255+
}
256+
}
257+
}
258+
#[test]
259+
fn reduction_u8() {
260+
for y in 1..64u8 {
261+
for x in 0..2 * y {
262+
let mut r = x % y;
263+
for e in 0..100 {
264+
assert_eq!(r, linear_mul_reduction(x, e, y));
265+
// maintain the correct expected remainder
266+
r <<= 1;
267+
if r >= y {
268+
r -= y;
269+
}
270+
}
271+
}
272+
}
273+
}
274+
#[test]
275+
fn reduction_u128() {
276+
assert_eq!(
277+
linear_mul_reduction::<u128>(17, 100, 123456789),
278+
(17 << 100) % 123456789
279+
);
280+
281+
// power-of-two divisor
282+
assert_eq!(
283+
linear_mul_reduction(0xdead_beef, 100, 1_u128 << 116),
284+
0xbeef << 100
285+
);
286+
287+
let x = 10_u128.pow(37);
288+
let y = 11_u128.pow(36);
289+
assert!(x < y);
290+
let mut r = x;
291+
for e in 0..1000 {
292+
assert_eq!(r, linear_mul_reduction(x, e, y));
293+
// maintain the correct expected remainder
294+
r <<= 1;
295+
if r >= y {
296+
r -= y;
297+
}
298+
assert!(r != 0);
299+
}
300+
}
301+
}

0 commit comments

Comments
 (0)