Skip to content

Commit

Permalink
k256: refactored lincomb to use LinearCombinationExt trait (#974)
Browse files Browse the repository at this point in the history
Co-authored-by: Tony Arcieri <[email protected]>
  • Loading branch information
ycscaly and tarcieri authored Nov 15, 2023
1 parent 2bb8aa8 commit 676e728
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 99 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion k256/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ pub(crate) mod scalar;
mod dev;

pub use field::FieldElement;
pub use mul::lincomb;

use self::{affine::AffinePoint, projective::ProjectivePoint, scalar::Scalar};
use crate::Secp256k1;
Expand Down
179 changes: 83 additions & 96 deletions k256/src/arithmetic/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ use crate::arithmetic::{
scalar::{Scalar, WideScalar},
ProjectivePoint,
};

use core::ops::{Mul, MulAssign};
use elliptic_curve::ops::LinearCombinationExt as LinearCombination;
use elliptic_curve::{
ops::{LinearCombination, MulByGenerator},
ops::MulByGenerator,
scalar::IsHigh,
subtle::{Choice, ConditionallySelectable, ConstantTimeEq},
};
Expand Down Expand Up @@ -278,103 +280,83 @@ impl<const D: usize> Default for Radix16Decomposition<D> {
}
}

/// Maps an array `x` to an array using the predicate `f`.
/// We can't use the standard `map()` because as of Rust 1.51 we cannot collect into arrays.
/// Consequently, since we cannot have an uninitialized array (without `unsafe`),
/// a default value needs to be provided.
fn static_map<T: Copy, V: Copy, const N: usize>(
f: impl Fn(T) -> V,
x: &[T; N],
default: V,
) -> [V; N] {
let mut res = [default; N];
for i in 0..N {
res[i] = f(x[i]);
impl<const N: usize> LinearCombination<[(ProjectivePoint, Scalar); N]> for ProjectivePoint {
fn lincomb_ext(points_and_scalars: &[(ProjectivePoint, Scalar); N]) -> Self {
let mut tables = [(LookupTable::default(), LookupTable::default()); N];
let mut digits = [(
Radix16Decomposition::<33>::default(),
Radix16Decomposition::<33>::default(),
); N];

lincomb(points_and_scalars, &mut tables, &mut digits)
}
res
}

/// Maps two arrays `x` and `y` into an array using a predicate `f` that takes two arguments.
fn static_zip_map<T: Copy, S: Copy, V: Copy, const N: usize>(
f: impl Fn(T, S) -> V,
x: &[T; N],
y: &[S; N],
default: V,
) -> [V; N] {
let mut res = [default; N];
for i in 0..N {
res[i] = f(x[i], y[i]);
#[cfg(feature = "alloc")]
impl LinearCombination<[(ProjectivePoint, Scalar)]> for ProjectivePoint {
fn lincomb_ext(points_and_scalars: &[(ProjectivePoint, Scalar)]) -> Self {
let mut tables =
vec![(LookupTable::default(), LookupTable::default()); points_and_scalars.len()];
let mut digits = vec![
(
Radix16Decomposition::<33>::default(),
Radix16Decomposition::<33>::default(),
);
points_and_scalars.len()
];

lincomb(points_and_scalars, &mut tables, &mut digits)
}
res
}

/// Calculates a linear combination `sum(x[i] * k[i])`, `i = 0..N`
#[inline(always)]
pub fn lincomb<const N: usize>(xs: &[ProjectivePoint; N], ks: &[Scalar; N]) -> ProjectivePoint {
let rs = static_map(
|k| decompose_scalar(&k),
ks,
(Scalar::default(), Scalar::default()),
);
let r1s = static_map(|(r1, _r2)| r1, &rs, Scalar::default());
let r2s = static_map(|(_r1, r2)| r2, &rs, Scalar::default());

let xs_beta = static_map(|x| x.endomorphism(), xs, ProjectivePoint::default());

let r1_signs = static_map(|r| r.is_high(), &r1s, Choice::from(0u8));
let r2_signs = static_map(|r| r.is_high(), &r2s, Choice::from(0u8));

let r1s_c = static_zip_map(
|r, r_sign| Scalar::conditional_select(&r, &-r, r_sign),
&r1s,
&r1_signs,
Scalar::default(),
);
let r2s_c = static_zip_map(
|r, r_sign| Scalar::conditional_select(&r, &-r, r_sign),
&r2s,
&r2_signs,
Scalar::default(),
);

let tables1 = static_zip_map(
|x, r_sign| LookupTable::from(&ProjectivePoint::conditional_select(&x, &-x, r_sign)),
xs,
&r1_signs,
LookupTable::default(),
);
let tables2 = static_zip_map(
|x, r_sign| LookupTable::from(&ProjectivePoint::conditional_select(&x, &-x, r_sign)),
&xs_beta,
&r2_signs,
LookupTable::default(),
);

let digits1 = static_map(
|r| Radix16Decomposition::<33>::new(&r),
&r1s_c,
Radix16Decomposition::<33>::default(),
);
let digits2 = static_map(
|r| Radix16Decomposition::<33>::new(&r),
&r2s_c,
Radix16Decomposition::<33>::default(),
);
fn lincomb(
xks: &[(ProjectivePoint, Scalar)],
tables: &mut [(LookupTable, LookupTable)],
digits: &mut [(Radix16Decomposition<33>, Radix16Decomposition<33>)],
) -> ProjectivePoint {
xks.iter().enumerate().for_each(|(i, (x, k))| {
let (r1, r2) = decompose_scalar(k);
let x_beta = x.endomorphism();
let (r1_sign, r2_sign) = (r1.is_high(), r2.is_high());

let (r1_c, r2_c) = (
Scalar::conditional_select(&r1, &-r1, r1_sign),
Scalar::conditional_select(&r2, &-r2, r2_sign),
);

tables[i] = (
LookupTable::from(&ProjectivePoint::conditional_select(x, &-*x, r1_sign)),
LookupTable::from(&ProjectivePoint::conditional_select(
&x_beta, &-x_beta, r2_sign,
)),
);

digits[i] = (
Radix16Decomposition::<33>::new(&r1_c),
Radix16Decomposition::<33>::new(&r2_c),
)
});

let mut acc = ProjectivePoint::IDENTITY;
for component in 0..N {
acc += &tables1[component].select(digits1[component].0[32]);
acc += &tables2[component].select(digits2[component].0[32]);
for component in 0..xks.len() {
let (digit1, digit2) = digits[component];
let (table1, table2) = tables[component];

acc += &table1.select(digit1.0[32]);
acc += &table2.select(digit2.0[32]);
}

for i in (0..32).rev() {
for _j in 0..4 {
acc = acc.double();
}

for component in 0..N {
acc += &tables1[component].select(digits1[component].0[i]);
acc += &tables2[component].select(digits2[component].0[i]);
for component in 0..xks.len() {
let (digit1, digit2) = digits[component];
let (table1, table2) = tables[component];

acc += &table1.select(digit1.0[i]);
acc += &table2.select(digit2.0[i]);
}
}
acc
Expand Down Expand Up @@ -429,18 +411,7 @@ impl MulByGenerator for ProjectivePoint {

#[inline(always)]
fn mul(x: &ProjectivePoint, k: &Scalar) -> ProjectivePoint {
lincomb(&[*x], &[*k])
}

impl LinearCombination for ProjectivePoint {
fn lincomb(
x: &ProjectivePoint,
k: &Scalar,
y: &ProjectivePoint,
l: &Scalar,
) -> ProjectivePoint {
lincomb(&[*x, *y], &[*k, *l])
}
ProjectivePoint::lincomb_ext(&[(*x, *k)])
}

impl Mul<Scalar> for ProjectivePoint {
Expand Down Expand Up @@ -481,9 +452,10 @@ impl MulAssign<&Scalar> for ProjectivePoint {

#[cfg(test)]
mod tests {
use super::*;
use crate::arithmetic::{ProjectivePoint, Scalar};
use elliptic_curve::{
ops::{LinearCombination, MulByGenerator},
ops::{LinearCombination as _, MulByGenerator},
rand_core::OsRng,
Field, Group,
};
Expand All @@ -507,4 +479,19 @@ mod tests {
let test = ProjectivePoint::mul_by_generator(&k);
assert_eq!(reference, test);
}

#[cfg(feature = "alloc")]
#[test]
fn test_lincomb_slice() {
let x = ProjectivePoint::random(&mut OsRng);
let y = ProjectivePoint::random(&mut OsRng);
let k = Scalar::random(&mut OsRng);
let l = Scalar::random(&mut OsRng);

let reference = &x * &k + &y * &l;
let points_and_scalars = vec![(x, k), (y, l)];

let test = ProjectivePoint::lincomb_ext(points_and_scalars.as_slice());
assert_eq!(reference, test);
}
}
2 changes: 1 addition & 1 deletion k256/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub mod test_vectors;
pub use elliptic_curve::{self, bigint::U256};

#[cfg(feature = "arithmetic")]
pub use arithmetic::{affine::AffinePoint, lincomb, projective::ProjectivePoint, scalar::Scalar};
pub use arithmetic::{affine::AffinePoint, projective::ProjectivePoint, scalar::Scalar};

#[cfg(feature = "expose-field")]
pub use arithmetic::FieldElement;
Expand Down

0 comments on commit 676e728

Please sign in to comment.