From 661ec8e7523b5c894f42d13041799b2d810120be Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Tue, 5 Mar 2024 14:17:09 +0800 Subject: [PATCH] add AVX implementation for Poseidon2 hash --- plonky2/src/hash/arch/x86_64/mod.rs | 2 + .../arch/x86_64/poseidon2_goldilocks_avx2.rs | 244 ++++++++++++++++++ plonky2/src/hash/poseidon2.rs | 84 +++--- 3 files changed, 296 insertions(+), 34 deletions(-) create mode 100644 plonky2/src/hash/arch/x86_64/poseidon2_goldilocks_avx2.rs diff --git a/plonky2/src/hash/arch/x86_64/mod.rs b/plonky2/src/hash/arch/x86_64/mod.rs index 0730b62614..20a6860f29 100644 --- a/plonky2/src/hash/arch/x86_64/mod.rs +++ b/plonky2/src/hash/arch/x86_64/mod.rs @@ -3,3 +3,5 @@ // // - BMI2 (for MULX and SHRX) // #[cfg(all(target_feature = "avx2", target_feature = "bmi2"))] // pub(crate) mod poseidon_goldilocks_avx2_bmi2; +#[cfg(target_feature = "avx2")] +pub mod poseidon2_goldilocks_avx2; \ No newline at end of file diff --git a/plonky2/src/hash/arch/x86_64/poseidon2_goldilocks_avx2.rs b/plonky2/src/hash/arch/x86_64/poseidon2_goldilocks_avx2.rs new file mode 100644 index 0000000000..82595773af --- /dev/null +++ b/plonky2/src/hash/arch/x86_64/poseidon2_goldilocks_avx2.rs @@ -0,0 +1,244 @@ +/// Code taken and adapted from: https://github.com/0xPolygonHermez/goldilocks/blob/master/src/goldilocks_base_field_avx.hpp + +use crate::hash::{hash_types::RichField, poseidon2::{apply_m_4, SPONGE_WIDTH}}; +use core::arch::x86_64::*; + +const MSB_: i64 = 0x8000000000000000u64 as i64; +const P_: i64 = 0xFFFFFFFF00000001u64 as i64; +const P_s_: i64 = 0x7FFFFFFF00000001u64 as i64; +const P_n_: i64 = 0xFFFFFFFF; + +#[inline] +fn shift_avx(a_s: &mut __m256i, a: &__m256i) +{ + unsafe { + let MSB = _mm256_set_epi64x(MSB_, MSB_, MSB_, MSB_); + let a_s = _mm256_xor_si256(*a, MSB); + } +} + +#[inline] +fn toCanonical_avx_s(a_sc: &mut __m256i, a_s: &__m256i) +{ + unsafe { + let P_s = _mm256_set_epi64x(P_s_, P_s_, P_s_, P_s_); + let P_n = _mm256_set_epi64x(P_n_, P_n_, P_n_, P_n_); + let mask1_ = _mm256_cmpgt_epi64(P_s, *a_s); + let corr1_ = _mm256_andnot_si256(mask1_, P_n); + let a_sc = _mm256_add_epi64(*a_s, corr1_); + } +} + +#[inline] +fn add_avx_a_sc(c: &mut __m256i,a_sc: &__m256i, b: &__m256i) +{ + unsafe { + let c0_s = _mm256_add_epi64(*a_sc, *b); + let P_n = _mm256_set_epi64x(P_n_, P_n_, P_n_, P_n_); + let mask_ = _mm256_cmpgt_epi64(*a_sc, c0_s); + let corr_ = _mm256_and_si256(mask_, P_n); + let c_s = _mm256_add_epi64(c0_s, corr_); + shift_avx(c, &c_s); + } +} + +#[inline] +fn add_avx(c: &mut __m256i, a: &__m256i, b: &__m256i) +{ + unsafe { + let mut a_s: __m256i = _mm256_set_epi64x(0, 0, 0, 0); + let mut a_sc: __m256i = _mm256_set_epi64x(0, 0, 0, 0); + shift_avx(&mut a_s, a); + toCanonical_avx_s(&mut a_sc, &a_s); + add_avx_a_sc(c, &a_sc, b); + } +} + +#[inline] fn add_avx_s_b_small(c_s: &mut __m256i, a_s: &__m256i, b_small: &__m256i) +{ + unsafe { + let c0_s = _mm256_add_epi64(*a_s, *b_small); + let mask_ = _mm256_cmpgt_epi32(*a_s, c0_s); + let corr_ = _mm256_srli_epi64(mask_, 32); + let c_s = _mm256_add_epi64(c0_s, corr_); + } +} + +#[inline] +fn sub_avx_s_b_small(c_s: &mut __m256i, a_s: &__m256i, b: &__m256i) +{ + unsafe { + let c0_s = _mm256_sub_epi64(*a_s, *b); + let mask_ = _mm256_cmpgt_epi32(c0_s, *a_s); + let corr_ = _mm256_srli_epi64(mask_, 32); + let c_s = _mm256_sub_epi64(c0_s, corr_); + } +} + +#[inline] fn reduce_avx_128_64(c: &mut __m256i, c_h: &__m256i, c_l: &__m256i) +{ + unsafe { + let c_hh = _mm256_srli_epi64(*c_h, 32); + let mut c1_s: __m256i = c_hh.clone(); + let mut c_ls: __m256i = c_hh.clone(); + let mut c_s:__m256i = c_hh.clone(); + shift_avx(&mut c_ls, c_l); + sub_avx_s_b_small(&mut c1_s, &c_ls, &c_hh); + let P_n = _mm256_set_epi64x(P_n_, P_n_, P_n_, P_n_); + let c2 = _mm256_mul_epu32(*c_h, P_n); + add_avx_s_b_small(&mut c_s, &c1_s, &c2); + shift_avx(c, &c_s); + } +} + +#[inline ] +fn mult_avx_128(c_h: &mut __m256i, c_l: &mut __m256i, a: &__m256i, b: &__m256i) +{ + unsafe { + let a_h = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(*a))); + let b_h = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(*b))); + let c_hh = _mm256_mul_epu32(a_h, b_h); + let c_hl = _mm256_mul_epu32(a_h, *b); + let c_lh = _mm256_mul_epu32(*a, b_h); + let c_ll = _mm256_mul_epu32(*a, *b); + let c_ll_h = _mm256_srli_epi64(c_ll, 32); + let r0 = _mm256_add_epi64(c_hl, c_ll_h); + let P_n = _mm256_set_epi64x(P_n_, P_n_, P_n_, P_n_); + let r0_l = _mm256_and_si256(r0, P_n); + let r1 = _mm256_add_epi64(c_lh, r0_l); + let r1_l = _mm256_castps_si256(_mm256_moveldup_ps(_mm256_castsi256_ps(r1))); + let c_l = _mm256_blend_epi32(c_ll, r1_l, 0xaa); + let r0_h = _mm256_srli_epi64(r0, 32); + let r2 = _mm256_add_epi64(c_hh, r0_h); + let r1_h = _mm256_srli_epi64(r1, 32); + let c_h = _mm256_add_epi64(r2, r1_h); + } +} + +#[inline] +fn mult_avx(c: &mut __m256i, a: &__m256i, b: &__m256i) +{ + let mut c_h = b.clone(); + let mut c_l = b.clone(); + mult_avx_128(&mut c_h, &mut c_l, a, b); + reduce_avx_128_64(c, &c_h, &c_l); +} + +pub fn add_rc_avx(state: &mut [F; SPONGE_WIDTH], rc: &[u64; SPONGE_WIDTH]) +where + F: RichField, +{ + unsafe { + let s0 = _mm256_loadu_si256((&state[0..4]).as_ptr().cast::<__m256i>()); + let s1 = _mm256_loadu_si256((&state[4..8]).as_ptr().cast::<__m256i>()); + let s2 = _mm256_loadu_si256((&state[8..12]).as_ptr().cast::<__m256i>()); + let rc0 = _mm256_loadu_si256((&rc[0..4]).as_ptr().cast::<__m256i>()); + let rc1 = _mm256_loadu_si256((&rc[4..8]).as_ptr().cast::<__m256i>()); + let rc2 = _mm256_loadu_si256((&rc[8..12]).as_ptr().cast::<__m256i>()); + let p0 = state[0..4].as_mut_ptr().cast::<__m256i>(); + let p1 = state[4..8].as_mut_ptr().cast::<__m256i>(); + let p2 = state[8..12].as_mut_ptr().cast::<__m256i>(); + add_avx(&mut *p0, &s0, &rc0); + add_avx(&mut *p1, &s1, &rc1); + add_avx(&mut *p2, &s2, &rc2); + } +} + +pub fn sbox_avx(state: &mut [F; SPONGE_WIDTH]) +where + F: RichField, +{ + unsafe { + let mut s0 = _mm256_loadu_si256((&state[0..4]).as_ptr().cast::<__m256i>()); + let mut s1 = _mm256_loadu_si256((&state[4..8]).as_ptr().cast::<__m256i>()); + let mut s2 = _mm256_loadu_si256((&state[8..12]).as_ptr().cast::<__m256i>()); + let mut p10: __m256i = s0.clone(); + let mut p11: __m256i = s0.clone(); + let mut p12: __m256i = s0.clone(); + let mut p20: __m256i = s0.clone(); + let mut p21: __m256i = s0.clone(); + let mut p22: __m256i = s0.clone(); + // x^2 + mult_avx(&mut p10, &s0, &s0); + mult_avx(&mut p11, &s1, &s1); + mult_avx(&mut p12, &s2, &s2); + // x^3 + mult_avx(&mut p20, &p10, &s0); + mult_avx(&mut p21, &p11, &s1); + mult_avx(&mut p22, &p12, &s2); + // x^4 + mult_avx(&mut s0, &p10, &p10); + mult_avx(&mut s1, &p11, &p11); + mult_avx(&mut s2, &p12, &p12); + // x^7 + let p0 = state[0..4].as_mut_ptr().cast::<__m256i>(); + let p1 = state[4..8].as_mut_ptr().cast::<__m256i>(); + let p2 = state[8..12].as_mut_ptr().cast::<__m256i>(); + mult_avx(&mut *p0, &s0, &p20); + mult_avx(&mut *p1, &s1, &p21); + mult_avx(&mut *p2, &s2, &p22); + } +} + +pub fn matmul_internal_avx( + state: &mut [F; SPONGE_WIDTH], + mat_internal_diag_m_1: [u64; SPONGE_WIDTH], +) +where + F: RichField, +{ + let mut sum = state[0]; + for i in 1..SPONGE_WIDTH { + sum = sum + state[i]; + } + let si64: i64 = sum.to_canonical_u64() as i64; + unsafe { + let s0 = _mm256_loadu_si256((&state[0..4]).as_ptr().cast::<__m256i>()); + let s1 = _mm256_loadu_si256((&state[4..8]).as_ptr().cast::<__m256i>()); + let s2 = _mm256_loadu_si256((&state[8..12]).as_ptr().cast::<__m256i>()); + let m0 = _mm256_loadu_si256((&mat_internal_diag_m_1[0..4]).as_ptr().cast::<__m256i>()); + let m1 = _mm256_loadu_si256((&mat_internal_diag_m_1[4..8]).as_ptr().cast::<__m256i>()); + let m2 = _mm256_loadu_si256((&mat_internal_diag_m_1[8..12]).as_ptr().cast::<__m256i>()); + let ss = _mm256_set_epi64x(si64, si64, si64, si64); + let mut p10: __m256i = s0.clone(); + let mut p11: __m256i = s0.clone(); + let mut p12: __m256i = s0.clone(); + mult_avx(&mut p10, &s0, &m0); + mult_avx(&mut p11, &s1, &m1); + mult_avx(&mut p12, &s2, &m2); + let p0 = state[0..4].as_mut_ptr().cast::<__m256i>(); + let p1 = state[4..8].as_mut_ptr().cast::<__m256i>(); + let p2 = state[8..12].as_mut_ptr().cast::<__m256i>(); + add_avx(&mut *p0, &p10, &ss); + add_avx(&mut *p1, &p11, &ss); + add_avx(&mut *p2, &p12, &ss); + } +} + +#[inline] +pub fn permute_mut_avx(state: &mut [F; SPONGE_WIDTH]) +where + F: RichField, +{ + // First, we apply M_4 to each consecutive four elements of the state. + // In Appendix B's terminology, this replaces each x_i with x_i'. + for i in (0..SPONGE_WIDTH).step_by(4) { + apply_m_4(&mut state[i..i + 4]); + } + + unsafe { + let s0 = _mm256_loadu_si256((&state[0..4]).as_ptr().cast::<__m256i>()); + let s1 = _mm256_loadu_si256((&state[4..8]).as_ptr().cast::<__m256i>()); + let s2 = _mm256_loadu_si256((&state[8..12]).as_ptr().cast::<__m256i>()); + let mut s3 = s0.clone(); + let mut s = s0.clone(); + add_avx(&mut s3, &s0, &s1); + add_avx(&mut s, &s2, &s3); + let p0 = state[0..4].as_mut_ptr().cast::<__m256i>(); + let p1 = state[4..8].as_mut_ptr().cast::<__m256i>(); + let p2 = state[8..12].as_mut_ptr().cast::<__m256i>(); + add_avx(&mut *p0, &s0, &s); + add_avx(&mut *p1, &s1, &s); + add_avx(&mut *p2, &s2, &s); + } +} \ No newline at end of file diff --git a/plonky2/src/hash/poseidon2.rs b/plonky2/src/hash/poseidon2.rs index 95338ba5b9..3bd52ac9cb 100644 --- a/plonky2/src/hash/poseidon2.rs +++ b/plonky2/src/hash/poseidon2.rs @@ -15,6 +15,8 @@ use crate::iop::target::{BoolTarget, Target}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::config::{AlgebraicHasher, Hasher, HasherType}; +#[cfg(target_feature = "avx2")] +use super::arch::x86_64::poseidon2_goldilocks_avx2::{add_rc_avx, sbox_avx, matmul_internal_avx, permute_mut_avx}; use super::hash_types::NUM_HASH_OUT_ELTS; pub const SPONGE_RATE: usize = 8; @@ -78,11 +80,11 @@ extern crate alloc; // [ 1 1 4 6 ]. // The permutation calculation is based on Appendix B from the Poseidon2 paper. #[derive(Copy, Clone, Default)] -pub struct Poseidon2MEMatrix; +pub struct Poseidon2MEMatrix; // Multiply a 4-element vector x by M_4, in place. // This uses the formula from the start of Appendix B, with multiplications unrolled into additions. -fn apply_m_4(x: &mut [F]) +pub fn apply_m_4(x: &mut [F]) where F: RichField, { @@ -109,14 +111,15 @@ trait P2Permutation: Clone + Sync { fn permute_mut(&self, input: &mut T); } -impl P2Permutation<[F; WIDTH]> for Poseidon2MEMatrix +impl P2Permutation<[F; SPONGE_WIDTH]> for Poseidon2MEMatrix where F: RichField, { - fn permute_mut(&self, state: &mut [F; WIDTH]) { + #[cfg(not(target_feature = "avx2"))] + fn permute_mut(&self, state: &mut [F; SPONGE_WIDTH]) { // First, we apply M_4 to each consecutive four elements of the state. // In Appendix B's terminology, this replaces each x_i with x_i'. - for i in (0..WIDTH).step_by(4) { + for i in (0..SPONGE_WIDTH).step_by(4) { apply_m_4(&mut state[i..i + 4]); } @@ -124,7 +127,7 @@ where // We first precompute the four sums of every four elements. let sums: [F; 4] = core::array::from_fn(|k| { - (0..WIDTH) + (0..SPONGE_WIDTH) .step_by(4) .map(|j| state[j + k].clone()) .sum::() @@ -132,29 +135,41 @@ where // The formula for each y_i involves 2x_i' term and x_j' terms for each j that equals i mod 4. // In other words, we can add a single copy of x_i' to the appropriate one of our precomputed sums - for i in 0..WIDTH { + for i in 0..SPONGE_WIDTH { state[i] += sums[i % 4].clone(); } } + + #[cfg(target_feature = "avx2")] + fn permute_mut(&self, state: &mut [F; SPONGE_WIDTH]) { + permute_mut_avx(state); + } } #[derive(Debug, Clone, Default)] struct DiffusionMatrixGoldilocks; -pub fn matmul_internal( - state: &mut [F; WIDTH], - mat_internal_diag_m_1: [u64; WIDTH], +pub fn matmul_internal( + state: &mut [F; SPONGE_WIDTH], + mat_internal_diag_m_1: [u64; SPONGE_WIDTH], ) { + // if no AVX + #[cfg(not(target_feature = "avx2"))] let sum: F = state.iter().cloned().sum(); - for i in 0..WIDTH { + // if no AVX + #[cfg(not(target_feature = "avx2"))] + for i in 0..SPONGE_WIDTH { state[i] *= F::from_canonical_u64(mat_internal_diag_m_1[i]); state[i] += sum.clone(); } + // if AVX + #[cfg(target_feature = "avx2")] + matmul_internal_avx(state, mat_internal_diag_m_1); } impl P2Permutation<[F; 12]> for DiffusionMatrixGoldilocks { fn permute_mut(&self, state: &mut [F; 12]) { - matmul_internal::(state, MATRIX_DIAG_12_GOLDILOCKS); + matmul_internal::(state, MATRIX_DIAG_12_GOLDILOCKS); } } @@ -165,29 +180,18 @@ pub trait Poseidon2: RichField { const ROUNDS_P: usize = 22; #[inline] - fn matmul_internal( - state: &mut [F; SPONGE_WIDTH], - mat_internal_diag_m_1: [u64; SPONGE_WIDTH], - ) + fn add_rc(state: &mut [F; SPONGE_WIDTH], rc: &[u64; SPONGE_WIDTH]) where F: RichField, { - let sum: F = state.iter().cloned().sum(); + // if no AVX + #[cfg(not(target_feature = "avx2"))] for i in 0..SPONGE_WIDTH { - state[i] *= F::from_canonical_u64(mat_internal_diag_m_1[i]); - state[i] += sum.clone(); + state[i] = state[i] + F::from_canonical_u64(rc[i]); } - } - - #[inline] - fn add_rc(state: &mut [F; SPONGE_WIDTH], rc: &[u64; SPONGE_WIDTH]) - where - F: RichField, - { - state - .iter_mut() - .zip(rc) - .for_each(|(a, b)| *a += F::from_canonical_u64(*b)); + // if AVX + #[cfg(target_feature = "avx2")] + add_rc_avx(state, rc); } #[inline] @@ -195,7 +199,12 @@ pub trait Poseidon2: RichField { where F: RichField, { - input.exp_u64(7) + // this is inefficient, so we change to the one below + // input.exp_u64(7) + let x2 = (*input) * (*input); + let x4 = x2 * x2; + let x3 = x2 * (*input); + x3 * x4 } #[inline] @@ -203,12 +212,19 @@ pub trait Poseidon2: RichField { where F: RichField, { - state.iter_mut().for_each(|a| *a = Self::sbox_p(a)); + // if no AVX + #[cfg(not(target_feature = "avx2"))] + for i in 0..SPONGE_WIDTH { + state[i] = Self::sbox_p(&state[i]); + } + // if AVX + #[cfg(target_feature = "avx2")] + sbox_avx(state); } #[inline] fn poseidon2(state: &mut [Self; SPONGE_WIDTH]) { - let external_linear_layer = Poseidon2MEMatrix::; + let external_linear_layer = Poseidon2MEMatrix; // The initial linear layer. external_linear_layer.permute_mut(state); @@ -227,7 +243,7 @@ pub trait Poseidon2: RichField { for r in rounds_f_beginning..p_end { state[0] += Self::from_canonical_u64(RC12[r][0]); state[0] = Self::sbox_p(&state[0]); - Self::matmul_internal(state, MATRIX_DIAG_12_GOLDILOCKS); + matmul_internal(state, MATRIX_DIAG_12_GOLDILOCKS); } // The second half of the external rounds.