Skip to content

hash2curve: move oversized DST requirements to runtime errors #1901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions elliptic-curve/src/hash2curve/group_digest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ pub trait GroupDigest: MapToCurve {
/// > hash function is used.
///
/// # Errors
/// See implementors of [`ExpandMsg`] for errors:
/// - [`ExpandMsgXmd`]
/// - [`ExpandMsgXof`]
/// - `len_in_bytes > u16::MAX`
/// - See implementors of [`ExpandMsg`] for additional errors:
/// - [`ExpandMsgXmd`]
/// - [`ExpandMsgXof`]
///
/// `len_in_bytes = <Self::FieldElement as FromOkm>::Length * 2`
///
Expand Down Expand Up @@ -53,9 +54,10 @@ pub trait GroupDigest: MapToCurve {
/// > points in this set are more likely to be output than others.
///
/// # Errors
/// See implementors of [`ExpandMsg`] for errors:
/// - [`ExpandMsgXmd`]
/// - [`ExpandMsgXof`]
/// - `len_in_bytes > u16::MAX`
/// - See implementors of [`ExpandMsg`] for additional errors:
/// - [`ExpandMsgXmd`]
/// - [`ExpandMsgXof`]
///
/// `len_in_bytes = <Self::FieldElement as FromOkm>::Length`
///
Expand All @@ -76,9 +78,10 @@ pub trait GroupDigest: MapToCurve {
/// and returns a scalar.
///
/// # Errors
/// See implementors of [`ExpandMsg`] for errors:
/// - [`ExpandMsgXmd`]
/// - [`ExpandMsgXof`]
/// - `len_in_bytes > u16::MAX`
/// - See implementors of [`ExpandMsg`] for additional errors:
/// - [`ExpandMsgXmd`]
/// - [`ExpandMsgXof`]
///
/// `len_in_bytes = <Self::Scalar as FromOkm>::Length`
///
Expand Down
14 changes: 8 additions & 6 deletions elliptic-curve/src/hash2curve/hash2field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

mod expand_msg;

use core::num::NonZeroUsize;
use core::num::NonZeroU16;

pub use expand_msg::{xmd::*, xof::*, *};

Expand All @@ -28,9 +28,10 @@ pub trait FromOkm {
/// <https://www.rfc-editor.org/rfc/rfc9380.html#name-hash_to_field-implementatio>
///
/// # Errors
/// See implementors of [`ExpandMsg`] for errors:
/// - [`ExpandMsgXmd`]
/// - [`ExpandMsgXof`]
/// - `len_in_bytes > u16::MAX`
/// - See implementors of [`ExpandMsg`] for additional errors:
/// - [`ExpandMsgXmd`]
/// - [`ExpandMsgXof`]
///
/// `len_in_bytes = T::Length * out.len()`
///
Expand All @@ -42,9 +43,10 @@ where
E: ExpandMsg<K>,
T: FromOkm + Default,
{
let len_in_bytes = T::Length::to_usize()
let len_in_bytes = T::Length::USIZE
.checked_mul(out.len())
.and_then(NonZeroUsize::new)
.and_then(|len| len.try_into().ok())
.and_then(NonZeroU16::new)
.ok_or(Error)?;
let mut tmp = Array::<u8, <T as FromOkm>::Length>::default();
let mut expander = E::expand_message(data, domain, len_in_bytes)?;
Expand Down
25 changes: 13 additions & 12 deletions elliptic-curve/src/hash2curve/hash2field/expand_msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use core::num::NonZero;

use crate::{Error, Result};
use digest::{Digest, ExtendableOutput, Update, XofReader};
use hybrid_array::typenum::{IsLess, True, U256};
use hybrid_array::{Array, ArraySize};

/// Salt when the DST is too long
Expand All @@ -34,7 +33,7 @@ pub trait ExpandMsg<K> {
fn expand_message<'dst>(
msg: &[&[u8]],
dst: &'dst [&[u8]],
len_in_bytes: NonZero<usize>,
len_in_bytes: NonZero<u16>,
) -> Result<Self::Expander<'dst>>;
}

Expand All @@ -50,20 +49,14 @@ pub trait Expander {
///
/// [dst]: https://www.rfc-editor.org/rfc/rfc9380.html#name-using-dsts-longer-than-255-
#[derive(Debug)]
pub(crate) enum Domain<'a, L>
where
L: ArraySize + IsLess<U256, Output = True>,
{
pub(crate) enum Domain<'a, L: ArraySize> {
/// > 255
Hashed(Array<u8, L>),
/// <= 255
Array(&'a [&'a [u8]]),
}

impl<'a, L> Domain<'a, L>
where
L: ArraySize + IsLess<U256, Output = True>,
{
impl<'a, L: ArraySize> Domain<'a, L> {
pub fn xof<X>(dst: &'a [&'a [u8]]) -> Result<Self>
where
X: Default + ExtendableOutput + Update,
Expand All @@ -72,6 +65,10 @@ where
if dst.iter().map(|slice| slice.len()).sum::<usize>() == 0 {
Err(Error)
} else if dst.iter().map(|slice| slice.len()).sum::<usize>() > MAX_DST_LEN {
if L::USIZE > u8::MAX.into() {
return Err(Error);
}

let mut data = Array::<u8, L>::default();
let mut hash = X::default();
hash.update(OVERSIZE_DST_SALT);
Expand All @@ -96,6 +93,10 @@ where
if dst.iter().map(|slice| slice.len()).sum::<usize>() == 0 {
Err(Error)
} else if dst.iter().map(|slice| slice.len()).sum::<usize>() > MAX_DST_LEN {
if L::USIZE > u8::MAX.into() {
return Err(Error);
}

Ok(Self::Hashed({
let mut hash = X::new();
hash.update(OVERSIZE_DST_SALT);
Expand Down Expand Up @@ -124,8 +125,8 @@ where

pub fn len(&self) -> u8 {
match self {
// Can't overflow because it's enforced on a type level.
Self::Hashed(_) => L::to_u8(),
// Can't overflow because it's checked on creation.
Self::Hashed(_) => L::U8,
// Can't overflow because it's checked on creation.
Self::Array(d) => {
u8::try_from(d.iter().map(|d| d.len()).sum::<usize>()).expect("length overflow")
Expand Down
41 changes: 15 additions & 26 deletions elliptic-curve/src/hash2curve/hash2field/expand_msg/xmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use digest::{
FixedOutput, HashMarker,
array::{
Array,
typenum::{IsGreaterOrEqual, IsLess, IsLessOrEqual, Prod, True, U2, U256, Unsigned},
typenum::{IsGreaterOrEqual, IsLessOrEqual, Prod, True, U2, Unsigned},
},
block_api::BlockSizeUser,
};
Expand All @@ -18,22 +18,17 @@ use digest::{
///
/// # Errors
/// - `dst` contains no bytes
/// - `len_in_bytes > u16::MAX`
/// - `dst > 255 && HashT::OutputSize > 255`
/// - `len_in_bytes > 255 * HashT::OutputSize`
#[derive(Debug)]
pub struct ExpandMsgXmd<HashT>(PhantomData<HashT>)
where
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLess<U256, Output = True>,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>;

impl<HashT, K> ExpandMsg<K> for ExpandMsgXmd<HashT>
where
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
// If DST is larger than 255 bytes, the length of the computed DST will depend on the output
// size of the hash, which is still not allowed to be larger than 255.
// https://www.rfc-editor.org/rfc/rfc9380.html#section-5.3.1-6
HashT::OutputSize: IsLess<U256, Output = True>,
// The number of bits output by `HashT` MUST be at most `HashT::BlockSize`.
// https://www.rfc-editor.org/rfc/rfc9380.html#section-5.3.1-4
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
Expand All @@ -47,17 +42,17 @@ where
fn expand_message<'dst>(
msg: &[&[u8]],
dst: &'dst [&[u8]],
len_in_bytes: NonZero<usize>,
len_in_bytes: NonZero<u16>,
) -> Result<Self::Expander<'dst>> {
let len_in_bytes_u16 = u16::try_from(len_in_bytes.get()).map_err(|_| Error)?;
let b_in_bytes = HashT::OutputSize::USIZE;

// `255 * <b_in_bytes>` can not exceed `u16::MAX`
if len_in_bytes_u16 > 255 * HashT::OutputSize::to_u16() {
if usize::from(len_in_bytes.get()) > 255 * b_in_bytes {
return Err(Error);
}

let b_in_bytes = HashT::OutputSize::to_usize();
let ell = u8::try_from(len_in_bytes.get().div_ceil(b_in_bytes)).map_err(|_| Error)?;
let ell = u8::try_from(usize::from(len_in_bytes.get()).div_ceil(b_in_bytes))
.expect("should never pass the previous check");

let domain = Domain::xmd::<HashT>(dst)?;
let mut b_0 = HashT::default();
Expand All @@ -67,7 +62,7 @@ where
b_0.update(msg);
}

b_0.update(&len_in_bytes_u16.to_be_bytes());
b_0.update(&len_in_bytes.get().to_be_bytes());
b_0.update(&[0]);
domain.update_hash(&mut b_0);
b_0.update(&[domain.len()]);
Expand Down Expand Up @@ -96,7 +91,6 @@ where
pub struct ExpanderXmd<'a, HashT>
where
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLess<U256, Output = True>,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
{
b_0: Array<u8, HashT::OutputSize>,
Expand All @@ -110,7 +104,6 @@ where
impl<HashT> ExpanderXmd<'_, HashT>
where
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLess<U256, Output = True>,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
{
fn next(&mut self) -> bool {
Expand Down Expand Up @@ -140,7 +133,6 @@ where
impl<HashT> Expander for ExpanderXmd<'_, HashT>
where
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLess<U256, Output = True>,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
{
fn fill_bytes(&mut self, okm: &mut [u8]) {
Expand All @@ -157,11 +149,10 @@ where
#[cfg(test)]
mod test {
use super::*;
use core::mem::size_of;
use hex_literal::hex;
use hybrid_array::{
ArraySize,
typenum::{U4, U8, U32, U128},
typenum::{IsLess, U4, U8, U32, U128, U65536},
};
use sha2::Sha256;

Expand All @@ -172,9 +163,8 @@ mod test {
bytes: &[u8],
) where
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLess<U256, Output = True>,
{
let block = HashT::BlockSize::to_usize();
let block = HashT::BlockSize::USIZE;
assert_eq!(
Array::<u8, HashT::BlockSize>::default().as_slice(),
&bytes[..block]
Expand Down Expand Up @@ -206,25 +196,24 @@ mod test {

impl TestVector {
#[allow(clippy::panic_in_result_fn)]
fn assert<HashT, L: ArraySize>(
fn assert<HashT, L>(
&self,
dst: &'static [u8],
domain: &Domain<'_, HashT::OutputSize>,
) -> Result<()>
where
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLess<U256, Output = True>
+ IsLessOrEqual<HashT::BlockSize, Output = True>
+ Mul<U8>,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
HashT::OutputSize: IsGreaterOrEqual<U8, Output = True>,
L: ArraySize + IsLess<U65536, Output = True>,
{
assert_message::<HashT>(self.msg, domain, L::to_u16(), self.msg_prime);
assert_message::<HashT>(self.msg, domain, L::U16, self.msg_prime);

let dst = [dst];
let mut expander = <ExpandMsgXmd<HashT> as ExpandMsg<U4>>::expand_message(
&[self.msg],
&dst,
NonZero::new(L::to_usize()).ok_or(Error)?,
NonZero::new(L::U16).ok_or(Error)?,
)?;

let mut uniform_bytes = Array::<u8, L>::default();
Expand Down
20 changes: 11 additions & 9 deletions elliptic-curve/src/hash2curve/hash2field/expand_msg/xof.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
//! `expand_message_xof` for the `ExpandMsg` trait

use super::{Domain, ExpandMsg, Expander};
use crate::{Error, Result};
use crate::Result;
use core::{fmt, num::NonZero, ops::Mul};
use digest::{
CollisionResistance, ExtendableOutput, HashMarker, Update, XofReader, typenum::IsGreaterOrEqual,
};
use hybrid_array::{
ArraySize,
typenum::{IsLess, Prod, True, U2, U256},
typenum::{Prod, True, U2},
};

/// Implements `expand_message_xof` via the [`ExpandMsg`] trait:
/// <https://www.rfc-editor.org/rfc/rfc9380.html#name-expand_message_xof>
///
/// # Errors
/// - `dst` contains no bytes
/// - `len_in_bytes > u16::MAX`
/// - `dst > 255 && K * 2 > 255`
pub struct ExpandMsgXof<HashT>
where
HashT: Default + ExtendableOutput + Update + HashMarker,
Expand All @@ -41,7 +41,7 @@ where
HashT: Default + ExtendableOutput + Update + HashMarker,
// If DST is larger than 255 bytes, the length of the computed DST is calculated by `K * 2`.
// https://www.rfc-editor.org/rfc/rfc9380.html#section-5.3.1-2.1
K: Mul<U2, Output: ArraySize + IsLess<U256, Output = True>>,
K: Mul<U2, Output: ArraySize>,
// The collision resistance of `HashT` MUST be at least `K` bits.
// https://www.rfc-editor.org/rfc/rfc9380.html#section-5.3.2-2.1
HashT: CollisionResistance<CollisionResistance: IsGreaterOrEqual<K, Output = True>>,
Expand All @@ -51,9 +51,9 @@ where
fn expand_message<'dst>(
msg: &[&[u8]],
dst: &'dst [&[u8]],
len_in_bytes: NonZero<usize>,
len_in_bytes: NonZero<u16>,
) -> Result<Self::Expander<'dst>> {
let len_in_bytes = u16::try_from(len_in_bytes.get()).map_err(|_| Error)?;
let len_in_bytes = len_in_bytes.get();

let domain = Domain::<Prod<K, U2>>::xof::<HashT>(dst)?;
let mut reader = HashT::default();
Expand Down Expand Up @@ -81,12 +81,14 @@ where

#[cfg(test)]
mod test {
use crate::Error;

use super::*;
use core::mem::size_of;
use hex_literal::hex;
use hybrid_array::{
Array, ArraySize,
typenum::{U16, U32, U128},
typenum::{IsLess, U16, U32, U128, U65536},
};
use sha3::Shake128;

Expand Down Expand Up @@ -124,14 +126,14 @@ mod test {
+ Update
+ HashMarker
+ CollisionResistance<CollisionResistance: IsGreaterOrEqual<U16, Output = True>>,
L: ArraySize,
L: ArraySize + IsLess<U65536, Output = True>,
{
assert_message(self.msg, domain, L::to_u16(), self.msg_prime);

let mut expander = <ExpandMsgXof<HashT> as ExpandMsg<U16>>::expand_message(
&[self.msg],
&[dst],
NonZero::new(L::to_usize()).ok_or(Error)?,
NonZero::new(L::U16).ok_or(Error)?,
)?;

let mut uniform_bytes = Array::<u8, L>::default();
Expand Down