Skip to content

Commit

Permalink
refactor: add byte array type for deserialization (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobkaufmann authored Mar 13, 2024
1 parent 148b4e7 commit 67509da
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 30 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ version = "0.1.0"
edition = "2021"

[dependencies]
alloy-primitives = { version = "0.4.2", features = ["std", "serde"] }
blst = "0.3.11"
criterion = "0.5.1"
hex = "0.4.3"
rand = { version = "0.8.5", optional = true }
serde = { version = "1.0.189", features = ["derive"] }
serde_json = "1.0.107"
Expand Down
52 changes: 52 additions & 0 deletions src/bytes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use serde::{de::Visitor, Deserialize};

#[derive(Clone, Debug)]
pub struct Bytes(Vec<u8>);

impl AsRef<[u8]> for Bytes {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}

pub struct BytesVisitor;

impl<'de> Visitor<'de> for BytesVisitor {
type Value = Vec<u8>;

fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
formatter.write_str(
"a variable-length byte array represented by a raw byte array or a hex-encoded string",
)
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
let v = v.strip_prefix("0x").unwrap_or(v);
let v = hex::decode(v).map_err(E::custom)?;
Ok(v)
}

fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
let v = hex::decode(v).map_err(E::custom)?;
Ok(v)
}
}

impl<'de> Deserialize<'de> for Bytes {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
if deserializer.is_human_readable() {
deserializer.deserialize_str(BytesVisitor).map(Bytes)
} else {
deserializer.deserialize_bytes(BytesVisitor).map(Bytes)
}
}
}
33 changes: 9 additions & 24 deletions src/kzg/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@ use std::{
use super::{Bytes32, Bytes48, Commitment, Error, Polynomial, Proof};
use crate::{
blob::{Blob, Error as BlobError},
bls::{self, Decompress, ECGroupError, Error as BlsError, Fr, P1, P2},
bls::{self, Decompress, Error as BlsError, Fr, P1, P2},
bytes::Bytes,
math,
};

use alloy_primitives::{hex, Bytes, FixedBytes};
use serde::Deserialize;

#[derive(Debug)]
pub enum LoadSetupError {
Bls(BlsError),
Io(io::Error),
Hex(hex::FromHexError),
Serde(serde_json::Error),
InvalidLenG1Lagrange,
InvalidLenG2Monomial,
}

#[derive(serde::Deserialize, serde::Serialize)]
#[derive(Deserialize)]
struct SetupUnchecked {
g1_lagrange: Vec<Bytes>,
g2_monomial: Vec<Bytes>,
Expand Down Expand Up @@ -52,13 +52,6 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {

let mut g1_lagrange: Box<[P1; G1]> = Box::new([P1::default(); G1]);
for (i, point) in setup.g1_lagrange.iter().enumerate() {
if point.len() != 48 {
return Err(LoadSetupError::Bls(BlsError::from(
ECGroupError::InvalidEncoding,
)));
}
// TODO: skip unnecessary allocation
let point = FixedBytes::<48>::from_slice(point);
let point =
P1::decompress(point).map_err(|err| LoadSetupError::Bls(BlsError::from(err)))?;
g1_lagrange[i] = point;
Expand All @@ -67,13 +60,6 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {

let mut g2_monomial: Box<[P2; G2]> = Box::new([P2::default(); G2]);
for (i, point) in setup.g2_monomial.iter().enumerate() {
if point.len() != 96 {
return Err(LoadSetupError::Bls(BlsError::from(
ECGroupError::InvalidEncoding,
)));
}
// TODO: skip unnecessary allocation
let point = FixedBytes::<96>::from_slice(point);
let point =
P2::decompress(point).map_err(|err| LoadSetupError::Bls(BlsError::from(err)))?;
g2_monomial[i] = point;
Expand Down Expand Up @@ -287,12 +273,11 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
mod tests {
use super::*;

use crate::{
bls::Compress,
kzg::spec::{
BlobToCommitment, ComputeBlobProof, ComputeProof, VerifyBlobProof,
VerifyBlobProofBatch, VerifyProof,
},
use crate::bls::Compress;

use crate::kzg::spec::{
BlobToCommitment, ComputeBlobProof, ComputeProof, VerifyBlobProof, VerifyBlobProofBatch,
VerifyProof,
};

use std::{
Expand Down
12 changes: 7 additions & 5 deletions src/kzg/spec.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
use alloy_primitives::{Bytes, FixedBytes};
use serde::Deserialize;

use crate::bls::{Fr, P1};
use crate::{
bls::{Fr, P1},
bytes::Bytes,
};

use super::{Bytes32, Bytes48};

fn bytes32_from_bytes(bytes: &Bytes) -> Option<Bytes32> {
let bytes = FixedBytes::<{ Fr::BYTES }>::try_from(bytes.as_ref()).ok();
let bytes: Option<[u8; Fr::BYTES]> = TryFrom::try_from(bytes.as_ref()).ok();
bytes.map(Into::<Bytes32>::into)
}

fn bytes48_from_bytes(bytes: &Bytes) -> Option<Bytes48> {
let bytes = FixedBytes::<{ P1::BYTES }>::try_from(bytes.as_ref()).ok()?;
Some(bytes.into())
let bytes: Option<[u8; P1::BYTES]> = TryFrom::try_from(bytes.as_ref()).ok();
bytes.map(Into::<Bytes48>::into)
}

#[derive(Deserialize)]
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod bls;
mod bytes;
mod math;

pub use bls::{Compress, Decompress};
Expand Down

0 comments on commit 67509da

Please sign in to comment.