Skip to content

Commit aec741e

Browse files
committed
add more hkdf function and zeroize
1 parent 7951db2 commit aec741e

File tree

2 files changed

+179
-15
lines changed

2 files changed

+179
-15
lines changed

mbedtls/src/hash/mod.rs

Lines changed: 175 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,21 @@ impl MdInfo {
7878
}
7979
}
8080

81+
impl Clone for Md {
82+
fn clone(&self) -> Self {
83+
fn copy_md(md: &Md) -> Result<Md> {
84+
let mut ctx = Md::init();
85+
unsafe {
86+
md_setup(&mut ctx.inner, md.inner.md_info, 0).into_result()?;
87+
md_starts(&mut ctx.inner).into_result()?;
88+
md_clone(&mut ctx.inner, &md.inner).into_result()?;
89+
};
90+
Ok(ctx)
91+
}
92+
copy_md(self).expect("Md::copy success")
93+
}
94+
}
95+
8196
impl Md {
8297
pub fn new(md: Type) -> Result<Md> {
8398
let md: MdInfo = match md.into() {
@@ -126,6 +141,7 @@ impl Md {
126141
}
127142
}
128143

144+
#[derive(Clone)]
129145
pub struct Hmac {
130146
ctx: Md,
131147
}
@@ -178,12 +194,31 @@ impl Hmac {
178194
}
179195
}
180196

181-
pub struct Hkdf {
182-
_ctx: Md,
183-
}
197+
/// The HMAC-based Extract-and-Expand Key Derivation Function (HKDF) is specified by RFC 5869.
198+
#[derive(Debug)]
199+
pub struct Hkdf;
184200

185201
impl Hkdf {
186-
pub fn hkdf(md: Type, salt: &[u8], ikm: &[u8], info: &[u8], key: &mut [u8]) -> Result<()> {
202+
/// This is the HMAC-based Extract-and-Expand Key Derivation Function (HKDF).
203+
///
204+
/// # Parameters
205+
///
206+
/// * `md`: A hash function; `MdInfo::from(md).size()` denotes the length of the hash
207+
/// function output in bytes.
208+
/// * `salt`: An salt value (a non-secret random value);
209+
/// * `ikm`: The input keying material.
210+
/// * `info`: An optional context and application specific information
211+
/// string. This can be a zero-length string.
212+
/// * `okm`: The output keying material. The length of the output keying material in bytes
213+
/// must be less than or equal to 255 * `MdInfo::from(md).size()` bytes.
214+
///
215+
/// # Returns
216+
///
217+
/// * `()` on success.
218+
/// * [`Error::HkdfBadInputData`] when the parameters are invalid.
219+
/// * Any `Error::Md*` error for errors returned from the underlying
220+
/// MD layer.
221+
pub fn hkdf(md: Type, salt: &[u8], ikm: &[u8], info: &[u8], okm: &mut [u8]) -> Result<()> {
187222
let md: MdInfo = match md.into() {
188223
Some(md) => md,
189224
None => return Err(Error::MdBadInputData),
@@ -198,24 +233,149 @@ impl Hkdf {
198233
ikm.len(),
199234
info.as_ptr(),
200235
info.len(),
201-
key.as_mut_ptr(),
202-
key.len(),
236+
okm.as_mut_ptr(),
237+
okm.len(),
203238
)
204239
.into_result()?;
205240
Ok(())
206241
}
207242
}
208-
}
209243

210-
impl Clone for Md {
211-
fn clone(&self) -> Self {
212-
fn copy_md(md: &Md) -> Result<Md> {
213-
let md_type = unsafe { md_get_type(md.inner.md_info) };
214-
let mut m = Md::new(md_type.into())?;
215-
unsafe { md_clone(&mut m.inner, &md.inner) }.into_result()?;
216-
Ok(m)
244+
/// This is the HMAC-based Extract-and-Expand Key Derivation Function (HKDF).
245+
///
246+
/// # Parameters
247+
///
248+
/// * `md`: A hash function; `MdInfo::from(md).size()` denotes the length of the hash
249+
/// function output in bytes.
250+
/// * `salt`: An optional salt value (a non-secret random value);
251+
/// if the salt is not provided, a string of all zeros of
252+
/// `MdInfo::from(md).size()` length is used as the salt.
253+
/// * `ikm`: The input keying material.
254+
/// * `info`: An optional context and application specific information
255+
/// string. This can be a zero-length string.
256+
/// * `okm`: The output keying material. The length of the output keying material in bytes
257+
/// must be less than or equal to 255 * `MdInfo::from(md).size()` bytes.
258+
///
259+
/// # Returns
260+
///
261+
/// * `()` on success.
262+
/// * [`Error::HkdfBadInputData`] when the parameters are invalid.
263+
/// * Any `Error::Md*` error for errors returned from the underlying
264+
/// MD layer.
265+
pub fn hkdf_optional_salt(md: Type, maybe_salt: Option<&[u8]>, ikm: &[u8], info: &[u8], okm: &mut [u8]) -> Result<()> {
266+
let md: MdInfo = match md.into() {
267+
Some(md) => md,
268+
None => return Err(Error::MdBadInputData),
269+
};
270+
271+
unsafe {
272+
hkdf(
273+
md.inner,
274+
maybe_salt.map_or(::core::ptr::null(), |salt| salt.as_ptr()),
275+
maybe_salt.map_or(0, |salt| salt.len()),
276+
ikm.as_ptr(),
277+
ikm.len(),
278+
info.as_ptr(),
279+
info.len(),
280+
okm.as_mut_ptr(),
281+
okm.len(),
282+
)
283+
.into_result()?;
284+
Ok(())
285+
}
286+
}
287+
288+
/// Takes the input keying material `ikm` and extracts from it a
289+
/// fixed-length pseudorandom key `prk`.
290+
///
291+
/// # Warning
292+
///
293+
/// This function should only be used if the security of it has been
294+
/// studied and established in that particular context (eg. TLS 1.3
295+
/// key schedule). For standard HKDF security guarantees use
296+
/// `hkdf` instead.
297+
///
298+
/// # Parameters
299+
///
300+
/// * `md`: A hash function; `MdInfo::from(md).size()` denotes the length of the
301+
/// hash function output in bytes.
302+
/// * `salt`: An optional salt value (a non-secret random value);
303+
/// if the salt is not provided, a string of all zeros
304+
/// of `MdInfo::from(md).size()` length is used as the salt.
305+
/// * `ikm`: The input keying material.
306+
/// * `prk`: The output pseudorandom key of at least `MdInfo::from(md).size()` bytes.
307+
///
308+
/// # Returns
309+
///
310+
/// * `()` on success.
311+
/// * [`Error::HkdfBadInputData`] when the parameters are invalid.
312+
/// * Any `Error::Md*` error for errors returned from the underlying
313+
/// MD layer.
314+
pub fn hkdf_extract(md: Type, maybe_salt: Option<&[u8]>, ikm: &[u8], prk: &mut [u8]) -> Result<()> {
315+
let md: MdInfo = match md.into() {
316+
Some(md) => md,
317+
None => return Err(Error::MdBadInputData),
318+
};
319+
320+
unsafe {
321+
hkdf_extract(
322+
md.inner,
323+
maybe_salt.map_or(::core::ptr::null(), |salt| salt.as_ptr()),
324+
maybe_salt.map_or(0, |salt| salt.len()),
325+
ikm.as_ptr(),
326+
ikm.len(),
327+
prk.as_mut_ptr(),
328+
)
329+
.into_result()?;
330+
Ok(())
331+
}
332+
}
333+
334+
/// Expand the supplied `prk` into several additional pseudorandom keys, which is the output of the HKDF.
335+
///
336+
/// # Warning
337+
///
338+
/// This function should only be used if the security of it has been
339+
/// studied and established in that particular context (eg. TLS 1.3
340+
/// key schedule). For standard HKDF security guarantees use
341+
/// `hkdf` instead.
342+
///
343+
/// # Parameters
344+
///
345+
/// * `md`: A hash function; `MdInfo::from(md).size()` denotes the length of the
346+
/// hash function output in bytes.
347+
/// * `prk`: A pseudorandom key of at least `MdInfo::from(md).size()` bytes. `prk` is
348+
/// usually the output from the HKDF extract step.
349+
/// * `info`: An optional context and application specific information
350+
/// string. This can be a zero-length string.
351+
/// * `okm`: The output keying material. The length of the output keying material in bytes
352+
/// must be less than or equal to 255 * `MdInfo::from(md).size()` bytes.
353+
///
354+
/// # Returns
355+
///
356+
/// * `()` on success.
357+
/// * [`Error::HkdfBadInputData`] when the parameters are invalid.
358+
/// * Any `Error::Md*` error for errors returned from the underlying
359+
/// MD layer.
360+
pub fn hkdf_expand(md: Type, prk: &[u8], info: &[u8], okm: &mut [u8]) -> Result<()> {
361+
let md: MdInfo = match md.into() {
362+
Some(md) => md,
363+
None => return Err(Error::MdBadInputData),
364+
};
365+
366+
unsafe {
367+
hkdf_expand(
368+
md.inner,
369+
prk.as_ptr(),
370+
prk.len(),
371+
info.as_ptr(),
372+
info.len(),
373+
okm.as_mut_ptr(),
374+
okm.len(),
375+
)
376+
.into_result()?;
377+
Ok(())
217378
}
218-
copy_md(self).expect("Md::copy success")
219379
}
220380
}
221381

mbedtls/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ pub mod x509;
4444
#[cfg(feature = "pkcs12")]
4545
pub mod pkcs12;
4646

47+
pub fn zeroize(buf: &mut [u8]) {
48+
unsafe { mbedtls_sys::platform_zeroize(buf.as_mut_ptr() as *mut mbedtls_sys::types::raw_types::c_void, buf.len()) }
49+
}
50+
4751
// ==============
4852
// Utility
4953
// ==============

0 commit comments

Comments
 (0)