Skip to content

Commit

Permalink
feat: add KZG methods that accept raw bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobkaufmann committed Feb 17, 2024
1 parent 17dd1fd commit eb80569
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 18 deletions.
6 changes: 3 additions & 3 deletions src/bls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,13 +355,13 @@ impl P1 {
pub const BITS: usize = 384;
pub const BYTES: usize = Self::BITS / 8;

pub fn deserialize(bytes: impl AsRef<[u8; Self::BYTES]>) -> Result<Self, ECGroupError> {
pub fn deserialize(bytes: &[u8; Self::BYTES]) -> Result<Self, ECGroupError> {
let mut affine = MaybeUninit::<blst_p1_affine>::uninit();
let mut out = MaybeUninit::<blst_p1>::uninit();
unsafe {
// NOTE: deserialize performs a curve check but not a subgroup check. if that changes,
// then we should encounter `unreachable` for `BLST_POINT_NOT_IN_GROUP` in tests.
match blst_p1_deserialize(affine.as_mut_ptr(), bytes.as_ref().as_ptr()) {
match blst_p1_deserialize(affine.as_mut_ptr(), bytes.as_ptr()) {
BLST_ERROR::BLST_SUCCESS => {}
BLST_ERROR::BLST_BAD_ENCODING => return Err(ECGroupError::InvalidEncoding),
BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(ECGroupError::NotOnCurve),
Expand Down Expand Up @@ -515,7 +515,7 @@ impl P2 {
pub const BITS: usize = 768;
pub const BYTES: usize = Self::BITS / 8;

pub fn deserialize(bytes: impl AsRef<[u8; Self::BYTES]>) -> Result<Self, ECGroupError> {
pub fn deserialize(bytes: &[u8; Self::BYTES]) -> Result<Self, ECGroupError> {
let mut affine = MaybeUninit::<blst_p2_affine>::uninit();
let mut out = MaybeUninit::<blst_p2>::uninit();
unsafe {
Expand Down
15 changes: 14 additions & 1 deletion src/kzg/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::bls;
use crate::{blob, bls};

mod poly;
mod setup;
Expand All @@ -7,9 +7,22 @@ pub type Proof = bls::P1;
pub type Commitment = bls::P1;

pub enum Error {
Blob(blob::Error),
Bls(bls::Error),
}

impl From<blob::Error> for Error {
fn from(value: blob::Error) -> Self {
Self::Blob(value)
}
}

impl From<bls::Error> for Error {
fn from(value: bls::Error) -> Self {
Self::Bls(value)
}
}

pub(crate) use poly::Polynomial;

pub use setup::Setup;
129 changes: 115 additions & 14 deletions src/kzg/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use std::{
path::Path,
};

use super::{Commitment, Polynomial, Proof};
use super::{Commitment, Error, Polynomial, Proof};
use crate::{
blob::Blob,
blob::{Blob, Error as BlobError},
bls::{self, ECGroupError, Error as BlsError, Fr, P1, P2},
math,
};
Expand Down Expand Up @@ -60,7 +60,7 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
// TODO: skip unnecessary allocation
let point = FixedBytes::<48>::from_slice(point);
let point =
P1::deserialize(point).map_err(|err| LoadSetupError::Bls(BlsError::from(err)))?;
P1::deserialize(&point).map_err(|err| LoadSetupError::Bls(BlsError::from(err)))?;
g1_lagrange[i] = point;
}
let g1_lagrange_brp = math::bit_reversal_permutation_boxed_array(g1_lagrange.as_slice());
Expand All @@ -75,7 +75,7 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
// TODO: skip unnecessary allocation
let point = FixedBytes::<96>::from_slice(point);
let point =
P2::deserialize(point).map_err(|err| LoadSetupError::Bls(BlsError::from(err)))?;
P2::deserialize(&point).map_err(|err| LoadSetupError::Bls(BlsError::from(err)))?;
g2_monomial[i] = point;
}

Expand All @@ -101,6 +101,25 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
bls::verify_pairings(pairing1, pairing2)
}

pub fn verify_proof_raw(
&self,
proof: &[u8; Proof::BYTES],
commitment: &[u8; Commitment::BYTES],
point: &[u8; Fr::BYTES],
eval: &[u8; Fr::BYTES],
) -> Result<bool, Error> {
let proof = Proof::deserialize(proof).map_err(|err| Error::from(BlsError::ECGroup(err)))?;
let commitment = Commitment::deserialize(commitment)
.map_err(|err| Error::from(BlsError::ECGroup(err)))?;
let point =
Fr::from_be_slice(point).map_err(|err| Error::from(BlsError::FiniteField(err)))?;
let eval =
Fr::from_be_slice(eval).map_err(|err| Error::from(BlsError::FiniteField(err)))?;

let verified = self.verify_proof(&proof, &commitment, &point, &eval);
Ok(verified)
}

pub fn verify_proof_batch(
&self,
proofs: impl AsRef<[Proof]>,
Expand Down Expand Up @@ -153,10 +172,43 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
blob.commitment(self)
}

pub fn blob_to_commitment_raw(&self, blob: impl AsRef<[u8]>) -> Result<Commitment, BlobError> {
let blob = Blob::<G1>::from_slice(blob)?;
let commitment = self.blob_to_commitment(&blob);
Ok(commitment)
}

pub fn blob_proof(&self, blob: &Blob<G1>, commitment: &Commitment) -> Proof {
blob.proof(commitment, self)
}

pub fn blob_proof_raw(
&self,
blob: impl AsRef<[u8]>,
commitment: &[u8; Commitment::BYTES],
) -> Result<Proof, Error> {
let blob: Blob<G1> = Blob::from_slice(blob).map_err(Error::from)?;
let commitment = Commitment::deserialize(commitment)
.map_err(|err| Error::from(BlsError::ECGroup(err)))?;
let proof = blob.proof(&commitment, self);
Ok(proof)
}

pub fn proof_raw(
&self,
blob: impl AsRef<[u8]>,
point: &[u8; Fr::BYTES],
) -> Result<Proof, Error> {
let blob: Blob<G1> = Blob::from_slice(blob).map_err(Error::from)?;
let point =
Fr::from_be_slice(point).map_err(|err| Error::from(BlsError::FiniteField(err)))?;

let poly = Polynomial(&blob.elements);
let (_eval, proof) = poly.prove(point, self);

Ok(proof)
}

pub fn verify_blob_proof(
&self,
blob: &Blob<G1>,
Expand All @@ -169,6 +221,21 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
self.verify_proof(proof, commitment, &challenge, &eval)
}

pub fn verify_blob_proof_raw(
&self,
blob: impl AsRef<[u8]>,
commitment: &[u8; Commitment::BYTES],
proof: &[u8; Proof::BYTES],
) -> Result<bool, Error> {
let blob: Blob<G1> = Blob::from_slice(blob).map_err(Error::from)?;
let commitment = Commitment::deserialize(commitment)
.map_err(|err| Error::from(BlsError::ECGroup(err)))?;
let proof = Proof::deserialize(proof).map_err(|err| Error::from(BlsError::ECGroup(err)))?;

let verified = self.verify_blob_proof(&blob, &commitment, &proof);
Ok(verified)
}

pub fn verify_blob_proof_batch(
&self,
blobs: impl AsRef<[Blob<G1>]>,
Expand All @@ -192,6 +259,40 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {

self.verify_proof_batch(proofs, commitments, challenges, evaluations)
}

pub fn verify_blob_proof_batch_raw<'a, 'b, B>(
&self,
blobs: impl AsRef<[B]>,
commitments: impl AsRef<[&'a [u8; Commitment::BYTES]]>,
proofs: impl AsRef<[&'b [u8; Proof::BYTES]]>,
) -> Result<bool, Error>
where
B: AsRef<[u8]>,
{
assert_eq!(blobs.as_ref().len(), commitments.as_ref().len());
assert_eq!(commitments.as_ref().len(), proofs.as_ref().len());

let blobs: Result<Vec<Blob<G1>>, _> =
blobs.as_ref().iter().map(Blob::<G1>::from_slice).collect();
let blobs = blobs.map_err(Error::from)?;

let commitments: Result<Vec<Commitment>, _> = commitments
.as_ref()
.iter()
.map(|comm| Commitment::deserialize(comm))
.collect();
let commitments = commitments.map_err(|err| Error::from(BlsError::ECGroup(err)))?;

let proofs: Result<Vec<Proof>, _> = proofs
.as_ref()
.iter()
.map(|proof| Proof::deserialize(proof))
.collect();
let proofs = proofs.map_err(|err| Error::from(BlsError::ECGroup(err)))?;

let verified = self.verify_blob_proof_batch(blobs, commitments, proofs);
Ok(verified)
}
}

#[cfg(test)]
Expand Down Expand Up @@ -287,7 +388,7 @@ mod tests {
return Err(());
}
let commitment = FixedBytes::<{ Commitment::BYTES }>::from_slice(&unchecked.commitment);
let commitment = Commitment::deserialize(commitment).map_err(|_| ())?;
let commitment = Commitment::deserialize(&commitment).map_err(|_| ())?;
Ok(Self { blob, commitment })
}
}
Expand Down Expand Up @@ -319,7 +420,7 @@ mod tests {
return Err(());
}
let commitment = FixedBytes::<{ Commitment::BYTES }>::from_slice(&unchecked.commitment);
let commitment = Commitment::deserialize(commitment).map_err(|_| ())?;
let commitment = Commitment::deserialize(&commitment).map_err(|_| ())?;

let z = Fr::from_be_slice(unchecked.z).map_err(|_| ())?;
let y = Fr::from_be_slice(unchecked.y).map_err(|_| ())?;
Expand All @@ -328,7 +429,7 @@ mod tests {
return Err(());
}
let proof = FixedBytes::<{ Proof::BYTES }>::from_slice(&unchecked.proof);
let proof = Proof::deserialize(proof).map_err(|_| ())?;
let proof = Proof::deserialize(&proof).map_err(|_| ())?;

Ok(Self {
commitment,
Expand Down Expand Up @@ -365,12 +466,12 @@ mod tests {
return Err(());
}
let commitment = FixedBytes::<{ Commitment::BYTES }>::from_slice(&unchecked.commitment);
let commitment = Commitment::deserialize(commitment).map_err(|_| ())?;
let commitment = Commitment::deserialize(&commitment).map_err(|_| ())?;
if unchecked.proof.len() != Proof::BYTES {
return Err(());
}
let proof = FixedBytes::<{ Proof::BYTES }>::from_slice(&unchecked.proof);
let proof = Proof::deserialize(proof).map_err(|_| ())?;
let proof = Proof::deserialize(&proof).map_err(|_| ())?;
Ok(Self {
blob,
commitment,
Expand Down Expand Up @@ -420,7 +521,7 @@ mod tests {
return Err(());
}
let commitment = FixedBytes::<{ Commitment::BYTES }>::from_slice(&commitment);
let commitment = Commitment::deserialize(commitment).map_err(|_| ())?;
let commitment = Commitment::deserialize(&commitment).map_err(|_| ())?;
commitments.push(commitment);
}

Expand All @@ -430,7 +531,7 @@ mod tests {
return Err(());
}
let proof = FixedBytes::<{ Proof::BYTES }>::from_slice(&proof);
let proof = Proof::deserialize(proof).map_err(|_| ())?;
let proof = Proof::deserialize(&proof).map_err(|_| ())?;
proofs.push(proof);
}

Expand Down Expand Up @@ -478,7 +579,7 @@ mod tests {
Ok(input) => {
let (proof, eval) = case.output.unwrap();
let expected_eval = Fr::from_be_bytes(eval).unwrap();
let expected_proof = P1::deserialize(proof).unwrap();
let expected_proof = P1::deserialize(&proof).unwrap();

let poly = Polynomial(&input.blob.elements);
let eval = poly.evaluate(input.z, &setup);
Expand Down Expand Up @@ -508,7 +609,7 @@ mod tests {
match ComputeBlobKzgProofInput::from_unchecked(case.input) {
Ok(input) => {
let proof = case.output.unwrap();
let proof = P1::deserialize(proof).unwrap();
let proof = P1::deserialize(&proof).unwrap();
let expected_proof = Proof::from(proof);

let proof = setup.blob_proof(&input.blob, &input.commitment);
Expand Down Expand Up @@ -537,7 +638,7 @@ mod tests {
match BlobToCommitmentInput::from_unchecked(case.input) {
Ok(input) => {
let comm = case.output.unwrap();
let comm = P1::deserialize(comm).unwrap();
let comm = P1::deserialize(&comm).unwrap();
let expected_comm = Commitment::from(comm);

let comm = input.blob.commitment(&setup);
Expand Down

0 comments on commit eb80569

Please sign in to comment.