diff --git a/fastcrypto-tbls/benches/polynomial.rs b/fastcrypto-tbls/benches/polynomial.rs index 4aa8cab7e..1cab24b77 100644 --- a/fastcrypto-tbls/benches/polynomial.rs +++ b/fastcrypto-tbls/benches/polynomial.rs @@ -9,6 +9,37 @@ use std::num::NonZeroU16; mod polynomial_benches { use super::*; + use fastcrypto_tbls::threshold_schnorr::gao::RSDecoder; + use fastcrypto_tbls::threshold_schnorr::S; + use fastcrypto_tbls::types::ShareIndex; + use itertools::Itertools; + + fn rs_decoder(c: &mut Criterion) { + const SIZES: [usize; 4] = [128, 256, 512, 1024]; + + for n in SIZES { + let k = n / 3; + let a = (1..=n) + .map(|i| ShareIndex::new(i as u16).unwrap()) + .collect_vec(); + let decoder = RSDecoder::new(a.clone(), k); + + let message: Vec = (0..k).map(|i| S::from((i * 10) as u128)).collect(); + let code_word = decoder.encode(message.clone()).unwrap(); + + // Introduce errors + let mut received = code_word.clone(); + received[4] = S::from(20u128); // Error at position 4 + received[2] = S::from(200u128); // Error at position 2 + + let mut rs_decoder: BenchmarkGroup<_> = c.benchmark_group("RS Decoder"); + rs_decoder.bench_function(format!("n={}, k={}", n, k).as_str(), |b| { + b.iter(|| { + decoder.decode(&received).unwrap(); + }) + }); + } + } fn polynomials(c: &mut Criterion) { const SIZES: [usize; 7] = [128, 256, 512, 1024, 2048, 4096, 8192]; @@ -77,7 +108,7 @@ mod polynomial_benches { criterion_group! { name = polynomial_benches; config = Criterion::default(); - targets = polynomials, + targets = polynomials, rs_decoder, } } diff --git a/fastcrypto-tbls/src/dkg_v1.rs b/fastcrypto-tbls/src/dkg_v1.rs index 39c38bf12..7b742f689 100644 --- a/fastcrypto-tbls/src/dkg_v1.rs +++ b/fastcrypto-tbls/src/dkg_v1.rs @@ -311,12 +311,12 @@ where return Err(FastCryptoError::InvalidMessage); }; - if self.t as usize != msg.vss_pk.degree() + 1 { + if self.t as usize != msg.vss_pk.degree_bound() + 1 { warn!( "DKG: Message sanity check failed for id {}, expected degree={}, got {}", msg.sender, self.t - 1, - msg.vss_pk.degree() + msg.vss_pk.degree_bound() ); return Err(FastCryptoError::InvalidMessage); } diff --git a/fastcrypto-tbls/src/dl_verification.rs b/fastcrypto-tbls/src/dl_verification.rs index 1e3dd61dc..6995ccb3d 100644 --- a/fastcrypto-tbls/src/dl_verification.rs +++ b/fastcrypto-tbls/src/dl_verification.rs @@ -47,7 +47,7 @@ pub fn verify_poly_evals( poly: &Poly, rng: &mut R, ) -> FastCryptoResult<()> { - assert!(poly.degree() > 0); + assert!(poly.degree_bound() > 0); if evals.is_empty() { return Ok(()); } @@ -59,7 +59,7 @@ pub fn verify_poly_evals( .iter() .map(|e| G::ScalarType::from(e.index.get().into())) .collect::>(); - let coeffs = batch_coefficients(&rs, &evals_as_scalars, poly.degree()); + let coeffs = batch_coefficients(&rs, &evals_as_scalars, poly.degree_bound()); let rhs = G::multi_scalar_mul(&coeffs, poly.as_vec()).expect("sizes match"); if lhs != rhs { diff --git a/fastcrypto-tbls/src/lib.rs b/fastcrypto-tbls/src/lib.rs index 74960ddec..00451a4c4 100644 --- a/fastcrypto-tbls/src/lib.rs +++ b/fastcrypto-tbls/src/lib.rs @@ -20,7 +20,6 @@ pub mod nodes; pub mod polynomial; pub mod random_oracle; pub mod tbls; -#[cfg(any(test, feature = "experimental"))] pub mod threshold_schnorr; pub mod types; diff --git a/fastcrypto-tbls/src/polynomial.rs b/fastcrypto-tbls/src/polynomial.rs index d74798306..df69c64a5 100644 --- a/fastcrypto-tbls/src/polynomial.rs +++ b/fastcrypto-tbls/src/polynomial.rs @@ -5,7 +5,8 @@ // modified for our needs. // -use crate::types::{IndexedValue, ShareIndex}; +use crate::types; +use crate::types::{to_scalar, IndexedValue, ShareIndex}; use fastcrypto::error::{FastCryptoError, FastCryptoResult}; use fastcrypto::groups::{GroupElement, MultiScalarMul, Scalar}; use fastcrypto::traits::AllowedRng; @@ -13,7 +14,8 @@ use itertools::{Either, Itertools}; use serde::{Deserialize, Serialize}; use std::borrow::Borrow; use std::collections::HashSet; -use std::ops::{AddAssign, Mul}; +use std::mem::swap; +use std::ops::{Add, AddAssign, Mul, MulAssign, SubAssign}; /// Types @@ -29,13 +31,26 @@ pub type PublicPoly = Poly; /// Vector related operations. -impl Poly { - /// Returns the degree of the polynomial - pub fn degree(&self) -> usize { +impl Poly { + /// Returns an upper bound for the degree of the polynomial. + /// The returned number is equal to the size of the underlying coefficient vector - 1, + /// and in case some of the leading elements are zero, the actual degree will be smaller. + /// See also [Poly::degree]. + pub fn degree_bound(&self) -> usize { // e.g. c_0 + c_1 * x + c_2 * x^2 + c_3 * x^3 // ^ 4 coefficients correspond to a 3rd degree poly self.0.len() - 1 } + + /// Returns the degree of the polynomial. + pub fn degree(&self) -> usize { + self.0.iter().rposition(|&c| c != C::zero()).unwrap_or(0) + } + + /// Removes leading zero coefficients. + pub(crate) fn reduce(&mut self) { + self.0.truncate(self.degree() + 1); + } } impl From> for Poly { @@ -62,6 +77,43 @@ impl Mul<&C> for Poly { } } +impl Mul<&Poly> for &Poly { + type Output = Poly; + + fn mul(self, rhs: &Poly) -> Poly { + if self.is_zero() || rhs.is_zero() { + return Poly::zero(); + } + let mut result = vec![C::zero(); self.degree() + rhs.degree() + 1]; + for (i, a) in self.0.iter().enumerate() { + for (j, b) in rhs.0.iter().enumerate() { + result[i + j] += *a * *b; + } + } + Poly::from(result) + } +} + +impl Add<&Poly> for Poly { + type Output = Poly; + + fn add(mut self, rhs: &Poly) -> Poly { + self += rhs; + self + } +} + +impl SubAssign> for Poly { + fn sub_assign(&mut self, rhs: Poly) { + if self.0.len() < rhs.0.len() { + self.0.resize(rhs.0.len(), C::zero()); + } + for (a, b) in self.0.iter_mut().zip(&rhs.0) { + *a -= *b; + } + } +} + /// GroupElement operations. impl Poly { @@ -70,6 +122,14 @@ impl Poly { Self::from(vec![C::zero()]) } + pub(crate) fn is_zero(&self) -> bool { + self.0.iter().all(|&c| c == C::zero()) + } + + pub fn one() -> Self { + Self::from(vec![C::generator()]) + } + // TODO: Some of the functions/steps below may be executed many times in practice thus cache can be // used to improve efficiency (e.g., eval(i) may be called with the same index every time a partial // signature from party i is verified). @@ -77,7 +137,7 @@ impl Poly { /// Evaluates the polynomial at the specified value. pub fn eval(&self, i: ShareIndex) -> Eval { // Use Horner's Method to evaluate the polynomial. - let xi = C::ScalarType::from(i.get().into()); + let xi: C::ScalarType = to_scalar(i); let res = self .0 .iter() @@ -121,9 +181,8 @@ impl Poly { return Err(FastCryptoError::InvalidInput); } - let full_numerator = indices.iter().fold(C::ScalarType::generator(), |acc, i| { - acc * C::ScalarType::from(*i) - }); + let full_numerator = + C::ScalarType::product(indices.iter().map(|i| C::ScalarType::from(*i))); let mut coeffs = Vec::new(); for i in &indices { @@ -168,10 +227,7 @@ impl Poly { ) -> FastCryptoResult { let coeffs = Self::get_lagrange_coefficients_for_c0(t, shares.clone())?; let plain_shares = shares.map(|s| s.borrow().value); - let res = coeffs - .iter() - .zip(plain_shares) - .fold(C::zero(), |acc, (c, s)| acc + (s * *c)); + let res = C::sum(coeffs.iter().zip(plain_shares).map(|(c, s)| s * c)); Ok(res) } @@ -196,7 +252,7 @@ impl Poly { panic!( "Index out of bounds: requested {}, but polynomial has degree {}", i, - self.degree() + self.degree_bound() ); } &self.0[i] @@ -206,6 +262,14 @@ impl Poly { pub fn as_vec(&self) -> &Vec { &self.0 } + + pub fn to_vec(self) -> Vec { + self.0 + } + + fn sum(terms: impl Iterator>) -> Poly { + terms.fold(Poly::zero(), |sum, term| sum + &term) + } } /// Scalar operations. @@ -253,12 +317,12 @@ impl Poly { if !points.iter().map(|p| p.index).all_unique() { return Err(FastCryptoError::InvalidInput); } - let x = C::from(index.get() as u128); + let x: C = to_scalar(index); // Convert indices to scalars for interpolation. let indices = points .iter() - .map(|p| C::from(p.index.get() as u128)) + .map(|p| to_scalar(p.index)) .collect::>(); let value = C::sum(indices.iter().enumerate().map(|(j, x_j)| { @@ -273,6 +337,110 @@ impl Poly { Ok(Eval { index, value }) } + + /// Given a set of shares with unique indices, compute the polynomial that + /// goes through all the points. The degree of the resulting polynomial is + /// at most `points.len() - 1`. + /// Returns an error if the input is invalid (e.g., empty or duplicate indices). + pub fn interpolate(points: &[Eval]) -> FastCryptoResult> { + if points.is_empty() || !points.iter().map(|p| p.index).all_unique() { + return Err(FastCryptoError::InvalidInput); + } + let x: Vec = points + .iter() + .map(|e| types::to_scalar(e.index)) + .collect_vec(); + + // Compute the full numerator polynomial: (x - x_1)(x - x_2)...(x - x_t) + let mut full_numerator = Poly::one(); + for x_i in &x { + full_numerator *= MonicLinear(-*x_i); + } + + Ok(Poly::sum(points.iter().enumerate().map(|(j, p_j)| { + let denominator = C::product( + x.iter() + .enumerate() + .filter(|(i, _)| *i != j) + .map(|(_, x_i)| x[j] - x_i), + ); + // Safe since (x - x[j]) divides full_numerator per definition + div_exact(&full_numerator, &MonicLinear(-x[j])) * &(p_j.value / denominator).unwrap() + }))) + } + + /// Returns the leading term of the polynomial. + /// If the polynomial is zero, returns a monomial with coefficient zero and degree zero. + fn lead(&self) -> Monomial { + if self.is_zero() { + return Monomial { + coefficient: C::zero(), + degree: 0, + }; + } + let degree = self.degree(); + Monomial { + coefficient: self.0[degree], + degree, + } + } + + /// Divide self by divisor, returning the quotient and remainder. + /// Returns an error if divisor is zero. + pub fn div_rem(&self, divisor: &Poly) -> FastCryptoResult<(Poly, Poly)> { + if divisor.is_zero() { + return Err(FastCryptoError::InvalidInput); + } + let mut remainder = self.clone(); + let mut quotient = Self::zero(); + + let lead_inverse = divisor.lead().coefficient.inverse()?; + + // Function to divide a term by the leading term of the divisor. + // This panics if the degree of the given term is less than that of the divisor. + let divider = |p: Monomial| Monomial { + coefficient: p.coefficient * lead_inverse, + degree: p.degree - divisor.degree(), + }; + + while !remainder.is_zero() && remainder.degree() >= divisor.degree() { + let tmp = divider(remainder.lead()); + quotient += &tmp; + remainder -= divisor * &tmp; + remainder.reduce(); + } + Ok((quotient, remainder)) + } + + /// Compute the extended GCD of two polynomials. + /// Returns (g, x, y, s, t) such that g = self * x + other * y. + /// The loop stops when the degree of g is less than degree_bound. + pub fn partial_extended_gcd( + &self, + other: &Poly, + degree_bound: usize, + ) -> FastCryptoResult<(Poly, Poly, Poly)> { + let mut r = (self.clone(), other.clone()); + let mut s = (Poly::one(), Poly::zero()); + let mut t = (Poly::zero(), Poly::one()); + + while r.0.degree() >= degree_bound && !r.1.is_zero() { + let (q, r_new) = r.0.div_rem(&r.1)?; + r = (r.1, r_new); + r.0.reduce(); + + t.0 -= &q * &t.1; + s.0 -= &q * &s.1; + + swap(&mut t.0, &mut t.1); + swap(&mut s.0, &mut s.1); + } + Ok((r.0, s.0, t.0)) + } + + pub fn extended_gcd(&self, other: &Poly) -> FastCryptoResult<(Poly, Poly, Poly)> { + self.partial_extended_gcd(other, 1) + } } impl Poly { @@ -288,3 +456,67 @@ impl Poly { Ok(res) } } + +/// This represents a monomial, e.g., 3 * x^2, where 3 is the coefficient and 2 is the degree. +struct Monomial { + coefficient: C, + degree: usize, +} + +impl AddAssign<&Monomial> for Poly { + fn add_assign(&mut self, rhs: &Monomial) { + if self.0.len() <= rhs.degree { + self.0.resize(rhs.degree + 1, C::zero()); + } + self.0[rhs.degree] += rhs.coefficient; + } +} + +impl Mul<&Monomial> for &Poly { + type Output = Poly; + + fn mul(self, rhs: &Monomial) -> Poly { + if rhs.coefficient == C::zero() { + return Poly::zero(); + } + let mut result = vec![C::zero(); self.degree_bound() + rhs.degree + 1]; + for (i, coefficient) in self.0.iter().enumerate() { + result[i + rhs.degree] = *coefficient * rhs.coefficient; + } + Poly::from(result) + } +} + +/// Represents a monic linear polynomial of the form x + c. +pub(crate) struct MonicLinear(pub C); + +impl MulAssign> for Poly { + fn mul_assign(&mut self, rhs: MonicLinear) { + if rhs.0 == C::zero() || self.is_zero() { + *self = Poly::zero(); + return; + } + self.0.push(*self.0.last().unwrap()); + for i in (1..self.0.len() - 1).rev() { + self.0[i] = self.0[i] * rhs.0 + self.0[i - 1]; + } + self.0[0] = self.0[0] * rhs.0; + } +} + +/// Assuming that `d` divides `n` exactly (or, that `d.0` is a root in `n`), return the quotient `n / d`. +fn div_exact(n: &Poly, d: &MonicLinear) -> Poly { + if n.is_zero() { + return Poly::zero(); + } + let mut result = n.0[1..].to_vec(); + for i in (0..result.len() - 1).rev() { + result[i] = result[i] - result[i + 1] * d.0; + } + Poly::from(result) +} + +#[cfg(test)] +pub(crate) fn poly_eq(a: &Poly, b: &Poly) -> bool { + a.0[..(a.degree() + 1)] == b.0[..(b.degree() + 1)] +} diff --git a/fastcrypto-tbls/src/tbls.rs b/fastcrypto-tbls/src/tbls.rs index c6d5ef37f..6236f9f9f 100644 --- a/fastcrypto-tbls/src/tbls.rs +++ b/fastcrypto-tbls/src/tbls.rs @@ -71,7 +71,7 @@ pub trait ThresholdBls { partial_sigs: impl Iterator>>, rng: &mut R, ) -> FastCryptoResult<()> { - assert!(vss_pk.degree() > 0 || !msg.is_empty()); + assert!(vss_pk.degree_bound() > 0 || !msg.is_empty()); let (evals_as_scalars, points): (Vec<_>, Vec<_>) = partial_sigs .map(|sig| { let sig = sig.borrow(); @@ -83,7 +83,7 @@ pub trait ThresholdBls { } let rs = get_random_scalars::(points.len(), rng); // TODO: should we cache it instead? that would replace t-wide msm with w-wide msm. - let coeffs = batch_coefficients(&rs, &evals_as_scalars, vss_pk.degree()); + let coeffs = batch_coefficients(&rs, &evals_as_scalars, vss_pk.degree_bound()); let pk = Self::Public::multi_scalar_mul(&coeffs, vss_pk.as_vec()).expect("sizes match"); let aggregated_sig = Self::Signature::multi_scalar_mul(&rs, &points).expect("sizes match"); diff --git a/fastcrypto-tbls/src/tests/dkg_v1_tests.rs b/fastcrypto-tbls/src/tests/dkg_v1_tests.rs index 600a1f22d..9f0ee29d0 100644 --- a/fastcrypto-tbls/src/tests/dkg_v1_tests.rs +++ b/fastcrypto-tbls/src/tests/dkg_v1_tests.rs @@ -621,7 +621,7 @@ fn create_message_generates_valid_message() { assert_eq!(msg.sender, 1); assert_eq!(msg.encrypted_shares.len(), 4); - assert_eq!(msg.vss_pk.degree(), 2); + assert_eq!(msg.vss_pk.degree_bound(), 2); } #[test] diff --git a/fastcrypto-tbls/src/tests/polynomial_tests.rs b/fastcrypto-tbls/src/tests/polynomial_tests.rs index 9a35c3bdc..d9d4272c7 100644 --- a/fastcrypto-tbls/src/tests/polynomial_tests.rs +++ b/fastcrypto-tbls/src/tests/polynomial_tests.rs @@ -11,6 +11,7 @@ use fastcrypto::groups::bls12381::{G1Element, G2Element, Scalar as BlsScalar}; use fastcrypto::groups::ristretto255::{RistrettoPoint, RistrettoScalar}; use fastcrypto::groups::{GroupElement, MultiScalarMul, Scalar}; use rand::prelude::*; +use std::iter; use std::num::NonZeroU16; const I10: NonZeroU16 = unsafe { NonZeroU16::new_unchecked(10) }; @@ -21,10 +22,10 @@ mod scalar_tests { use itertools::Itertools; #[test] - fn test_degree() { + fn test_degree_bound() { let s: usize = 5; let p = Poly::::rand(s as u16, &mut thread_rng()); - assert_eq!(p.degree(), s); + assert_eq!(p.degree_bound(), s); } #[test] @@ -112,6 +113,101 @@ mod scalar_tests { Poly::interpolate_at_index(ShareIndex::new(7).unwrap(), &shares).unwrap_err(); } + #[test] + fn test_interpolate() { + let degree = 12; + let threshold = degree + 1; + let poly = Poly::::rand(degree, &mut thread_rng()); + let mut shares = (1..50) + .map(|i| poly.eval(ShareIndex::new(i).unwrap())) + .collect::>(); + for _ in 0..10 { + shares.shuffle(&mut thread_rng()); + let used_shares = shares + .iter() + .take(threshold as usize) + .cloned() + .collect_vec(); + let interpolated = Poly::interpolate(&used_shares).unwrap(); + assert_eq!(interpolated, poly); + } + + // Using too few shares + for _ in 0..10 { + shares.shuffle(&mut thread_rng()); + let used_shares = shares + .iter() + .take(threshold as usize - 1) + .cloned() + .collect_vec(); + let interpolated = Poly::interpolate(&used_shares).unwrap(); + assert_ne!(interpolated, poly); + } + + // Using duplicate shares should fail + let mut shares = (1..=threshold) + .map(|i| poly.eval(ShareIndex::new(i).unwrap())) + .collect_vec(); // duplicate value 1 + shares.push(poly.eval(ShareIndex::new(1).unwrap())); + Poly::interpolate(&shares).unwrap_err(); + } + + #[test] + fn test_division() { + let mut rng = thread_rng(); + let degree_a = 8; + let degree_b = 5; + let a = crate::polynomial::Poly::from( + iter::from_fn(|| Some(S::rand(&mut rng))) + .take(degree_a + 1) + .collect_vec(), + ); + let b = crate::polynomial::Poly::from( + iter::from_fn(|| Some(S::rand(&mut rng))) + .take(degree_b + 1) + .collect_vec(), + ); + + let (q, r) = a.div_rem(&b).unwrap(); + assert!(r.degree() < b.degree()); + + let mut lhs = &q * &b; + lhs += &r; + assert!(poly_eq(&lhs, &a)); + } + + #[test] + fn test_extended_gcd() { + let mut rng = thread_rng(); + let degree_a = 8; + let degree_b = 5; + let a = crate::polynomial::Poly::from( + iter::from_fn(|| Some(S::rand(&mut rng))) + .take(degree_a + 1) + .collect_vec(), + ); + let b = crate::polynomial::Poly::from( + iter::from_fn(|| Some(S::rand(&mut rng))) + .take(degree_b + 1) + .collect_vec(), + ); + + let (g, x, y) = Poly::extended_gcd(&a, &b).unwrap(); + + assert!(poly_eq(&(&x * &a + &(&y * &b)), &g)); + } + + #[test] + fn test_degree() { + let coefficients = [1, 2, 3, 0, 0].iter().map(|&x| S::from(x)).collect_vec(); + let mut a = crate::polynomial::Poly::from(coefficients); + assert_eq!(a.degree(), 2); + assert_eq!(a.degree_bound(), 4); + a.reduce(); + assert_eq!(a.degree(), 2); + assert_eq!(a.degree_bound(), 2); + } + #[instantiate_tests()] mod ristretto_scalar {} @@ -141,7 +237,7 @@ mod points_tests { let one = G::ScalarType::generator(); let coeff = vec![one, one, one]; let p = Poly::::from(coeff); - assert_eq!(p.degree(), 2); + assert_eq!(p.degree_bound(), 2); let s1 = p.eval(NonZeroU16::new(10).unwrap()); let s2 = p.eval(NonZeroU16::new(20).unwrap()); let s3 = p.eval(NonZeroU16::new(30).unwrap()); @@ -173,7 +269,7 @@ mod points_tests { let one = G::generator(); let coeff = vec![one, one, one]; let p = Poly::::from(coeff); - assert_eq!(p.degree(), 2); + assert_eq!(p.degree_bound(), 2); let s1 = p.eval(NonZeroU16::new(10).unwrap()); let s2 = p.eval(NonZeroU16::new(20).unwrap()); let s3 = p.eval(NonZeroU16::new(30).unwrap()); diff --git a/fastcrypto-tbls/src/threshold_schnorr/gao.rs b/fastcrypto-tbls/src/threshold_schnorr/gao.rs new file mode 100644 index 000000000..1fca7ea1b --- /dev/null +++ b/fastcrypto-tbls/src/threshold_schnorr/gao.rs @@ -0,0 +1,129 @@ +// Copyright (c) 2022, Mysten Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use crate::polynomial::{Eval, MonicLinear, Poly}; +use crate::threshold_schnorr::S; +use crate::types::{to_scalar, ShareIndex}; +use fastcrypto::error::{FastCryptoError, FastCryptoResult}; +use itertools::Itertools; + +/// Decoder for Reed-Solomon codes. +/// This can correct up to (d-1)/2 errors, where d is the distance of the code. +/// The code is defined by the evaluation points `a` and the message length `k`. +/// The distance is given by `n - k + 1`, where `n` is the length of `a`. +/// +/// The implementation follows the Gao decoding algorithm (see https://www.math.clemson.edu/~sgao/papers/RS.pdf). +pub struct RSDecoder { + g0: Poly, + a: Vec, + k: usize, +} + +impl RSDecoder { + /// Create a new Gao decoder with the given evaluation points `a` and message length `k`. + pub fn new(a: Vec, k: usize) -> Self { + assert!(k < a.len(), "Message length must be less than block length"); + let mut g0 = Poly::one(); + for ai in &a { + g0 *= MonicLinear(-to_scalar::(ai)); + } + Self { g0, a, k } + } + + /// The length of the code words. + fn block_length(&self) -> usize { + self.a.len() + } + + /// The length of the messages. + fn message_length(&self) -> usize { + self.k + } + + /// The distance of the code. + fn distance(&self) -> usize { + self.block_length() - self.message_length() + 1 + } + + /// Compute the message polynomial. + /// Returns an error if the input length is wrong or if there are too many errors to correct. + pub fn compute_message_polynomial(&self, code_word: &[S]) -> FastCryptoResult> { + // The implementation follows Algorithm 1 in Gao's paper. + + if code_word.len() != self.block_length() { + return Err(FastCryptoError::InputLengthWrong(self.block_length())); + } + + // Step 1: Interpolation + let g1 = Poly::interpolate( + &self + .a + .iter() + .zip(code_word) + .map(|(&index, &value)| Eval { index, value }) + .collect_vec(), + )?; + + // Step 2: Partial GCD + let (g, _, v) = Poly::partial_extended_gcd( + &self.g0, + &g1, + (self.message_length() + self.block_length()) / 2, + )?; + + // Step 3: Long division + let (f1, r) = g.div_rem(&v)?; + if !r.is_zero() || f1.degree() >= self.k { + return Err(FastCryptoError::TooManyErrors((self.distance() - 1) / 2)); + } + Ok(f1) + } + + /// Encode the message using the Reed-Solomon code defined by the evaluation points `a`. + /// Returns an error if the message length is wrong. + pub fn encode(&self, message: Vec) -> FastCryptoResult> { + if message.len() != self.message_length() { + return Err(FastCryptoError::InputLengthWrong(self.message_length())); + } + let f = Poly::from(message); + Ok(self.a.iter().map(|&ai| f.eval(ai).value).collect_vec()) + } + + /// Try to correct the input and return the decoded message. + /// Returns an error if the input length is wrong or if there are too many errors to correct. + pub fn decode(&self, input: &[S]) -> FastCryptoResult> { + let mut f1 = self.compute_message_polynomial(input)?.to_vec(); + f1.truncate(self.k); + Ok(f1) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_gao_decoder() { + let a = (1..=7).map(|i| ShareIndex::new(i).unwrap()).collect_vec(); + let k = 3; + let decoder = RSDecoder::new(a.clone(), k); + + let message = vec![S::from(11u128), S::from(22u128), S::from(33u128)]; + let code_word = decoder.encode(message.clone()).unwrap(); + + // Introduce errors + let mut received = code_word.clone(); + received[4] = S::from(20u128); // Error at position 4 + received[2] = S::from(200u128); // Error at position 2 + + let decoded_message = decoder.decode(&received).unwrap(); + assert_eq!(decoded_message, message); + + // Test with too many errors + let mut received = code_word.clone(); + received[4] = S::from(20u128); // Error at position 4 + received[3] = S::from(2000u128); // Error at position 3 + received[2] = S::from(200u128); // Error at position 2 + assert!(decoder.decode(&received).is_err()); + } +} diff --git a/fastcrypto-tbls/src/threshold_schnorr/mod.rs b/fastcrypto-tbls/src/threshold_schnorr/mod.rs index 3e4e1a76c..f48161b16 100644 --- a/fastcrypto-tbls/src/threshold_schnorr/mod.rs +++ b/fastcrypto-tbls/src/threshold_schnorr/mod.rs @@ -33,6 +33,7 @@ pub mod avss; pub mod batch_avss; mod bcs; pub mod complaint; +pub mod gao; pub mod key_derivation; mod pascal_matrix; pub mod presigning; diff --git a/fastcrypto-tbls/src/types.rs b/fastcrypto-tbls/src/types.rs index 7c74b5c4a..b5cd42bc4 100644 --- a/fastcrypto-tbls/src/types.rs +++ b/fastcrypto-tbls/src/types.rs @@ -4,8 +4,9 @@ use crate::polynomial::{Eval, PublicPoly}; use crate::tbls; use fastcrypto::error::{FastCryptoError, FastCryptoResult}; -use fastcrypto::groups::{bls12381, GroupElement, HashToGroupElement, Pairing}; +use fastcrypto::groups::{bls12381, GroupElement, HashToGroupElement, Pairing, Scalar}; use serde::{Deserialize, Serialize}; +use std::borrow::Borrow; use std::num::NonZeroU16; /// Implementation of [ThresholdBls] for BLS12-381-min-sig. A variant for BLS12-381-min-pk can be @@ -83,3 +84,8 @@ impl UnindexedValues { Ok(values) } } + +#[inline] +pub(crate) fn to_scalar(index: impl Borrow) -> C { + C::from(index.borrow().get() as u128) +} diff --git a/fastcrypto/src/error.rs b/fastcrypto/src/error.rs index 41e0e031e..54d04f6d4 100644 --- a/fastcrypto/src/error.rs +++ b/fastcrypto/src/error.rs @@ -56,6 +56,10 @@ pub enum FastCryptoError { #[error("Out of presigs in the iterator, please create new presigs")] OutOfPresigs, + /// Error in error decoding because there are too many errors to correct. + #[error("Too many errors to correct in error decoding. Up to {0} errors can be corrected.")] + TooManyErrors(usize), + /// General cryptographic error. #[error("General cryptographic error: {0}")] GeneralError(String),