Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenfeizhang committed Jun 18, 2024
1 parent ba69d14 commit 8636b2c
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 129 deletions.
235 changes: 109 additions & 126 deletions bi-kzg/src/bi_fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ fn bitreverse(mut n: usize, l: usize) -> usize {

#[inline]
fn deep_swap_chunks<F: Clone + Copy>(a: &mut [&mut [F]], rk: usize, k: usize) {
// hmmm maybe use unsafe functions to avoid two clones here?
let buf1 = (a[k]).to_vec();
let buf2 = (a[rk]).to_vec();
a[rk].copy_from_slice(&buf1);
a[k].copy_from_slice(&buf2);
let x = a[k].as_mut_ptr();
let y = a[rk].as_mut_ptr();
unsafe {
for i in 0..a[k].len() {
std::ptr::swap(x.add(i), y.add(i));
}
}
}

#[inline]
fn assign_vec<F: Field>(a: &mut [F], b: &[F]) {
a.iter_mut().zip(b.iter()).for_each(|(a, b)| *a = *b);
}
Expand Down Expand Up @@ -58,9 +61,9 @@ fn mul_assign_vec<F: Field>(a: &mut [F], b: &F) {
/// by $n$.
///
/// This will use multithreading if beneficial.
pub fn best_fft_vec<F: PrimeField>(a: &mut [F], omega: F, log_n: u32, log_m: u32) {
pub fn best_fft_vec_in_place<F: PrimeField>(a: &mut [F], omega: F, log_n: u32, log_m: u32) {
let threads = rayon::current_num_threads();
let _log_threads = threads.ilog2();
let log_threads = threads.ilog2();
let mn = a.len();
let m = 1 << log_m;
let n = 1 << log_n;
Expand All @@ -86,113 +89,105 @@ pub fn best_fft_vec<F: PrimeField>(a: &mut [F], omega: F, log_n: u32, log_m: u32
})
.collect();

// if log_n <= log_threads {
let mut chunk = 2_usize;
let mut twiddle_chunk = m / 2;
for _ in 0..log_m {
a_vec_ptrs.chunks_mut(chunk).for_each(|coeffs| {
let (left, right) = coeffs.split_at_mut(chunk / 2);

// case when twiddle factor is one
let (a, left) = left.split_at_mut(1);
let (b, right) = right.split_at_mut(1);
let t = b[0].to_vec();

// b[0] = a[0];
// a[0] += &t;
// b[0] -= &t;
assign_vec(b[0], a[0]);
add_assign_vec(a[0], &t);
sub_assign_vec(b[0], &t);

left.iter_mut()
.zip(right.iter_mut())
.enumerate()
.for_each(|(i, (a, b))| {
let mut t = b.to_vec();

// t *= &twiddles[(i + 1) * twiddle_chunk];
// *b = *a;
// *a += &t;
// *b -= &t;

mul_assign_vec(&mut t, &twiddles[(i + 1) * twiddle_chunk]);
assign_vec(b, a);
add_assign_vec(a, &t);
sub_assign_vec(b, &t);
});
});
chunk *= 2;
twiddle_chunk /= 2;
if log_n <= log_threads {
let mut chunk = 2_usize;
let mut twiddle_chunk = m / 2;
for _ in 0..log_m {
a_vec_ptrs.chunks_mut(chunk).for_each(|coeffs| {
let (left, right) = coeffs.split_at_mut(chunk / 2);

// case when twiddle factor is one
let (a, left) = left.split_at_mut(1);
let (b, right) = right.split_at_mut(1);
let t = b[0].to_vec();

// compute the following in vectors
// b[0] = a[0];
// a[0] += &t;
// b[0] -= &t;
assign_vec(b[0], a[0]);
add_assign_vec(a[0], &t);
sub_assign_vec(b[0], &t);

left.iter_mut()
.zip(right.iter_mut())
.enumerate()
.for_each(|(i, (a, b))| {
let mut t = b.to_vec();

// compute the following in vectors
// t *= &twiddles[(i + 1) * twiddle_chunk];
// *b = *a;
// *a += &t;
// *b -= &t;
mul_assign_vec(&mut t, &twiddles[(i + 1) * twiddle_chunk]);
assign_vec(b, a);
add_assign_vec(a, &t);
sub_assign_vec(b, &t);
});
});
chunk *= 2;
twiddle_chunk /= 2;
}
} else {
recursive_butterfly_arithmetic(&mut a_vec_ptrs, m, n, 1, &twiddles)
}
// // } else {
// // recursive_butterfly_arithmetic(a, n, 1, &twiddles)
// // }
// println!("before is correct: {:?}", a_vec_ptrs);

// for k in 0..m {
// let rk = bitreverse(k, log_m as usize);

// if k < rk {
// println!("k: {}, rk: {}", k, rk);
// for a in a_vec_ptrs.iter().enumerate() {
// println!("{}: {:?}", a.0, a.1);
// }

// a_vec_ptrs.swap(rk, k);
// // swap_chunks(a_vec_ptrs[k], log_n);
// // swap_chunks(a_vec_ptrs[rk], log_n);

// for a in a_vec_ptrs.iter().enumerate() {
// println!("{}: {:?}", a.0, a.1);
// }

// println!();
// }
// }

// println!("after: {:?}", a_vec_ptrs);
}

// /// This perform recursive butterfly arithmetic
// pub fn recursive_butterfly_arithmetic<Scalar: Field, G: FftGroup<Scalar>>(
// a: &mut [G],
// n: usize,
// twiddle_chunk: usize,
// twiddles: &[Scalar],
// ) {
// if n == 2 {
// let t = a[1];
// a[1] = a[0];
// a[0] += &t;
// a[1] -= &t;
// } else {
// let (left, right) = a.split_at_mut(n / 2);
// rayon::join(
// || recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles),
// || recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles),
// );

// // case when twiddle factor is one
// let (a, left) = left.split_at_mut(1);
// let (b, right) = right.split_at_mut(1);
// let t = b[0];
// b[0] = a[0];
// a[0] += &t;
// b[0] -= &t;
/// This perform recursive butterfly arithmetic
pub fn recursive_butterfly_arithmetic<F: PrimeField>(
a: &mut [&mut [F]],
m: usize,
n: usize,
twiddle_chunk: usize,
twiddles: &[F],
) {
if m == 2 {
let t1 = a[1].to_vec();
let t0 = a[0].to_vec();
// compute the following in vectors
// a[1] = a[0];
// a[0] += &t;
// a[1] -= &t;
assign_vec(a[1], &t0);
add_assign_vec(a[0], &t1);
sub_assign_vec(a[1], &t1);
} else {
let (left, right) = a.split_at_mut(m / 2);
rayon::join(
|| recursive_butterfly_arithmetic(left, m / 2, n, twiddle_chunk * 2, twiddles),
|| recursive_butterfly_arithmetic(right, m / 2, n, twiddle_chunk * 2, twiddles),
);

// left.iter_mut()
// .zip(right.iter_mut())
// .enumerate()
// .for_each(|(i, (a, b))| {
// let mut t = *b;
// t *= &twiddles[(i + 1) * twiddle_chunk];
// *b = *a;
// *a += &t;
// *b -= &t;
// });
// }
// }
// case when twiddle factor is one
let (a, left) = left.split_at_mut(1);
let (b, right) = right.split_at_mut(1);
let t = b[0].to_vec();
// compute the following in vectors
// b[0] = a[0];
// a[0] += &t;
// b[0] -= &t;
assign_vec(b[0], a[0]);
add_assign_vec(a[0], &t);
sub_assign_vec(b[0], &t);

left.iter_mut()
.zip(right.iter_mut())
.enumerate()
.for_each(|(i, (a, b))| {
let mut t = b.to_vec();
// compute the following in vectors
// t *= &twiddles[(i + 1) * twiddle_chunk];
// *b = *a;
// *a += &t;
// *b -= &t;
mul_assign_vec(&mut t, &twiddles[(i + 1) * twiddle_chunk]);
assign_vec(b, a);
add_assign_vec(a, &t);
sub_assign_vec(b, &t);
});
}
}

pub(crate) fn bi_fft_in_place<F: PrimeField>(coeffs: &mut [F], degree_n: usize, degree_m: usize) {
// roots of unity for supported_n and supported_m
Expand All @@ -216,9 +211,7 @@ pub(crate) fn bi_fft_in_place<F: PrimeField>(coeffs: &mut [F], degree_n: usize,
.chunks_exact_mut(degree_n)
.for_each(|chunk| best_fft(chunk, omega_0, log2(degree_n)));

// println!("before: {:?}", coeffs);
best_fft_vec(coeffs, omega_1, log2(degree_n), log2(degree_m));
// println!("after: {:?}", coeffs);
best_fft_vec_in_place(coeffs, omega_1, log2(degree_n), log2(degree_m));
}

#[cfg(test)]
Expand Down Expand Up @@ -260,27 +253,17 @@ mod tests {
let mut poly_lag2 = poly.coefficients.clone();
let poly_lag = poly.interpolate();
bi_fft_in_place(&mut poly_lag2, n, m);

// for (i, (a, b)) in poly_lag.iter().zip(poly_lag2.iter()).enumerate() {
// println!("{}: {:?} == {:?}", i, a, b);
// }

// println!("correct one {:?}", poly_lag);
// println!();
// println!("{:?}", poly_lag2);
// println!();
assert_eq!(poly_lag, poly_lag2);
}

let mut rng = test_rng();

for m in [2, 4, 8, 16].iter() {
for n in [2, 4, 8, 16].iter() {
for m in [2, 4, 8, 16, 32, 64].iter() {
for n in [2, 4, 8, 16, 32, 64].iter() {
let poly = BivariatePolynomial::<Fr>::random(&mut rng, *n, *m);
let mut poly_lag2 = poly.coefficients.clone();
let poly_lag = poly.interpolate();
let poly_lag = poly.evaluate_at_roots();
bi_fft_in_place(&mut poly_lag2, *n, *m);
// println!("m = {}, n = {}: {}", m, n, poly_lag == poly_lag2);
assert_eq!(poly_lag, poly_lag2);
}
}
Expand Down
40 changes: 37 additions & 3 deletions bi-kzg/src/poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,43 @@ impl<F: PrimeField> BivariatePolynomial<F> {
f_x_b
}

///
// TODO: this is super slow. Implement FFT for bivariate polynomials.
/// same as interpolate but slower.
pub fn evaluate_at_roots(&self) -> Vec<F> {
let timer = start_timer!(|| format!(
"Lagrange coefficients of degree {} {}",
self.degree_0, self.degree_1
));

// roots of unity for supported_n and supported_m
let (omega_0, omega_1) = {
let omega = F::ROOT_OF_UNITY;
let omega_0 = omega.pow_vartime(&[(1 << F::S) / self.degree_0 as u64]);
let omega_1 = omega.pow_vartime(&[(1 << F::S) / self.degree_1 as u64]);

assert!(
omega_0.pow_vartime(&[self.degree_0 as u64]) == F::ONE,
"omega_0 is not root of unity for supported_n"
);
assert!(
omega_1.pow_vartime(&[self.degree_1 as u64]) == F::ONE,
"omega_1 is not root of unity for supported_m"
);
(omega_0, omega_1)
};
let powers_of_omega_0 = powers_of_field_elements(&omega_0, self.degree_0);
let powers_of_omega_1 = powers_of_field_elements(&omega_1, self.degree_1);

let mut res = vec![];
for omega_1_power in powers_of_omega_1.iter() {
for omega_0_power in powers_of_omega_0.iter() {
res.push(self.evaluate(omega_0_power, omega_1_power));
}
}
end_timer!(timer);
res
}

/// interpolate the polynomial over the roots via bi-variate FFT
pub fn interpolate(&self) -> Vec<F> {
let timer = start_timer!(|| format!(
"Lagrange coefficients of degree {} {}",
Expand All @@ -71,7 +106,6 @@ impl<F: PrimeField> BivariatePolynomial<F> {
bi_fft_in_place(&mut coeff, self.degree_0, self.degree_1);
end_timer!(timer);
coeff

}
}

Expand Down

0 comments on commit 8636b2c

Please sign in to comment.