diff --git a/Cargo.lock b/Cargo.lock index 47d17f65..0440b759 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -372,7 +372,7 @@ checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" [[package]] name = "elliptic-curve" version = "0.13.6" -source = "git+https://github.com/RustCrypto/traits.git#4a663341473273c73fd2ca7e37dadd9ba7013fd2" +source = "git+https://github.com/RustCrypto/traits.git#723a1a309c4d987ec54c4c1f7bfe3686dc8d5a29" dependencies = [ "base16ct", "base64ct", diff --git a/k256/src/arithmetic.rs b/k256/src/arithmetic.rs index d738c091..b6ba5673 100644 --- a/k256/src/arithmetic.rs +++ b/k256/src/arithmetic.rs @@ -12,7 +12,6 @@ pub(crate) mod scalar; mod dev; pub use field::FieldElement; -pub use mul::lincomb; use self::{affine::AffinePoint, projective::ProjectivePoint, scalar::Scalar}; use crate::Secp256k1; diff --git a/k256/src/arithmetic/mul.rs b/k256/src/arithmetic/mul.rs index c55c2383..6354a7d2 100644 --- a/k256/src/arithmetic/mul.rs +++ b/k256/src/arithmetic/mul.rs @@ -44,9 +44,11 @@ use crate::arithmetic::{ scalar::{Scalar, WideScalar}, ProjectivePoint, }; + use core::ops::{Mul, MulAssign}; +use elliptic_curve::ops::LinearCombinationExt as LinearCombination; use elliptic_curve::{ - ops::{LinearCombination, MulByGenerator}, + ops::MulByGenerator, scalar::IsHigh, subtle::{Choice, ConditionallySelectable, ConstantTimeEq}, }; @@ -278,93 +280,70 @@ impl Default for Radix16Decomposition { } } -/// Maps an array `x` to an array using the predicate `f`. -/// We can't use the standard `map()` because as of Rust 1.51 we cannot collect into arrays. -/// Consequently, since we cannot have an uninitialized array (without `unsafe`), -/// a default value needs to be provided. -fn static_map( - f: impl Fn(T) -> V, - x: &[T; N], - default: V, -) -> [V; N] { - let mut res = [default; N]; - for i in 0..N { - res[i] = f(x[i]); +impl LinearCombination<[(ProjectivePoint, Scalar); N]> for ProjectivePoint { + fn lincomb_ext(points_and_scalars: &[(ProjectivePoint, Scalar); N]) -> Self { + let mut tables = [(LookupTable::default(), LookupTable::default()); N]; + let mut digits = [( + Radix16Decomposition::<33>::default(), + Radix16Decomposition::<33>::default(), + ); N]; + + lincomb(points_and_scalars, &mut tables, &mut digits) } - res } -/// Maps two arrays `x` and `y` into an array using a predicate `f` that takes two arguments. -fn static_zip_map( - f: impl Fn(T, S) -> V, - x: &[T; N], - y: &[S; N], - default: V, -) -> [V; N] { - let mut res = [default; N]; - for i in 0..N { - res[i] = f(x[i], y[i]); +#[cfg(feature = "alloc")] +impl LinearCombination<[(ProjectivePoint, Scalar)]> for ProjectivePoint { + fn lincomb_ext(points_and_scalars: &[(ProjectivePoint, Scalar)]) -> Self { + let mut tables = + vec![(LookupTable::default(), LookupTable::default()); points_and_scalars.len()]; + let mut digits = vec![ + ( + Radix16Decomposition::<33>::default(), + Radix16Decomposition::<33>::default(), + ); + points_and_scalars.len() + ]; + + lincomb(points_and_scalars, &mut tables, &mut digits) } - res } -/// Calculates a linear combination `sum(x[i] * k[i])`, `i = 0..N` -#[inline(always)] -pub fn lincomb(xs: &[ProjectivePoint; N], ks: &[Scalar; N]) -> ProjectivePoint { - let rs = static_map( - |k| decompose_scalar(&k), - ks, - (Scalar::default(), Scalar::default()), - ); - let r1s = static_map(|(r1, _r2)| r1, &rs, Scalar::default()); - let r2s = static_map(|(_r1, r2)| r2, &rs, Scalar::default()); - - let xs_beta = static_map(|x| x.endomorphism(), xs, ProjectivePoint::default()); - - let r1_signs = static_map(|r| r.is_high(), &r1s, Choice::from(0u8)); - let r2_signs = static_map(|r| r.is_high(), &r2s, Choice::from(0u8)); - - let r1s_c = static_zip_map( - |r, r_sign| Scalar::conditional_select(&r, &-r, r_sign), - &r1s, - &r1_signs, - Scalar::default(), - ); - let r2s_c = static_zip_map( - |r, r_sign| Scalar::conditional_select(&r, &-r, r_sign), - &r2s, - &r2_signs, - Scalar::default(), - ); - - let tables1 = static_zip_map( - |x, r_sign| LookupTable::from(&ProjectivePoint::conditional_select(&x, &-x, r_sign)), - xs, - &r1_signs, - LookupTable::default(), - ); - let tables2 = static_zip_map( - |x, r_sign| LookupTable::from(&ProjectivePoint::conditional_select(&x, &-x, r_sign)), - &xs_beta, - &r2_signs, - LookupTable::default(), - ); - - let digits1 = static_map( - |r| Radix16Decomposition::<33>::new(&r), - &r1s_c, - Radix16Decomposition::<33>::default(), - ); - let digits2 = static_map( - |r| Radix16Decomposition::<33>::new(&r), - &r2s_c, - Radix16Decomposition::<33>::default(), - ); +fn lincomb( + xks: &[(ProjectivePoint, Scalar)], + tables: &mut [(LookupTable, LookupTable)], + digits: &mut [(Radix16Decomposition<33>, Radix16Decomposition<33>)], +) -> ProjectivePoint { + xks.iter().enumerate().for_each(|(i, (x, k))| { + let (r1, r2) = decompose_scalar(k); + let x_beta = x.endomorphism(); + let (r1_sign, r2_sign) = (r1.is_high(), r2.is_high()); + + let (r1_c, r2_c) = ( + Scalar::conditional_select(&r1, &-r1, r1_sign), + Scalar::conditional_select(&r2, &-r2, r2_sign), + ); + + tables[i] = ( + LookupTable::from(&ProjectivePoint::conditional_select(x, &-*x, r1_sign)), + LookupTable::from(&ProjectivePoint::conditional_select( + &x_beta, &-x_beta, r2_sign, + )), + ); + + digits[i] = ( + Radix16Decomposition::<33>::new(&r1_c), + Radix16Decomposition::<33>::new(&r2_c), + ) + }); let mut acc = ProjectivePoint::IDENTITY; - for component in 0..N { - acc += &tables1[component].select(digits1[component].0[32]); - acc += &tables2[component].select(digits2[component].0[32]); + for component in 0..xks.len() { + let (digit1, digit2) = digits[component]; + let (table1, table2) = tables[component]; + + acc += &table1.select(digit1.0[32]); + acc += &table2.select(digit2.0[32]); } for i in (0..32).rev() { @@ -372,9 +351,12 @@ pub fn lincomb(xs: &[ProjectivePoint; N], ks: &[Scalar; N]) -> P acc = acc.double(); } - for component in 0..N { - acc += &tables1[component].select(digits1[component].0[i]); - acc += &tables2[component].select(digits2[component].0[i]); + for component in 0..xks.len() { + let (digit1, digit2) = digits[component]; + let (table1, table2) = tables[component]; + + acc += &table1.select(digit1.0[i]); + acc += &table2.select(digit2.0[i]); } } acc @@ -429,18 +411,7 @@ impl MulByGenerator for ProjectivePoint { #[inline(always)] fn mul(x: &ProjectivePoint, k: &Scalar) -> ProjectivePoint { - lincomb(&[*x], &[*k]) -} - -impl LinearCombination for ProjectivePoint { - fn lincomb( - x: &ProjectivePoint, - k: &Scalar, - y: &ProjectivePoint, - l: &Scalar, - ) -> ProjectivePoint { - lincomb(&[*x, *y], &[*k, *l]) - } + ProjectivePoint::lincomb_ext(&[(*x, *k)]) } impl Mul for ProjectivePoint { @@ -481,9 +452,10 @@ impl MulAssign<&Scalar> for ProjectivePoint { #[cfg(test)] mod tests { + use super::*; use crate::arithmetic::{ProjectivePoint, Scalar}; use elliptic_curve::{ - ops::{LinearCombination, MulByGenerator}, + ops::{LinearCombination as _, MulByGenerator}, rand_core::OsRng, Field, Group, }; @@ -507,4 +479,19 @@ mod tests { let test = ProjectivePoint::mul_by_generator(&k); assert_eq!(reference, test); } + + #[cfg(feature = "alloc")] + #[test] + fn test_lincomb_slice() { + let x = ProjectivePoint::random(&mut OsRng); + let y = ProjectivePoint::random(&mut OsRng); + let k = Scalar::random(&mut OsRng); + let l = Scalar::random(&mut OsRng); + + let reference = &x * &k + &y * &l; + let points_and_scalars = vec![(x, k), (y, l)]; + + let test = ProjectivePoint::lincomb_ext(points_and_scalars.as_slice()); + assert_eq!(reference, test); + } } diff --git a/k256/src/lib.rs b/k256/src/lib.rs index f0afd5e1..f47f4882 100644 --- a/k256/src/lib.rs +++ b/k256/src/lib.rs @@ -50,7 +50,7 @@ pub mod test_vectors; pub use elliptic_curve::{self, bigint::U256}; #[cfg(feature = "arithmetic")] -pub use arithmetic::{affine::AffinePoint, lincomb, projective::ProjectivePoint, scalar::Scalar}; +pub use arithmetic::{affine::AffinePoint, projective::ProjectivePoint, scalar::Scalar}; #[cfg(feature = "expose-field")] pub use arithmetic::FieldElement;