Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

k256: refactored lincomb to use LinearCombinationExt trait #974

Merged
merged 12 commits into from
Nov 15, 2023
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion k256/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ pub(crate) mod scalar;
mod dev;

pub use field::FieldElement;
pub use mul::lincomb;

use self::{affine::AffinePoint, projective::ProjectivePoint, scalar::Scalar};
use crate::Secp256k1;
Expand Down
179 changes: 83 additions & 96 deletions k256/src/arithmetic/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ use crate::arithmetic::{
scalar::{Scalar, WideScalar},
ProjectivePoint,
};

use core::ops::{Mul, MulAssign};
use elliptic_curve::ops::LinearCombinationExt as LinearCombination;
use elliptic_curve::{
ops::{LinearCombination, MulByGenerator},
ops::MulByGenerator,
scalar::IsHigh,
subtle::{Choice, ConditionallySelectable, ConstantTimeEq},
};
Expand Down Expand Up @@ -278,103 +280,83 @@ impl<const D: usize> Default for Radix16Decomposition<D> {
}
}

/// Maps an array `x` to an array using the predicate `f`.
/// We can't use the standard `map()` because as of Rust 1.51 we cannot collect into arrays.
/// Consequently, since we cannot have an uninitialized array (without `unsafe`),
/// a default value needs to be provided.
fn static_map<T: Copy, V: Copy, const N: usize>(
f: impl Fn(T) -> V,
x: &[T; N],
default: V,
) -> [V; N] {
let mut res = [default; N];
for i in 0..N {
res[i] = f(x[i]);
impl<const N: usize> LinearCombination<[(ProjectivePoint, Scalar); N]> for ProjectivePoint {
fn lincomb_ext(points_and_scalars: &[(ProjectivePoint, Scalar); N]) -> Self {
let mut tables = [(LookupTable::default(), LookupTable::default()); N];
let mut digits = [(
Radix16Decomposition::<33>::default(),
Radix16Decomposition::<33>::default(),
); N];

lincomb(points_and_scalars, &mut tables, &mut digits)
}
res
}

/// Maps two arrays `x` and `y` into an array using a predicate `f` that takes two arguments.
fn static_zip_map<T: Copy, S: Copy, V: Copy, const N: usize>(
f: impl Fn(T, S) -> V,
x: &[T; N],
y: &[S; N],
default: V,
) -> [V; N] {
let mut res = [default; N];
for i in 0..N {
res[i] = f(x[i], y[i]);
#[cfg(feature = "alloc")]
impl LinearCombination<[(ProjectivePoint, Scalar)]> for ProjectivePoint {
fn lincomb_ext(points_and_scalars: &[(ProjectivePoint, Scalar)]) -> Self {
let mut tables =
vec![(LookupTable::default(), LookupTable::default()); points_and_scalars.len()];
let mut digits = vec![
(
Radix16Decomposition::<33>::default(),
Radix16Decomposition::<33>::default(),
);
points_and_scalars.len()
];

lincomb(points_and_scalars, &mut tables, &mut digits)
}
res
}

/// Calculates a linear combination `sum(x[i] * k[i])`, `i = 0..N`
#[inline(always)]
pub fn lincomb<const N: usize>(xs: &[ProjectivePoint; N], ks: &[Scalar; N]) -> ProjectivePoint {
let rs = static_map(
|k| decompose_scalar(&k),
ks,
(Scalar::default(), Scalar::default()),
);
let r1s = static_map(|(r1, _r2)| r1, &rs, Scalar::default());
let r2s = static_map(|(_r1, r2)| r2, &rs, Scalar::default());

let xs_beta = static_map(|x| x.endomorphism(), xs, ProjectivePoint::default());

let r1_signs = static_map(|r| r.is_high(), &r1s, Choice::from(0u8));
let r2_signs = static_map(|r| r.is_high(), &r2s, Choice::from(0u8));

let r1s_c = static_zip_map(
|r, r_sign| Scalar::conditional_select(&r, &-r, r_sign),
&r1s,
&r1_signs,
Scalar::default(),
);
let r2s_c = static_zip_map(
|r, r_sign| Scalar::conditional_select(&r, &-r, r_sign),
&r2s,
&r2_signs,
Scalar::default(),
);

let tables1 = static_zip_map(
|x, r_sign| LookupTable::from(&ProjectivePoint::conditional_select(&x, &-x, r_sign)),
xs,
&r1_signs,
LookupTable::default(),
);
let tables2 = static_zip_map(
|x, r_sign| LookupTable::from(&ProjectivePoint::conditional_select(&x, &-x, r_sign)),
&xs_beta,
&r2_signs,
LookupTable::default(),
);

let digits1 = static_map(
|r| Radix16Decomposition::<33>::new(&r),
&r1s_c,
Radix16Decomposition::<33>::default(),
);
let digits2 = static_map(
|r| Radix16Decomposition::<33>::new(&r),
&r2s_c,
Radix16Decomposition::<33>::default(),
);
fn lincomb(
xks: &[(ProjectivePoint, Scalar)],
tables: &mut [(LookupTable, LookupTable)],
digits: &mut [(Radix16Decomposition<33>, Radix16Decomposition<33>)],
) -> ProjectivePoint {
xks.iter().enumerate().for_each(|(i, (x, k))| {
let (r1, r2) = decompose_scalar(k);
let x_beta = x.endomorphism();
let (r1_sign, r2_sign) = (r1.is_high(), r2.is_high());

let (r1_c, r2_c) = (
Scalar::conditional_select(&r1, &-r1, r1_sign),
Scalar::conditional_select(&r2, &-r2, r2_sign),
);

tables[i] = (
LookupTable::from(&ProjectivePoint::conditional_select(x, &-*x, r1_sign)),
LookupTable::from(&ProjectivePoint::conditional_select(
&x_beta, &-x_beta, r2_sign,
)),
);

digits[i] = (
Radix16Decomposition::<33>::new(&r1_c),
Radix16Decomposition::<33>::new(&r2_c),
)
});

let mut acc = ProjectivePoint::IDENTITY;
for component in 0..N {
acc += &tables1[component].select(digits1[component].0[32]);
acc += &tables2[component].select(digits2[component].0[32]);
for component in 0..xks.len() {
let (digit1, digit2) = digits[component];
let (table1, table2) = tables[component];

acc += &table1.select(digit1.0[32]);
acc += &table2.select(digit2.0[32]);
}

for i in (0..32).rev() {
for _j in 0..4 {
acc = acc.double();
}

for component in 0..N {
acc += &tables1[component].select(digits1[component].0[i]);
acc += &tables2[component].select(digits2[component].0[i]);
for component in 0..xks.len() {
let (digit1, digit2) = digits[component];
let (table1, table2) = tables[component];

acc += &table1.select(digit1.0[i]);
acc += &table2.select(digit2.0[i]);
}
}
acc
Expand Down Expand Up @@ -429,18 +411,7 @@ impl MulByGenerator for ProjectivePoint {

#[inline(always)]
fn mul(x: &ProjectivePoint, k: &Scalar) -> ProjectivePoint {
lincomb(&[*x], &[*k])
}

impl LinearCombination for ProjectivePoint {
fn lincomb(
x: &ProjectivePoint,
k: &Scalar,
y: &ProjectivePoint,
l: &Scalar,
) -> ProjectivePoint {
lincomb(&[*x, *y], &[*k, *l])
}
ProjectivePoint::lincomb_ext(&[(*x, *k)])
}

impl Mul<Scalar> for ProjectivePoint {
Expand Down Expand Up @@ -481,9 +452,10 @@ impl MulAssign<&Scalar> for ProjectivePoint {

#[cfg(test)]
mod tests {
use super::*;
use crate::arithmetic::{ProjectivePoint, Scalar};
use elliptic_curve::{
ops::{LinearCombination, MulByGenerator},
ops::{LinearCombination as _, MulByGenerator},
rand_core::OsRng,
Field, Group,
};
Expand All @@ -507,4 +479,19 @@ mod tests {
let test = ProjectivePoint::mul_by_generator(&k);
assert_eq!(reference, test);
}

#[cfg(feature = "alloc")]
#[test]
fn test_lincomb_slice() {
let x = ProjectivePoint::random(&mut OsRng);
let y = ProjectivePoint::random(&mut OsRng);
let k = Scalar::random(&mut OsRng);
let l = Scalar::random(&mut OsRng);

let reference = &x * &k + &y * &l;
let points_and_scalars = vec![(x, k), (y, l)];

let test = ProjectivePoint::lincomb_ext(points_and_scalars.as_slice());
assert_eq!(reference, test);
}
}
2 changes: 1 addition & 1 deletion k256/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub mod test_vectors;
pub use elliptic_curve::{self, bigint::U256};

#[cfg(feature = "arithmetic")]
pub use arithmetic::{affine::AffinePoint, lincomb, projective::ProjectivePoint, scalar::Scalar};
pub use arithmetic::{affine::AffinePoint, projective::ProjectivePoint, scalar::Scalar};

#[cfg(feature = "expose-field")]
pub use arithmetic::FieldElement;
Expand Down
Loading