From 4e266df7e4348787aea83906027d7e506e9db26d Mon Sep 17 00:00:00 2001 From: mridul Date: Fri, 8 Nov 2024 21:14:25 +0530 Subject: [PATCH 01/11] PROD-9297 porting mbedtls combined error PR - initial commit --- Cargo.lock | 117 ++++----- mbedtls/Cargo.toml | 4 +- mbedtls/src/bignum/mod.rs | 23 +- mbedtls/src/cipher/raw/mod.rs | 12 +- mbedtls/src/ecp/mod.rs | 24 +- mbedtls/src/error.rs | 400 +++++++++++++++++++++---------- mbedtls/src/hash/mod.rs | 28 +-- mbedtls/src/lib.rs | 2 +- mbedtls/src/pk/dsa/mod.rs | 33 +-- mbedtls/src/pk/ec.rs | 4 +- mbedtls/src/pk/mod.rs | 98 ++++---- mbedtls/src/pkcs12/mod.rs | 7 +- mbedtls/src/private.rs | 26 +- mbedtls/src/ssl/async_io.rs | 35 ++- mbedtls/src/ssl/config.rs | 16 +- mbedtls/src/ssl/context.rs | 30 +-- mbedtls/src/ssl/io.rs | 24 +- mbedtls/src/x509/certificate.rs | 44 ++-- mbedtls/src/x509/csr.rs | 12 +- mbedtls/tests/async_session.rs | 20 +- mbedtls/tests/bignum.rs | 6 +- mbedtls/tests/client_server.rs | 30 +-- mbedtls/tests/ec.rs | 4 +- mbedtls/tests/hyper.rs | 2 +- mbedtls/tests/rsa.rs | 4 +- mbedtls/tests/ssl_conf_ca_cb.rs | 8 +- mbedtls/tests/ssl_conf_verify.rs | 8 +- 27 files changed, 567 insertions(+), 454 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 017c5735b..adbbadb3a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "anes" @@ -31,8 +31,8 @@ version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e4655ae1a7b0cdf149156f780c5bf3f1352bc53cbd9e0a361a7ef7b22947e965" dependencies = [ - "proc-macro2 1.0.66", - "quote 1.0.33", + "proc-macro2", + "quote", "syn 1.0.64", ] @@ -72,12 +72,12 @@ dependencies = [ "log 0.4.8", "peeking_take_while", "prettyplease", - "proc-macro2 1.0.66", - "quote 1.0.33", + "proc-macro2", + "quote", "regex", "rustc-hash", "shlex", - "syn 2.0.32", + "syn 2.0.87", "which", ] @@ -429,8 +429,8 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6" dependencies = [ - "proc-macro2 1.0.66", - "quote 1.0.33", + "proc-macro2", + "quote", "syn 1.0.64", ] @@ -733,7 +733,7 @@ dependencies = [ "lazy_static", "libc", "libz-sys", - "quote 1.0.33", + "quote", "syn 1.0.64", ] @@ -915,44 +915,26 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae005bd773ab59b4725093fd7df83fd7892f7d8eafb48dbd7de6e024e4215f9d" dependencies = [ - "proc-macro2 1.0.66", - "syn 2.0.32", + "proc-macro2", + "syn 2.0.87", ] [[package]] name = "proc-macro2" -version = "0.4.30" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf3d2011ab5c909338f7887f4fc896d35932e29146c12c8d01da6b22a80ba759" -dependencies = [ - "unicode-xid 0.1.0", -] - -[[package]] -name = "proc-macro2" -version = "1.0.66" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "0.6.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ce23b6b870e8f94f81fb0a363d65d86675884b34a09043c81e5562f11c1f8e1" -dependencies = [ - "proc-macro2 0.4.30", -] - -[[package]] -name = "quote" -version = "1.0.33" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ - "proc-macro2 1.0.66", + "proc-macro2", ] [[package]] @@ -1108,9 +1090,9 @@ checksum = "b97ed7a9823b74f99c7742f5336af7be5ecd3eeafcb1507d1fa93347b1d589b0" [[package]] name = "serde" -version = "1.0.101" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9796c9b7ba2ffe7a9ce53c2287dfc48080f4b2b362fcc245a259b3a7201119dd" +checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5" dependencies = [ "serde_derive", ] @@ -1137,13 +1119,13 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.70" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3525a779832b08693031b8ecfb0de81cd71cfd3812088fafe9a7496789572124" +checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" dependencies = [ - "proc-macro2 0.4.30", - "quote 0.6.13", - "syn 0.14.9", + "proc-macro2", + "quote", + "syn 2.0.87", ] [[package]] @@ -1178,36 +1160,25 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" -[[package]] -name = "syn" -version = "0.14.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "261ae9ecaa397c42b960649561949d69311f08eeaea86a65696e6e46517cf741" -dependencies = [ - "proc-macro2 0.4.30", - "quote 0.6.13", - "unicode-xid 0.1.0", -] - [[package]] name = "syn" version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fd9d1e9976102a03c542daa2eff1b43f9d72306342f3f8b3ed5fb8908195d6f" dependencies = [ - "proc-macro2 1.0.66", - "quote 1.0.33", - "unicode-xid 0.2.1", + "proc-macro2", + "quote", + "unicode-xid", ] [[package]] name = "syn" -version = "2.0.32" +version = "2.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" dependencies = [ - "proc-macro2 1.0.66", - "quote 1.0.33", + "proc-macro2", + "quote", "unicode-ident", ] @@ -1269,8 +1240,8 @@ version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d266c00fde287f55d3f1c3e96c500c362a2b8c695076ec180f27918820bc6df8" dependencies = [ - "proc-macro2 1.0.66", - "quote 1.0.33", + "proc-macro2", + "quote", "syn 1.0.64", ] @@ -1292,8 +1263,8 @@ version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "11c75893af559bc8e10716548bdef5cb2b983f8e637db9d0e15126b61b484ee2" dependencies = [ - "proc-macro2 1.0.66", - "quote 1.0.33", + "proc-macro2", + "quote", "syn 1.0.64", ] @@ -1357,12 +1328,6 @@ dependencies = [ "tinyvec", ] -[[package]] -name = "unicode-xid" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc72304796d0818e357ead4e000d19c9c174ab23dc11093ac919054d20a6a7fc" - [[package]] name = "unicode-xid" version = "0.2.1" @@ -1427,9 +1392,9 @@ dependencies = [ "bumpalo", "log 0.4.8", "once_cell", - "proc-macro2 1.0.66", - "quote 1.0.33", - "syn 2.0.32", + "proc-macro2", + "quote", + "syn 2.0.87", "wasm-bindgen-shared", ] @@ -1439,7 +1404,7 @@ version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e4c238561b2d428924c49815533a8b9121c664599558a5d9ec51f8a1740a999" dependencies = [ - "quote 1.0.33", + "quote", "wasm-bindgen-macro-support", ] @@ -1449,9 +1414,9 @@ version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bae1abb6806dc1ad9e560ed242107c0f6c84335f1749dd4e8ddb012ebd5e25a7" dependencies = [ - "proc-macro2 1.0.66", - "quote 1.0.33", - "syn 2.0.32", + "proc-macro2", + "quote", + "syn 2.0.87", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/mbedtls/Cargo.toml b/mbedtls/Cargo.toml index 61b3068bd..05ee0bb8a 100644 --- a/mbedtls/Cargo.toml +++ b/mbedtls/Cargo.toml @@ -24,8 +24,8 @@ features = ["x509", "ssl"] [dependencies] bitflags = "1" -serde = { version = "1.0.7", default-features = false, features = ["alloc"] } -serde_derive = "1.0.7" +serde = { version = "1.0.214", default-features = false, features = ["alloc"] } +serde_derive = "1.0.214" byteorder = { version = "1.0.0", default-features = false } yasna = { version = "0.2", optional = true, features = [ "num-bigint", diff --git a/mbedtls/src/bignum/mod.rs b/mbedtls/src/bignum/mod.rs index 0df9e7bc6..dda1c12aa 100644 --- a/mbedtls/src/bignum/mod.rs +++ b/mbedtls/src/bignum/mod.rs @@ -6,7 +6,10 @@ * option. This file may not be copied, modified, or distributed except * according to those terms. */ -use crate::error::{Error, IntoResult, Result}; + + #[cfg(feature = "std")] +use crate::error::Error; +use crate::error::{IntoResult, Result, codes}; use mbedtls_sys::*; #[cfg(not(feature = "std"))] @@ -161,7 +164,7 @@ impl Mpi { pub fn as_u32(&self) -> Result { if self.bit_length()? > 32 { // Not exactly correct but close enough - return Err(Error::MpiBufferTooSmall); + return Err(codes::MpiBufferTooSmall.into()); } Ok(self.get_limb(0) as u32) @@ -183,7 +186,7 @@ impl Mpi { let r = unsafe { mpi_write_string(&self.inner, radix, ::core::ptr::null_mut(), 0, &mut olen) }; if r != ERR_MPI_BUFFER_TOO_SMALL { - return Err(Error::from_mbedtls_code(r)); + return Err(r.into()); } let mut buf = vec![0u8; olen]; @@ -264,7 +267,7 @@ impl Mpi { let zero = Mpi::new(0)?; if self < &zero || self >= p { - return Err(Error::MpiBadInputData); + return Err(codes::MpiBadInputData.into()); } if self == &zero { return Ok(zero); @@ -273,12 +276,12 @@ impl Mpi { // This ignores p=2 (for which this algorithm is valid), as not // cryptographically interesting. if p.get_bit(0) == false || p <= &zero { - return Err(Error::MpiBadInputData); + return Err(codes::MpiBadInputData.into()); } if self.jacobi(p)? != 1 { // a is not a quadratic residue mod p - return Err(Error::MpiBadInputData); + return Err(codes::MpiBadInputData.into()); } if (p % 4)?.as_u32()? == 3 { @@ -325,7 +328,7 @@ impl Mpi { bo = bo.mod_exp(&two, p)?; m += 1; if m >= r { - return Err(Error::MpiBadInputData); + return Err(codes::MpiBadInputData.into()); } } @@ -358,7 +361,7 @@ impl Mpi { let one = Mpi::new(1)?; if self < &zero || n < &zero || n.get_bit(0) == false { - return Err(Error::MpiBadInputData); + return Err(codes::MpiBadInputData.into()); } let mut x = self.modulo(n)?; @@ -431,7 +434,7 @@ impl Mpi { pub(super) fn mpi_inner_eq_const_time(x: &mpi, y: &mpi) -> core::prelude::v1::Result { match mpi_inner_cmp_const_time(x, y) { Ok(order) => Ok(order == Ordering::Equal), - Err(Error::MpiBadInputData) => Ok(false), + Err(Error::from(codes::MpiBadInputData)) => Ok(false), Err(e) => Err(e), } } @@ -779,7 +782,7 @@ mod tests { ]) .unwrap(); assert_eq!(mpi3.less_than_const_time(&mpi3), Ok(false)); - assert_eq!(mpi2.less_than_const_time(&mpi3), Err(Error::MpiBadInputData)); + assert_eq!(mpi2.less_than_const_time(&mpi3), Err(codes::MpiBadInputData.into())); } #[test] diff --git a/mbedtls/src/cipher/raw/mod.rs b/mbedtls/src/cipher/raw/mod.rs index 2a25e3435..03d875185 100644 --- a/mbedtls/src/cipher/raw/mod.rs +++ b/mbedtls/src/cipher/raw/mod.rs @@ -8,7 +8,7 @@ use mbedtls_sys::*; -use crate::error::{Error, IntoResult, Result}; +use crate::error::{IntoResult, Result, codes}; mod serde; @@ -254,7 +254,7 @@ impl Cipher { }; if out_data.len() < required_size { - return Err(Error::CipherFullBlockExpected); + return Err(codes::CipherFullBlockExpected.into()); } let mut olen = 0; @@ -274,7 +274,7 @@ impl Cipher { pub fn finish(&mut self, out_data: &mut [u8]) -> Result { // Check that minimum required space is available in out_data buffer if out_data.len() < self.block_size() { - return Err(Error::CipherFullBlockExpected); + return Err(codes::CipherFullBlockExpected.into()); } let mut olen = 0; @@ -337,7 +337,7 @@ impl Cipher { .checked_sub(tag_len) .map_or(true, |cipher_len| cipher_len < plain.len()) { - return Err(Error::CipherBadInputData); + return Err(codes::CipherBadInputData.into()); } let iv = self.inner.iv; @@ -371,7 +371,7 @@ impl Cipher { .checked_sub(tag_len) .map_or(true, |cipher_len| plain.len() < cipher_len) { - return Err(Error::CipherBadInputData); + return Err(codes::CipherBadInputData.into()); } let iv = self.inner.iv; @@ -470,7 +470,7 @@ impl Cipher { pub fn cmac(&mut self, key: &[u8], data: &[u8], out_data: &mut [u8]) -> Result<()> { // Check that out_data buffer has enough space if out_data.len() < self.block_size() { - return Err(Error::CipherFullBlockExpected); + return Err(codes::CipherFullBlockExpected.into()); } self.reset()?; unsafe { diff --git a/mbedtls/src/ecp/mod.rs b/mbedtls/src/ecp/mod.rs index 6b6fffd7c..a69795d82 100644 --- a/mbedtls/src/ecp/mod.rs +++ b/mbedtls/src/ecp/mod.rs @@ -6,7 +6,7 @@ * option. This file may not be copied, modified, or distributed except * according to those terms. */ -use crate::error::{Error, IntoResult, Result}; +use crate::error::{Error, IntoResult, Result, codes}; use core::convert::TryFrom; use mbedtls_sys::*; @@ -107,7 +107,7 @@ impl EcGroup { || &order <= &zero || (&a == &zero && &b == &zero) { - return Err(Error::EcpBadInputData); + return Err(codes::EcpBadInputData.into()); } // Compute `order - 2`, needed below. @@ -128,7 +128,7 @@ impl EcGroup { Test that the provided generator satisfies the curve equation */ if unsafe { ecp_check_pubkey(&ret.inner, &ret.inner.G) } != 0 { - return Err(Error::EcpBadInputData); + return Err(codes::EcpBadInputData.into()); } /* @@ -155,7 +155,7 @@ impl EcGroup { let is_zero = unsafe { ecp_is_zero(&g_m.inner as *const ecp_point as *mut ecp_point) }; if is_zero != 1 { - return Err(Error::EcpBadInputData); + return Err(codes::EcpBadInputData.into()); } Ok(ret) @@ -193,7 +193,7 @@ impl EcGroup { EcGroupId::Curve25519 => Ok(8), EcGroupId::Curve448 => Ok(4), // Requires a point-counting algorithm such as SEA. - EcGroupId::None => Err(Error::EcpFeatureUnavailable), + EcGroupId::None => Err(codes::EcpFeatureUnavailable.into()), _ => Ok(1), } } @@ -206,7 +206,7 @@ impl EcGroup { match unsafe { ecp_check_pubkey(&self.inner, &point.inner) } { 0 => Ok(true), ERR_ECP_INVALID_KEY => Ok(false), - err => Err(Error::from_mbedtls_code(err)), + err => Err(err.into()), } } } @@ -249,7 +249,7 @@ impl EcPoint { } pub fn from_binary(group: &EcGroup, bin: &[u8]) -> Result { - let prefix = *bin.get(0).ok_or(Error::EcpBadInputData)?; + let prefix = *bin.get(0).ok_or(Error::from(codes::EcpBadInputData))?; if prefix == 0x02 || prefix == 0x03 { // Compressed point, which mbedtls does not understand @@ -260,7 +260,7 @@ impl EcPoint { let b = group.b()?; if bin.len() != (p.byte_length()? + 1) { - return Err(Error::EcpBadInputData); + return Err(codes::EcpBadInputData.into()); } let x = Mpi::from_binary(&bin[1..]).unwrap(); @@ -317,7 +317,7 @@ impl EcPoint { match unsafe { ecp_is_zero(&self.inner as *const ecp_point as *mut ecp_point) } { 0 => Ok(false), 1 => Ok(true), - _ => Err(Error::EcpInvalidKey), + _ => Err(codes::EcpInvalidKey.into()), } } @@ -402,11 +402,11 @@ Please use `mul_with_rng` instead." let mut ret = Self::init(); if group.contains_point(&pt1)? == false { - return Err(Error::EcpInvalidKey); + return Err(codes::EcpInvalidKey.into()); } if group.contains_point(&pt2)? == false { - return Err(Error::EcpInvalidKey); + return Err(codes::EcpInvalidKey.into()); } unsafe { @@ -430,7 +430,7 @@ Please use `mul_with_rng` instead." match r { 0 => Ok(true), ERR_ECP_BAD_INPUT_DATA => Ok(false), - x => Err(Error::from_mbedtls_code(x)), + x => Err(x.into()), } } diff --git a/mbedtls/src/error.rs b/mbedtls/src/error.rs index 690a7804e..5bc453190 100644 --- a/mbedtls/src/error.rs +++ b/mbedtls/src/error.rs @@ -6,9 +6,10 @@ * option. This file may not be copied, modified, or distributed except * according to those terms. */ -use core::convert::Infallible; use core::fmt; +use core::ops::BitOr; use core::str::Utf8Error; +use core::convert::Infallible; #[cfg(feature = "std")] use std::error::Error as StdError; @@ -23,6 +24,11 @@ pub trait IntoResult: Sized { } } +pub mod codes { + pub use super::HiError::*; + pub use super::LoError::*; +} + // This is intended not to overlap with mbedtls error codes. Utf8Error is // generated in the bindings when converting to rust UTF-8 strings. Only in rare // circumstances (callbacks from mbedtls to rust) do we need to pass a Utf8Error @@ -30,59 +36,105 @@ pub trait IntoResult: Sized { pub const ERR_UTF8_INVALID: c_int = -0x10000; macro_rules! error_enum { - {enum $n:ident {$($rust:ident = $c:ident,)*}} => { - #[derive(Debug, Eq, PartialEq)] + { + const MASK: c_int = $mask:literal; + enum $n:ident {$($rust:ident = $c:ident,)*} + } => { + #[non_exhaustive] + #[derive(Debug, Eq, PartialEq, Copy, Clone)] pub enum $n { $($rust,)* - Other(c_int), - Utf8Error(Option), - // Stable-Rust equivalent of `#[non_exhaustive]` attribute. This - // value should never be used by users of this crate! - #[doc(hidden)] - __Nonexhaustive, + Unknown(c_int) } - impl IntoResult for c_int { - fn into_result(self) -> Result { - let err_code = match self { - _ if self >= 0 => return Ok(self), - ERR_UTF8_INVALID => return Err(Error::Utf8Error(None)), - _ => -self, - }; - let (high_level_code, low_level_code) = (err_code & 0xFF80, err_code & 0x7F); - Err($n::from_mbedtls_code(if high_level_code > 0 { -high_level_code } else { -low_level_code })) + impl From for $n { + fn from(code: c_int) -> $n { + // check against mask here (not in match blook) to make it compile-time + $(const $c: c_int = $n::assert_in_mask(::mbedtls_sys::$c);)* + match -code { + $($c => return $n::$rust),*, + _ => return $n::Unknown(-code) + } } } - impl $n { - pub fn from_mbedtls_code(code: c_int) -> Self { - match code { - $(::mbedtls_sys::$c => $n::$rust),*, - _ => $n::Other(code) + impl From<$n> for c_int { + fn from(error: $n) -> c_int { + match error { + $($n::$rust => return ::mbedtls_sys::$c,)* + $n::Unknown(code) => return code, } } + } - pub fn as_str(&self) -> &'static str { - match self { - $(&$n::$rust => concat!("mbedTLS error ",stringify!($n::$rust)),)* - &$n::Other(_) => "mbedTLS unknown error", - &$n::Utf8Error(_) => "error converting to UTF-8", - &$n::__Nonexhaustive => unreachable!("__Nonexhaustive value should not be instantiated"), - } + impl $n { + const fn mask() -> c_int { + $mask + } + + const fn assert_in_mask(val: c_int) -> c_int { + assert!((-val & !Self::mask()) == 0); + val } - pub fn to_int(&self) -> c_int { - match *self { - $($n::$rust => ::mbedtls_sys::$c,)* - $n::Other(code) => code, - $n::Utf8Error(_) => ERR_UTF8_INVALID, - $n::__Nonexhaustive => unreachable!("__Nonexhaustive value should not be instantiated"), + pub fn as_str(&self)-> &'static str { + match self { + $($n::$rust => concat!("mbedTLS error ", stringify!($n::$rust)),)* + $n::Unknown(_) => concat!("mbedTLS unknown ", stringify!($n), " error") } } } }; } +#[derive(Debug, Eq, PartialEq)] +pub enum Error { + HighLevel(HiError), + LowLevel(LoError), + HighAndLowLevel(HiError, LoError), + Other(c_int), + Utf8Error(Option), +} + +impl Error { + + pub fn low_level(&self) -> Option { + match self { + Error::LowLevel(error) + | Error::HighAndLowLevel(_, error) => Some(*error), + _ => None + } + } + + pub fn high_level(&self) -> Option { + match self { + Error::HighLevel(error) + | Error::HighAndLowLevel(error, _) => Some(*error), + _ => None + } + } + + pub fn as_str(&self) -> &'static str { + match &self { + &Error::HighLevel(e) => e.as_str(), + &Error::LowLevel(e) => e.as_str(), + &Error::HighAndLowLevel(e, _) => e.as_str(), + &Error::Other(_) => "mbedTLS unknown error", + &Error::Utf8Error(_) => "error converting to UTF-8" + } + } + + pub fn to_int(&self) -> c_int { + match self { + &Error::HighLevel(error) => error.into(), + &Error::LowLevel(error) => error.into(), + &Error::HighAndLowLevel(hl_error, ll_error) => c_int::from(hl_error) + c_int::from(ll_error), + &Error::Other(error) => error, + &Error::Utf8Error(_) => ERR_UTF8_INVALID, + } + } +} + impl From for Error { fn from(e: Utf8Error) -> Error { Error::Utf8Error(Some(e)) @@ -95,57 +147,77 @@ impl From for Error { } } +impl From for Error { + fn from(error: LoError) -> Error { + Error::LowLevel(error) + } +} + +impl From for Error { + fn from(error: HiError) -> Error { + Error::HighLevel(error) + } +} + +impl BitOr for HiError { + type Output = Error; + fn bitor(self, rhs: LoError) -> Self::Output { + Error::HighAndLowLevel(self, rhs) + } +} +impl BitOr for LoError { + type Output = Error; + fn bitor(self, rhs: HiError) -> Self::Output { + Error::HighAndLowLevel(rhs, self) + } +} + +impl From for Error { + fn from(x: c_int) -> Error { + let (high_level_code, low_level_code) = (-x & HiError::mask(), -x & LoError::mask()); + if -x & (HiError::mask() | LoError::mask()) != -x || x >= 0 { + Error::Other(x) + } else if high_level_code == 0 { + Error::LowLevel(low_level_code.into()) + } else if low_level_code == 0 { + Error::HighLevel(high_level_code.into()) + } else { + Error::HighAndLowLevel(high_level_code.into(), low_level_code.into()) + } + } +} + impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - &Error::Utf8Error(Some(ref e)) => f.write_fmt(format_args!("Error converting to UTF-8: {}", e)), - &Error::Utf8Error(None) => f.write_fmt(format_args!("Error converting to UTF-8")), - &Error::Other(i) => f.write_fmt(format_args!("mbedTLS unknown error ({})", i)), - &Error::__Nonexhaustive => unreachable!("__Nonexhaustive value should not be instantiated"), - e @ _ => f.write_fmt(format_args!("mbedTLS error {:?}", e)), + &Error::Utf8Error(Some(ref e)) => { + write!(f, "Error converting to UTF-8: {}", e) + } + &Error::Utf8Error(None) => write!(f, "Error converting to UTF-8"), + &Error::LowLevel(e) => write!(f, "{}", e.as_str()), + &Error::HighLevel(e) => write!(f, "{}", e.as_str()), + &Error::HighAndLowLevel(hi, lo) => write!(f, "({}, {})", hi.as_str(), lo.as_str()), + &Error::Other(code) => write!(f, "mbedTLS unknown error code {}", code), } } } #[cfg(feature = "std")] -impl StdError for Error { - fn description(&self) -> &str { - self.as_str() +impl StdError for Error {} + +impl IntoResult for c_int { + fn into_result(self) -> Result { + match self { + 0.. => return Ok(self), + ERR_UTF8_INVALID => return Err(Error::Utf8Error(None)), + _ => return Err(Error::from(self)) + }; } } error_enum!( - enum Error { - AesBadInputData = ERR_AES_BAD_INPUT_DATA, - AesFeatureUnavailable = ERR_AES_FEATURE_UNAVAILABLE, - AesHwAccelFailed = ERR_AES_HW_ACCEL_FAILED, - AesInvalidInputLength = ERR_AES_INVALID_INPUT_LENGTH, - AesInvalidKeyLength = ERR_AES_INVALID_KEY_LENGTH, - Arc4HwAccelFailed = ERR_ARC4_HW_ACCEL_FAILED, - AriaFeatureUnavailable = ERR_ARIA_FEATURE_UNAVAILABLE, - AriaHwAccelFailed = ERR_ARIA_HW_ACCEL_FAILED, - AriaInvalidInputLength = ERR_ARIA_INVALID_INPUT_LENGTH, - Asn1AllocFailed = ERR_ASN1_ALLOC_FAILED, - Asn1BufTooSmall = ERR_ASN1_BUF_TOO_SMALL, - Asn1InvalidData = ERR_ASN1_INVALID_DATA, - Asn1InvalidLength = ERR_ASN1_INVALID_LENGTH, - Asn1LengthMismatch = ERR_ASN1_LENGTH_MISMATCH, - Asn1OutOfData = ERR_ASN1_OUT_OF_DATA, - Asn1UnexpectedTag = ERR_ASN1_UNEXPECTED_TAG, - Base64BufferTooSmall = ERR_BASE64_BUFFER_TOO_SMALL, - Base64InvalidCharacter = ERR_BASE64_INVALID_CHARACTER, - BlowfishHwAccelFailed = ERR_BLOWFISH_HW_ACCEL_FAILED, - BlowfishInvalidInputLength = ERR_BLOWFISH_INVALID_INPUT_LENGTH, - CamelliaHwAccelFailed = ERR_CAMELLIA_HW_ACCEL_FAILED, - CamelliaInvalidInputLength = ERR_CAMELLIA_INVALID_INPUT_LENGTH, - CcmAuthFailed = ERR_CCM_AUTH_FAILED, - CcmBadInput = ERR_CCM_BAD_INPUT, - CcmHwAccelFailed = ERR_CCM_HW_ACCEL_FAILED, - Chacha20BadInputData = ERR_CHACHA20_BAD_INPUT_DATA, - Chacha20FeatureUnavailable = ERR_CHACHA20_FEATURE_UNAVAILABLE, - Chacha20HwAccelFailed = ERR_CHACHA20_HW_ACCEL_FAILED, - ChachapolyAuthFailed = ERR_CHACHAPOLY_AUTH_FAILED, - ChachapolyBadState = ERR_CHACHAPOLY_BAD_STATE, + const MASK: c_int = 0x7F80; + enum HiError { CipherAllocFailed = ERR_CIPHER_ALLOC_FAILED, CipherAuthFailed = ERR_CIPHER_AUTH_FAILED, CipherBadInputData = ERR_CIPHER_BAD_INPUT_DATA, @@ -154,13 +226,6 @@ error_enum!( CipherHwAccelFailed = ERR_CIPHER_HW_ACCEL_FAILED, CipherInvalidContext = ERR_CIPHER_INVALID_CONTEXT, CipherInvalidPadding = ERR_CIPHER_INVALID_PADDING, - CmacHwAccelFailed = ERR_CMAC_HW_ACCEL_FAILED, - CtrDrbgEntropySourceFailed = ERR_CTR_DRBG_ENTROPY_SOURCE_FAILED, - CtrDrbgFileIoError = ERR_CTR_DRBG_FILE_IO_ERROR, - CtrDrbgInputTooBig = ERR_CTR_DRBG_INPUT_TOO_BIG, - CtrDrbgRequestTooBig = ERR_CTR_DRBG_REQUEST_TOO_BIG, - DesHwAccelFailed = ERR_DES_HW_ACCEL_FAILED, - DesInvalidInputLength = ERR_DES_INVALID_INPUT_LENGTH, DhmAllocFailed = ERR_DHM_ALLOC_FAILED, DhmBadInputData = ERR_DHM_BAD_INPUT_DATA, DhmCalcSecretFailed = ERR_DHM_CALC_SECRET_FAILED, @@ -181,51 +246,12 @@ error_enum!( EcpRandomFailed = ERR_ECP_RANDOM_FAILED, EcpSigLenMismatch = ERR_ECP_SIG_LEN_MISMATCH, EcpVerifyFailed = ERR_ECP_VERIFY_FAILED, - EntropyFileIoError = ERR_ENTROPY_FILE_IO_ERROR, - EntropyMaxSources = ERR_ENTROPY_MAX_SOURCES, - EntropyNoSourcesDefined = ERR_ENTROPY_NO_SOURCES_DEFINED, - EntropyNoStrongSource = ERR_ENTROPY_NO_STRONG_SOURCE, - EntropySourceFailed = ERR_ENTROPY_SOURCE_FAILED, - GcmAuthFailed = ERR_GCM_AUTH_FAILED, - GcmBadInput = ERR_GCM_BAD_INPUT, - GcmHwAccelFailed = ERR_GCM_HW_ACCEL_FAILED, HkdfBadInputData = ERR_HKDF_BAD_INPUT_DATA, - HmacDrbgEntropySourceFailed = ERR_HMAC_DRBG_ENTROPY_SOURCE_FAILED, - HmacDrbgFileIoError = ERR_HMAC_DRBG_FILE_IO_ERROR, - HmacDrbgInputTooBig = ERR_HMAC_DRBG_INPUT_TOO_BIG, - HmacDrbgRequestTooBig = ERR_HMAC_DRBG_REQUEST_TOO_BIG, - Md2HwAccelFailed = ERR_MD2_HW_ACCEL_FAILED, - Md4HwAccelFailed = ERR_MD4_HW_ACCEL_FAILED, - Md5HwAccelFailed = ERR_MD5_HW_ACCEL_FAILED, MdAllocFailed = ERR_MD_ALLOC_FAILED, MdBadInputData = ERR_MD_BAD_INPUT_DATA, MdFeatureUnavailable = ERR_MD_FEATURE_UNAVAILABLE, MdFileIoError = ERR_MD_FILE_IO_ERROR, MdHwAccelFailed = ERR_MD_HW_ACCEL_FAILED, - MpiAllocFailed = ERR_MPI_ALLOC_FAILED, - MpiBadInputData = ERR_MPI_BAD_INPUT_DATA, - MpiBufferTooSmall = ERR_MPI_BUFFER_TOO_SMALL, - MpiDivisionByZero = ERR_MPI_DIVISION_BY_ZERO, - MpiFileIoError = ERR_MPI_FILE_IO_ERROR, - MpiInvalidCharacter = ERR_MPI_INVALID_CHARACTER, - MpiNegativeValue = ERR_MPI_NEGATIVE_VALUE, - MpiNotAcceptable = ERR_MPI_NOT_ACCEPTABLE, - NetAcceptFailed = ERR_NET_ACCEPT_FAILED, - NetBadInputData = ERR_NET_BAD_INPUT_DATA, - NetBindFailed = ERR_NET_BIND_FAILED, - NetBufferTooSmall = ERR_NET_BUFFER_TOO_SMALL, - NetConnReset = ERR_NET_CONN_RESET, - NetConnectFailed = ERR_NET_CONNECT_FAILED, - NetInvalidContext = ERR_NET_INVALID_CONTEXT, - NetListenFailed = ERR_NET_LISTEN_FAILED, - NetPollFailed = ERR_NET_POLL_FAILED, - NetRecvFailed = ERR_NET_RECV_FAILED, - NetSendFailed = ERR_NET_SEND_FAILED, - NetSocketFailed = ERR_NET_SOCKET_FAILED, - NetUnknownHost = ERR_NET_UNKNOWN_HOST, - OidBufTooSmall = ERR_OID_BUF_TOO_SMALL, - OidNotFound = ERR_OID_NOT_FOUND, - PadlockDataMisaligned = ERR_PADLOCK_DATA_MISALIGNED, PemAllocFailed = ERR_PEM_ALLOC_FAILED, PemBadInputData = ERR_PEM_BAD_INPUT_DATA, PemFeatureUnavailable = ERR_PEM_FEATURE_UNAVAILABLE, @@ -258,10 +284,6 @@ error_enum!( Pkcs5FeatureUnavailable = ERR_PKCS5_FEATURE_UNAVAILABLE, Pkcs5InvalidFormat = ERR_PKCS5_INVALID_FORMAT, Pkcs5PasswordMismatch = ERR_PKCS5_PASSWORD_MISMATCH, - Poly1305BadInputData = ERR_POLY1305_BAD_INPUT_DATA, - Poly1305FeatureUnavailable = ERR_POLY1305_FEATURE_UNAVAILABLE, - Poly1305HwAccelFailed = ERR_POLY1305_HW_ACCEL_FAILED, - Ripemd160HwAccelFailed = ERR_RIPEMD160_HW_ACCEL_FAILED, RsaBadInputData = ERR_RSA_BAD_INPUT_DATA, RsaHwAccelFailed = ERR_RSA_HW_ACCEL_FAILED, RsaInvalidPadding = ERR_RSA_INVALID_PADDING, @@ -273,9 +295,6 @@ error_enum!( RsaRngFailed = ERR_RSA_RNG_FAILED, RsaUnsupportedOperation = ERR_RSA_UNSUPPORTED_OPERATION, RsaVerifyFailed = ERR_RSA_VERIFY_FAILED, - Sha1HwAccelFailed = ERR_SHA1_HW_ACCEL_FAILED, - Sha256HwAccelFailed = ERR_SHA256_HW_ACCEL_FAILED, - Sha512HwAccelFailed = ERR_SHA512_HW_ACCEL_FAILED, SslAllocFailed = ERR_SSL_ALLOC_FAILED, SslAsyncInProgress = ERR_SSL_ASYNC_IN_PROGRESS, SslBadHsCertificate = ERR_SSL_BAD_HS_CERTIFICATE, @@ -353,3 +372,128 @@ error_enum!( XteaInvalidInputLength = ERR_XTEA_INVALID_INPUT_LENGTH, } ); + +error_enum!( + const MASK: c_int = 0x7F; + enum LoError { + AesBadInputData = ERR_AES_BAD_INPUT_DATA, + AesInvalidInputLength = ERR_AES_INVALID_INPUT_LENGTH, + AesInvalidKeyLength = ERR_AES_INVALID_KEY_LENGTH, + AriaBadInputData = ERR_ARIA_BAD_INPUT_DATA, + AriaInvalidInputLength = ERR_ARIA_INVALID_INPUT_LENGTH, + Asn1AllocFailed = ERR_ASN1_ALLOC_FAILED, + Asn1BufTooSmall = ERR_ASN1_BUF_TOO_SMALL, + Asn1InvalidData = ERR_ASN1_INVALID_DATA, + Asn1InvalidLength = ERR_ASN1_INVALID_LENGTH, + Asn1LengthMismatch = ERR_ASN1_LENGTH_MISMATCH, + Asn1OutOfData = ERR_ASN1_OUT_OF_DATA, + Asn1UnexpectedTag = ERR_ASN1_UNEXPECTED_TAG, + Base64BufferTooSmall = ERR_BASE64_BUFFER_TOO_SMALL, + Base64InvalidCharacter = ERR_BASE64_INVALID_CHARACTER, + CamelliaBadInputData = ERR_CAMELLIA_BAD_INPUT_DATA, + CamelliaInvalidInputLength = ERR_CAMELLIA_INVALID_INPUT_LENGTH, + CcmAuthFailed = ERR_CCM_AUTH_FAILED, + CcmBadInput = ERR_CCM_BAD_INPUT, + Chacha20BadInputData = ERR_CHACHA20_BAD_INPUT_DATA, + ChachapolyAuthFailed = ERR_CHACHAPOLY_AUTH_FAILED, + ChachapolyBadState = ERR_CHACHAPOLY_BAD_STATE, + CtrDrbgEntropySourceFailed = ERR_CTR_DRBG_ENTROPY_SOURCE_FAILED, + CtrDrbgFileIoError = ERR_CTR_DRBG_FILE_IO_ERROR, + CtrDrbgInputTooBig = ERR_CTR_DRBG_INPUT_TOO_BIG, + CtrDrbgRequestTooBig = ERR_CTR_DRBG_REQUEST_TOO_BIG, + DesInvalidInputLength = ERR_DES_INVALID_INPUT_LENGTH, + EntropyFileIoError = ERR_ENTROPY_FILE_IO_ERROR, + EntropyMaxSources = ERR_ENTROPY_MAX_SOURCES, + EntropyNoSourcesDefined = ERR_ENTROPY_NO_SOURCES_DEFINED, + EntropyNoStrongSource = ERR_ENTROPY_NO_STRONG_SOURCE, + EntropySourceFailed = ERR_ENTROPY_SOURCE_FAILED, + ErrorCorruptionDetected = ERR_ERROR_CORRUPTION_DETECTED, + ErrorGenericError = ERR_ERROR_GENERIC_ERROR, + GcmAuthFailed = ERR_GCM_AUTH_FAILED, + GcmBadInput = ERR_GCM_BAD_INPUT, + GcmBufferTooSmall = ERR_GCM_BUFFER_TOO_SMALL, + HmacDrbgEntropySourceFailed = ERR_HMAC_DRBG_ENTROPY_SOURCE_FAILED, + HmacDrbgFileIoError = ERR_HMAC_DRBG_FILE_IO_ERROR, + HmacDrbgInputTooBig = ERR_HMAC_DRBG_INPUT_TOO_BIG, + HmacDrbgRequestTooBig = ERR_HMAC_DRBG_REQUEST_TOO_BIG, + LmsAllocFailed = ERR_LMS_ALLOC_FAILED, + LmsBadInputData = ERR_LMS_BAD_INPUT_DATA, + LmsBufferTooSmall = ERR_LMS_BUFFER_TOO_SMALL, + LmsOutOfPrivateKeys = ERR_LMS_OUT_OF_PRIVATE_KEYS, + LmsVerifyFailed = ERR_LMS_VERIFY_FAILED, + MpiAllocFailed = ERR_MPI_ALLOC_FAILED, + MpiBadInputData = ERR_MPI_BAD_INPUT_DATA, + MpiBufferTooSmall = ERR_MPI_BUFFER_TOO_SMALL, + MpiDivisionByZero = ERR_MPI_DIVISION_BY_ZERO, + MpiFileIoError = ERR_MPI_FILE_IO_ERROR, + MpiInvalidCharacter = ERR_MPI_INVALID_CHARACTER, + MpiNegativeValue = ERR_MPI_NEGATIVE_VALUE, + MpiNotAcceptable = ERR_MPI_NOT_ACCEPTABLE, + NetAcceptFailed = ERR_NET_ACCEPT_FAILED, + NetBadInputData = ERR_NET_BAD_INPUT_DATA, + NetBindFailed = ERR_NET_BIND_FAILED, + NetBufferTooSmall = ERR_NET_BUFFER_TOO_SMALL, + NetConnectFailed = ERR_NET_CONNECT_FAILED, + NetConnReset = ERR_NET_CONN_RESET, + NetInvalidContext = ERR_NET_INVALID_CONTEXT, + NetListenFailed = ERR_NET_LISTEN_FAILED, + NetPollFailed = ERR_NET_POLL_FAILED, + NetRecvFailed = ERR_NET_RECV_FAILED, + NetSendFailed = ERR_NET_SEND_FAILED, + NetSocketFailed = ERR_NET_SOCKET_FAILED, + NetUnknownHost = ERR_NET_UNKNOWN_HOST, + OidBufTooSmall = ERR_OID_BUF_TOO_SMALL, + OidNotFound = ERR_OID_NOT_FOUND, + PlatformFeatureUnsupported = ERR_PLATFORM_FEATURE_UNSUPPORTED, + PlatformHwAccelFailed = ERR_PLATFORM_HW_ACCEL_FAILED, + Poly1305BadInputData = ERR_POLY1305_BAD_INPUT_DATA, + Sha1BadInputData = ERR_SHA1_BAD_INPUT_DATA, + Sha256BadInputData = ERR_SHA256_BAD_INPUT_DATA, + Sha512BadInputData = ERR_SHA512_BAD_INPUT_DATA, + ThreadingBadInputData = ERR_THREADING_BAD_INPUT_DATA, + ThreadingMutexError = ERR_THREADING_MUTEX_ERROR, + } +); + +#[cfg(test)] +mod tests { + use super::{Error, codes, HiError, LoError}; + + #[test] + fn test_common_error_operations() { + let (hi, lo) = (codes::CipherAllocFailed, codes::AesBadInputData); + let (hi_only_error, lo_only_error, combined_error) = (Error::HighLevel(hi), Error::LowLevel(lo), Error::HighAndLowLevel(hi, lo)); + assert_eq!(combined_error.high_level().unwrap(), hi); + assert_eq!(combined_error.low_level().unwrap(), lo); + assert_eq!(hi_only_error.to_int(), -24960); + assert_eq!(lo_only_error.to_int(), -33); + assert_eq!(combined_error.to_int(), hi_only_error.to_int() + lo_only_error.to_int()); + assert_eq!(codes::CipherAllocFailed | codes::AesBadInputData, combined_error); + assert_eq!(codes::AesBadInputData | codes::CipherAllocFailed, combined_error); + } + + #[test] + fn test_error_display() { + let (hi, lo) = (HiError::CipherAllocFailed, LoError::AesBadInputData); + let (hi_only_error, lo_only_error, combined_error) = (Error::HighLevel(hi), Error::LowLevel(lo), Error::HighAndLowLevel(hi, lo)); + assert_eq!(format!("{}", hi_only_error), "mbedTLS error HiError :: CipherAllocFailed"); + assert_eq!(format!("{}", lo_only_error), "mbedTLS error LoError :: AesBadInputData"); + assert_eq!(format!("{}", combined_error), "(mbedTLS error HiError :: CipherAllocFailed, mbedTLS error LoError :: AesBadInputData)"); + } + + #[test] + fn test_error_from_int() { + // positive error code + assert_eq!(Error::from(0), Error::Other(0)); + assert_eq!(Error::from(1), Error::Other(1)); + // Lo, Hi, HiAndLo cases + assert_eq!(Error::from(-1), Error::LowLevel(LoError::ErrorGenericError)); + assert_eq!(Error::from(-0x80), Error::HighLevel(HiError::Unknown(-0x80))); + assert_eq!(Error::from(-0x81), Error::HighAndLowLevel(HiError::Unknown(-0x80), LoError::ErrorGenericError)); + assert_eq!(Error::from(-24993), Error::HighAndLowLevel(HiError::CipherAllocFailed, LoError::AesBadInputData)); + assert_eq!(Error::from(-24960), Error::HighLevel(HiError::CipherAllocFailed)); + assert_eq!(Error::from(-33), Error::LowLevel(LoError::AesBadInputData )); + // error code out of boundaries + assert_eq!(Error::from(-0x01FFFF), Error::Other(-0x01FFFF)); + } +} diff --git a/mbedtls/src/hash/mod.rs b/mbedtls/src/hash/mod.rs index c50f09911..05cf4a627 100644 --- a/mbedtls/src/hash/mod.rs +++ b/mbedtls/src/hash/mod.rs @@ -6,7 +6,7 @@ * option. This file may not be copied, modified, or distributed except * according to those terms. */ -use crate::error::{Error, IntoResult, Result}; +use crate::error::{IntoResult, Result, codes}; use mbedtls_sys::*; define!( @@ -97,7 +97,7 @@ impl Md { pub fn new(md: Type) -> Result { let md: MdInfo = match md.into() { Some(md) => md, - None => return Err(Error::MdBadInputData), + None => return Err(codes::MdBadInputData.into()), }; let mut ctx = Md::init(); @@ -117,7 +117,7 @@ impl Md { unsafe { let olen = (*self.inner.md_info).size as usize; if out.len() < olen { - return Err(Error::MdBadInputData); + return Err(codes::MdBadInputData.into()); } md_finish(&mut self.inner, out.as_mut_ptr()).into_result()?; Ok(olen) @@ -127,13 +127,13 @@ impl Md { pub fn hash(mdt: Type, data: &[u8], out: &mut [u8]) -> Result { let mdinfo: MdInfo = match mdt.into() { Some(md) => md, - None => return Err(Error::MdBadInputData), + None => return Err(codes::MdBadInputData.into()), }; unsafe { let olen = mdinfo.inner.size as usize; if out.len() < olen { - return Err(Error::MdBadInputData); + return Err(codes::MdBadInputData.into()); } md(mdinfo.inner, data.as_ptr(), data.len(), out.as_mut_ptr()).into_result()?; Ok(olen) @@ -150,7 +150,7 @@ impl Hmac { pub fn new(md: Type, key: &[u8]) -> Result { let md: MdInfo = match md.into() { Some(md) => md, - None => return Err(Error::MdBadInputData), + None => return Err(codes::MdBadInputData.into()), }; let mut ctx = Md::init(); @@ -170,7 +170,7 @@ impl Hmac { unsafe { let olen = (*self.ctx.inner.md_info).size as usize; if out.len() < olen { - return Err(Error::MdBadInputData); + return Err(codes::MdBadInputData.into()); } md_hmac_finish(&mut self.ctx.inner, out.as_mut_ptr()).into_result()?; Ok(olen) @@ -180,13 +180,13 @@ impl Hmac { pub fn hmac(md: Type, key: &[u8], data: &[u8], out: &mut [u8]) -> Result { let md: MdInfo = match md.into() { Some(md) => md, - None => return Err(Error::MdBadInputData), + None => return Err(codes::MdBadInputData.into()), }; unsafe { let olen = md.inner.size as usize; if out.len() < olen { - return Err(Error::MdBadInputData); + return Err(codes::MdBadInputData.into()); } md_hmac(md.inner, key.as_ptr(), key.len(), data.as_ptr(), data.len(), out.as_mut_ptr()).into_result()?; Ok(olen) @@ -221,7 +221,7 @@ impl Hkdf { pub fn hkdf(md: Type, salt: &[u8], ikm: &[u8], info: &[u8], okm: &mut [u8]) -> Result<()> { let md: MdInfo = match md.into() { Some(md) => md, - None => return Err(Error::MdBadInputData), + None => return Err(codes::MdBadInputData.into()), }; unsafe { @@ -265,7 +265,7 @@ impl Hkdf { pub fn hkdf_optional_salt(md: Type, maybe_salt: Option<&[u8]>, ikm: &[u8], info: &[u8], okm: &mut [u8]) -> Result<()> { let md: MdInfo = match md.into() { Some(md) => md, - None => return Err(Error::MdBadInputData), + None => return Err(codes::MdBadInputData.into()), }; unsafe { @@ -314,7 +314,7 @@ impl Hkdf { pub fn hkdf_extract(md: Type, maybe_salt: Option<&[u8]>, ikm: &[u8], prk: &mut [u8]) -> Result<()> { let md: MdInfo = match md.into() { Some(md) => md, - None => return Err(Error::MdBadInputData), + None => return Err(codes::MdBadInputData.into()), }; unsafe { @@ -360,7 +360,7 @@ impl Hkdf { pub fn hkdf_expand(md: Type, prk: &[u8], info: &[u8], okm: &mut [u8]) -> Result<()> { let md: MdInfo = match md.into() { Some(md) => md, - None => return Err(Error::MdBadInputData), + None => return Err(codes::MdBadInputData.into()), }; unsafe { @@ -382,7 +382,7 @@ impl Hkdf { pub fn pbkdf2_hmac(md: Type, password: &[u8], salt: &[u8], iterations: u32, key: &mut [u8]) -> Result<()> { let md: MdInfo = match md.into() { Some(md) => md, - None => return Err(Error::MdBadInputData), + None => return Err(codes::MdBadInputData.into()), }; unsafe { diff --git a/mbedtls/src/lib.rs b/mbedtls/src/lib.rs index 7a8620528..f05ffeefe 100644 --- a/mbedtls/src/lib.rs +++ b/mbedtls/src/lib.rs @@ -27,7 +27,7 @@ mod wrapper_macros; // API // ============== pub mod bignum; -mod error; +pub mod error; pub use crate::error::{Error, Result}; pub mod cipher; pub mod ecp; diff --git a/mbedtls/src/pk/dsa/mod.rs b/mbedtls/src/pk/dsa/mod.rs index 1fc1d82e8..9c970ceaa 100644 --- a/mbedtls/src/pk/dsa/mod.rs +++ b/mbedtls/src/pk/dsa/mod.rs @@ -9,8 +9,9 @@ use crate::bignum::Mpi; use crate::hash::{MdInfo, Type as MdType}; use crate::pk::rfc6979::generate_rfc6979_nonce; -use crate::rng::Random; -use crate::{Error, Result}; +#[cfg(not(feature = "std"))] +use crate::Error; +use crate::{Result, error::codes}; use bit_vec::BitVec; use num_bigint::BigUint; @@ -27,11 +28,11 @@ pub struct DsaParams { impl DsaParams { pub fn from_components(p: Mpi, q: Mpi, g: Mpi) -> Result { if g > p || q > p { - return Err(Error::PkBadInputData); + return Err(codes::PkBadInputData.into()); } if p.modulo(&q)? != Mpi::new(1)? { - return Err(Error::PkBadInputData); + return Err(codes::PkBadInputData.into()); } Ok(Self { p, q, g }) @@ -65,11 +66,11 @@ const DSA_OBJECT_IDENTIFIER: &[u64] = &[1, 2, 840, 10040, 4, 1]; impl DsaPublicKey { pub fn from_components(params: DsaParams, y: Mpi) -> Result { if y < Mpi::new(1)? || y >= params.p { - return Err(Error::PkBadInputData); + return Err(codes::PkBadInputData.into()); } // Verify that y is of order q modulo p if y.mod_exp(¶ms.q, ¶ms.p)? != Mpi::new(1)? { - return Err(Error::PkBadInputData); + return Err(codes::PkBadInputData.into()); } Ok(Self { params, y }) } @@ -93,9 +94,9 @@ impl DsaPublicKey { Ok((p, q, g, y)) }) }) - .map_err(|_| Error::PkInvalidPubkey)?; + .map_err(|_| codes::PkInvalidPubkey)?; - let y = yasna::parse_der(&y.to_bytes(), |r| r.read_biguint()).map_err(|_| Error::PkInvalidPubkey)?; + let y = yasna::parse_der(&y.to_bytes(), |r| r.read_biguint()).map_err(|_| codes::PkInvalidPubkey)?; let p = Mpi::from_binary(&p.to_bytes_be()).expect("Success"); let q = Mpi::from_binary(&q.to_bytes_be()).expect("Success"); @@ -140,7 +141,7 @@ impl DsaPublicKey { Ok((r, s)) }) }) - .map_err(|_| Error::X509InvalidSignature)?; + .map_err(|_| codes::X509InvalidSignature)?; let r = Mpi::from_binary(&r.to_bytes_be()).expect("Success"); let s = Mpi::from_binary(&s.to_bytes_be()).expect("Success"); @@ -152,14 +153,14 @@ impl DsaPublicKey { let zero = Mpi::new(0)?; if r <= &zero || s <= &zero { - return Err(Error::X509InvalidSignature); + return Err(codes::X509InvalidSignature.into()); } let p = &self.params.p; let q = &self.params.q; if r >= q || s >= q { - return Err(Error::X509InvalidSignature); + return Err(codes::X509InvalidSignature.into()); } let m = reduce_mod_q(pre_hashed_message, q)?; @@ -176,7 +177,7 @@ impl DsaPublicKey { let gsm_ysr = (&gsm * &ysr)?.modulo(p)?; if &gsm_ysr.modulo(q)? != r { - return Err(Error::X509InvalidSignature); + return Err(codes::X509InvalidSignature.into()); } Ok(()) @@ -225,7 +226,7 @@ fn encode_dsa_signature(r: &Mpi, s: &Mpi) -> Result> { impl DsaPrivateKey { pub fn from_components(params: DsaParams, x: Mpi) -> Result { if x <= Mpi::new(1)? || x >= params.q { - return Err(Error::PkBadInputData); + return Err(codes::PkBadInputData.into()); } Ok(Self { params, x }) } @@ -257,9 +258,9 @@ impl DsaPrivateKey { Ok((p, q, g, x)) }) }) - .map_err(|_| Error::PkInvalidPubkey)?; + .map_err(|_| codes::PkInvalidPubkey)?; - let x = yasna::parse_der(&x, |r| r.read_biguint()).map_err(|_| Error::PkInvalidPubkey)?; + let x = yasna::parse_der(&x, |r| r.read_biguint()).map_err(|_| codes::PkInvalidPubkey)?; let p = Mpi::from_binary(&p.to_bytes_be()).expect("Success"); let q = Mpi::from_binary(&q.to_bytes_be()).expect("Success"); @@ -345,7 +346,7 @@ impl DsaPrivateKey { let zero = Mpi::new(0)?; if r == zero || s == zero { - return Err(Error::MpiBadInputData); + return Err(codes::MpiBadInputData.into()); } encode_dsa_signature(&r, &s) } diff --git a/mbedtls/src/pk/ec.rs b/mbedtls/src/pk/ec.rs index 2cc8a0151..0f35a6b90 100644 --- a/mbedtls/src/pk/ec.rs +++ b/mbedtls/src/pk/ec.rs @@ -9,7 +9,7 @@ use mbedtls_sys::ECDSA_MAX_LEN as MBEDTLS_ECDSA_MAX_LEN; use mbedtls_sys::*; -use crate::error::{Error, IntoResult, Result}; +use crate::error::{IntoResult, Result, codes}; define!( #[c_ty(ecp_group_id)] @@ -80,7 +80,7 @@ define!( impl Ecdh { pub fn from_keys(private: &EcpKeypair, public: &EcpKeypair) -> Result { if public.inner.grp.id == ECP_DP_NONE || public.inner.grp.id != private.inner.grp.id { - return Err(Error::EcpBadInputData); + return Err(codes::EcpBadInputData.into()); } let mut ret = Self::init(); diff --git a/mbedtls/src/pk/mod.rs b/mbedtls/src/pk/mod.rs index 63c77575e..c60685e78 100644 --- a/mbedtls/src/pk/mod.rs +++ b/mbedtls/src/pk/mod.rs @@ -12,7 +12,7 @@ use mbedtls_sys::*; use mbedtls_sys::types::raw_types::c_void; -use crate::error::{Error, IntoResult, Result}; +use crate::error::{Error, IntoResult, Result, codes}; use crate::hash::Type as MdType; use crate::private::UnsafeFrom; use crate::rng::Random; @@ -463,7 +463,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn custom_algo_id(&self) -> Result<&[u64]> { if self.pk_type() != Type::Custom { - return Err(Error::PkInvalidAlg); + return Err(codes::PkInvalidAlg.into()); } unsafe { @@ -474,7 +474,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn custom_public_key(&self) -> Result<&[u8]> { if self.pk_type() != Type::Custom { - return Err(Error::PkInvalidAlg); + return Err(codes::PkInvalidAlg.into()); } let ctx = self.inner.pk_ctx as *const CustomPkContext; @@ -483,13 +483,13 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn custom_private_key(&self) -> Result<&[u8]> { if self.pk_type() != Type::Custom { - return Err(Error::PkInvalidAlg); + return Err(codes::PkInvalidAlg.into()); } let ctx = self.inner.pk_ctx as *const CustomPkContext; unsafe { if (*ctx).sk.len() == 0 { - return Err(Error::PkTypeMismatch); + return Err(codes::PkTypeMismatch.into()); } Ok(&(*ctx).sk) } @@ -537,7 +537,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn curve(&self) -> Result { match self.pk_type() { Type::Eckey | Type::EckeyDh | Type::Ecdsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } unsafe { Ok((*(self.inner.pk_ctx as *const ecp_keypair)).grp.id.into()) } @@ -556,14 +556,14 @@ Please use `private_from_ec_scalar_with_rng` instead." EcGroupId::SecP256R1 => Ok(vec![1, 2, 840, 10045, 3, 1, 7]), EcGroupId::SecP384R1 => Ok(vec![1, 3, 132, 0, 34]), EcGroupId::SecP521R1 => Ok(vec![1, 3, 132, 0, 35]), - _ => Err(Error::OidNotFound), + _ => Err(codes::OidNotFound.into()), } } pub fn ec_group(&self) -> Result { match self.pk_type() { Type::Eckey | Type::EckeyDh | Type::Ecdsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } match self.curve()? { @@ -588,7 +588,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn ec_public(&self) -> Result { match self.pk_type() { Type::Eckey | Type::EckeyDh | Type::Ecdsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } let q = &unsafe { (*(self.inner.pk_ctx as *const ecp_keypair)).Q }; @@ -598,7 +598,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn ec_private(&self) -> Result { match self.pk_type() { Type::Eckey | Type::EckeyDh | Type::Ecdsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } let d = &unsafe { (*(self.inner.pk_ctx as *const ecp_keypair)).d }; @@ -608,7 +608,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn rsa_public_modulus(&self) -> Result { match self.pk_type() { Type::Rsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } let mut n = Mpi::new(0)?; @@ -631,7 +631,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn rsa_private_prime1(&self) -> Result { match self.pk_type() { Type::Rsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } let mut p = Mpi::new(0)?; @@ -654,7 +654,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn rsa_private_prime2(&self) -> Result { match self.pk_type() { Type::Rsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } let mut q = Mpi::new(0)?; @@ -677,7 +677,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn rsa_private_exponent(&self) -> Result { match self.pk_type() { Type::Rsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } let mut d = Mpi::new(0)?; @@ -700,7 +700,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn rsa_crt_dp(&self) -> Result { match self.pk_type() { Type::Rsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } let mut dp = Mpi::new(0)?; @@ -721,7 +721,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn rsa_crt_dq(&self) -> Result { match self.pk_type() { Type::Rsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } let mut dq = Mpi::new(0)?; @@ -742,7 +742,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn rsa_crt_qp(&self) -> Result { match self.pk_type() { Type::Rsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } let mut qp = Mpi::new(0)?; @@ -763,7 +763,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn rsa_public_exponent(&self) -> Result { match self.pk_type() { Type::Rsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } let mut e: [u8; 4] = [0, 0, 0, 0]; @@ -797,14 +797,14 @@ Please use `private_from_ec_scalar_with_rng` instead." if unsafe { (*ctx).padding == RAW_RSA_DECRYPT } { let olen = self.len() / 8; if plain.len() < olen { - return Err(Error::RsaOutputTooLarge); + return Err(codes::RsaOutputTooLarge.into()); } // Don't process outside of {2, ..., n-2} let nm1 = self.rsa_public_modulus()?.sub(&Mpi::new(1)?)?; let c_mpi = Mpi::from_binary(cipher)?; if c_mpi <= Mpi::new(1).unwrap() || c_mpi >= nm1 { - return Err(Error::MpiBadInputData); + return Err(codes::MpiBadInputData.into()); } unsafe { @@ -843,11 +843,11 @@ Please use `private_from_ec_scalar_with_rng` instead." label: &[u8], ) -> Result { if self.pk_type() != Type::Rsa { - return Err(Error::PkTypeMismatch); + return Err(codes::PkTypeMismatch.into()); } let ctx = self.inner.pk_ctx as *mut rsa_context; if unsafe { (*ctx).padding != RSA_PKCS_V21 } { - return Err(Error::RsaInvalidPadding); + return Err(codes::RsaInvalidPadding.into()); } let mut ret = 0usize; @@ -899,15 +899,15 @@ Please use `private_from_ec_scalar_with_rng` instead." label: &[u8], ) -> Result { if self.pk_type() != Type::Rsa { - return Err(Error::PkTypeMismatch); + return Err(codes::PkTypeMismatch.into()); } let ctx = self.inner.pk_ctx as *mut rsa_context; if unsafe { (*ctx).padding != RSA_PKCS_V21 } { - return Err(Error::RsaInvalidPadding); + return Err(codes::RsaInvalidPadding.into()); } let olen = self.len() / 8; if cipher.len() < olen { - return Err(Error::RsaOutputTooLarge); + return Err(codes::RsaOutputTooLarge.into()); } unsafe { @@ -943,21 +943,21 @@ Please use `private_from_ec_scalar_with_rng` instead." // If hash or sig are allowed with size 0 (&[]) then mbedtls will attempt to // auto-detect size and cause an invalid write. if hash.len() == 0 || sig.len() == 0 { - return Err(Error::PkBadInputData); + return Err(codes::PkBadInputData.into()); } match self.pk_type() { Type::Rsa | Type::RsaAlt | Type::RsassaPss => { if sig.len() < (self.len() / 8) { - return Err(Error::PkSigLenMismatch); + return Err(codes::PkSigLenMismatch.into()); } } Type::Eckey | Type::Ecdsa => { if sig.len() < ECDSA_MAX_LEN { - return Err(Error::PkSigLenMismatch); + return Err(codes::PkSigLenMismatch.into()); } } - _ => return Err(Error::PkSigLenMismatch), + _ => return Err(codes::PkSigLenMismatch.into()), } let mut ret = 0usize; unsafe { @@ -980,14 +980,14 @@ Please use `private_from_ec_scalar_with_rng` instead." // If hash or sig are allowed with size 0 (&[]) then mbedtls will attempt to // auto-detect size and cause an invalid write. if hash.len() == 0 || sig.len() == 0 { - return Err(Error::PkBadInputData); + return Err(codes::PkBadInputData.into()); } use crate::rng::RngCallbackMut; if self.pk_type() == Type::Ecdsa || self.pk_type() == Type::Eckey { if sig.len() < ECDSA_MAX_LEN { - return Err(Error::PkSigLenMismatch); + return Err(codes::PkSigLenMismatch.into()); } // RFC 6979 signature scheme @@ -1017,7 +1017,7 @@ Please use `private_from_ec_scalar_with_rng` instead." } else if self.pk_type() == Type::Rsa { // Reject sign_deterministic being use for PSS if unsafe { (*(self.inner.pk_ctx as *mut rsa_context)).padding } != RSA_PKCS_V15 { - return Err(Error::PkInvalidAlg); + return Err(codes::PkInvalidAlg.into()); } // This is a PKCSv1.5 signature which is already deterministic; just pass it to @@ -1025,7 +1025,7 @@ Please use `private_from_ec_scalar_with_rng` instead." return self.sign(md, hash, sig, rng); } else { // Some non-deterministic scheme - return Err(Error::PkInvalidAlg); + return Err(codes::PkInvalidAlg.into()); } } @@ -1033,7 +1033,7 @@ Please use `private_from_ec_scalar_with_rng` instead." // If hash or sig are allowed with size 0 (&[]) then mbedtls will attempt to // auto-detect size and cause an invalid write. if hash.len() == 0 || sig.len() == 0 { - return Err(Error::PkBadInputData); + return Err(codes::PkBadInputData.into()); } unsafe { @@ -1056,13 +1056,13 @@ Please use `private_from_ec_scalar_with_rng` instead." )?; ecdh.calc_secret(shared, rng) }, - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } } pub fn write_private_der<'buf>(&mut self, buf: &'buf mut [u8]) -> Result> { match unsafe { pk_write_key_der(&mut self.inner, buf.as_mut_ptr(), buf.len()).into_result() } { - Err(Error::Asn1BufTooSmall) => Ok(None), + Err(e) if e.low_level() == Some(codes::Asn1BufTooSmall) => Ok(None), Err(e) => Err(e), Ok(n) => Ok(Some(&buf[buf.len() - (n as usize)..])), } @@ -1074,7 +1074,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn write_private_pem<'buf>(&mut self, buf: &'buf mut [u8]) -> Result> { match unsafe { pk_write_key_pem(&mut self.inner, buf.as_mut_ptr(), buf.len()).into_result() } { - Err(Error::Base64BufferTooSmall) => Ok(None), + Err(e) if e.low_level() == Some(codes::Base64BufferTooSmall) => Ok(None), Err(e) => Err(e), Ok(n) => Ok(Some(&buf[buf.len() - (n as usize)..])), } @@ -1091,7 +1091,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn write_public_der<'buf>(&mut self, buf: &'buf mut [u8]) -> Result> { match unsafe { pk_write_pubkey_der(&mut self.inner, buf.as_mut_ptr(), buf.len()).into_result() } { - Err(Error::Asn1BufTooSmall) => Ok(None), + Err(e) if e.low_level() == Some(codes::Asn1BufTooSmall) => Ok(None), Err(e) => Err(e), Ok(n) => Ok(Some(&buf[buf.len() - (n as usize)..])), } @@ -1103,7 +1103,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn write_public_pem<'buf>(&mut self, buf: &'buf mut [u8]) -> Result> { match unsafe { pk_write_pubkey_pem(&mut self.inner, buf.as_mut_ptr(), buf.len()).into_result() } { - Err(Error::Base64BufferTooSmall) => Ok(None), + Err(e) if e.low_level() == Some(codes::Base64BufferTooSmall) => Ok(None), Err(e) => Err(e), Ok(n) => Ok(Some(&buf[buf.len() - (n as usize)..])), } @@ -1339,30 +1339,30 @@ iy6KC991zzvaWY/Ys+q/84Afqa+0qJKQnPuy/7F5GkVdQA/lfbhi .unwrap(); pk.verify(digest, data, &signature[0..len]).unwrap(); - assert_eq!(pk.verify(digest, data, &[]).unwrap_err(), Error::PkBadInputData); - assert_eq!(pk.verify(digest, &[], &signature[0..len]).unwrap_err(), Error::PkBadInputData); + assert_eq!(pk.verify(digest, data, &[]).unwrap_err(), codes::PkBadInputData.into()); + assert_eq!(pk.verify(digest, &[], &signature[0..len]).unwrap_err(), codes::PkBadInputData.into()); let mut dummy_sig = []; assert_eq!( pk.sign(digest, data, &mut dummy_sig, &mut crate::test_support::rand::test_rng()) .unwrap_err(), - Error::PkBadInputData + codes::PkBadInputData.into() ); assert_eq!( pk.sign(digest, &[], &mut signature, &mut crate::test_support::rand::test_rng()) .unwrap_err(), - Error::PkBadInputData + codes::PkBadInputData.into() ); assert_eq!( pk.sign_deterministic(digest, data, &mut dummy_sig, &mut crate::test_support::rand::test_rng()) .unwrap_err(), - Error::PkBadInputData + codes::PkBadInputData.into() ); assert_eq!( pk.sign_deterministic(digest, &[], &mut signature, &mut crate::test_support::rand::test_rng()) .unwrap_err(), - Error::PkBadInputData + codes::PkBadInputData.into() ); } } @@ -1465,7 +1465,7 @@ iy6KC991zzvaWY/Ys+q/84Afqa+0qJKQnPuy/7F5GkVdQA/lfbhi assert_eq!( pk.encrypt(b"test", &mut cipher, &mut crate::test_support::rand::test_rng()) .unwrap_err(), - Error::RsaInvalidPadding + codes::RsaInvalidPadding.into() ); } @@ -1504,7 +1504,7 @@ iy6KC991zzvaWY/Ys+q/84Afqa+0qJKQnPuy/7F5GkVdQA/lfbhi b"WRONG_LABEL" ) .unwrap_err(), - Error::RsaInvalidPadding + codes::RsaInvalidPadding.into() ); } @@ -1520,7 +1520,7 @@ iy6KC991zzvaWY/Ys+q/84Afqa+0qJKQnPuy/7F5GkVdQA/lfbhi assert_eq!( pk.sign(Type::Sha256, data, &mut signature, &mut crate::test_support::rand::test_rng()) .unwrap_err(), - Error::RsaInvalidPadding + codes::RsaInvalidPadding.into() ); } @@ -1543,7 +1543,7 @@ iy6KC991zzvaWY/Ys+q/84Afqa+0qJKQnPuy/7F5GkVdQA/lfbhi }); assert_eq!( pk.verify(digest, data, &signature[0..len]).unwrap_err(), - Error::RsaInvalidPadding + codes::RsaInvalidPadding.into() ); } diff --git a/mbedtls/src/pkcs12/mod.rs b/mbedtls/src/pkcs12/mod.rs index 54e5a82fa..14161aa78 100644 --- a/mbedtls/src/pkcs12/mod.rs +++ b/mbedtls/src/pkcs12/mod.rs @@ -35,7 +35,7 @@ use crate::cipher::{Cipher, Decryption, Fresh, Traditional}; use crate::hash::{pbkdf_pkcs12, Hmac, MdInfo, Type as MdType}; use crate::pk::Pk; use crate::x509::Certificate; -use crate::Error as MbedtlsError; +use crate::error::{Error as MbedtlsError, codes}; // Constants for various object identifiers used in PKCS12: @@ -686,7 +686,7 @@ impl Pfx { let md_info: MdInfo = match md.into() { Some(md) => md, - None => return Err(Pkcs12Error::from(MbedtlsError::MdBadInputData)), + None => return Err(Pkcs12Error::from(MbedtlsError::from(codes::MdBadInputData))), }; if stored_mac.len() != md_info.size() { @@ -862,6 +862,7 @@ impl BERDecodable for Pfx { mod tests { use crate::mbedtls::pkcs12::{ASN1Error, ASN1ErrorKind, Pfx, Pkcs12Error}; + use crate::error::{codes, Error}; #[test] fn parse_shibboleth() { @@ -1024,7 +1025,7 @@ mod tests { let pfx = parsed_pfx.decrypt(&wrong_password, None); assert!(pfx.is_err()); - assert_eq!(pfx.unwrap_err(), Pkcs12Error::Crypto(crate::Error::CipherInvalidPadding)); + assert_eq!(pfx.unwrap_err(), Pkcs12Error::Crypto(Error::from(codes::CipherInvalidPadding))); let pfx = parsed_pfx.decrypt(&wrong_password_correct_padding, None); assert!(pfx.is_err()); diff --git a/mbedtls/src/private.rs b/mbedtls/src/private.rs index 33df9b6f9..abd7d9810 100644 --- a/mbedtls/src/private.rs +++ b/mbedtls/src/private.rs @@ -13,7 +13,7 @@ use mbedtls_sys::types::raw_types::c_char; use mbedtls_sys::types::raw_types::{c_int, c_uchar}; use mbedtls_sys::types::size_t; -use crate::error::{Error, IntoResult, Result}; +use crate::error::{Error, IntoResult, Result, LoError, HiError}; pub trait UnsafeFrom where @@ -34,18 +34,22 @@ where const MAX_VECTOR_ALLOCATION: usize = 4 * 1024 * 1024; let mut vec = Vec::with_capacity(2048 /* big because of bug in x509write */); + + let is_buf_too_small = |e: &Error| match (e.high_level(), e.low_level()) { + (Some(HiError::EcpBufferTooSmall + | HiError::SslBufferTooSmall + | HiError::X509BufferTooSmall), _) + | (_, Some(LoError::Asn1BufTooSmall + | LoError::Base64BufferTooSmall + | LoError::MpiBufferTooSmall + | LoError::NetBufferTooSmall + | LoError::OidBufTooSmall)) => true, + _ => false, + }; + loop { match f(vec.as_mut_ptr(), vec.capacity()).into_result() { - Err(Error::Asn1BufTooSmall) - | Err(Error::Base64BufferTooSmall) - | Err(Error::EcpBufferTooSmall) - | Err(Error::MpiBufferTooSmall) - | Err(Error::NetBufferTooSmall) - | Err(Error::OidBufTooSmall) - | Err(Error::SslBufferTooSmall) - | Err(Error::X509BufferTooSmall) - if vec.capacity() < MAX_VECTOR_ALLOCATION => - { + Err(e) if is_buf_too_small(&e) && vec.capacity() < MAX_VECTOR_ALLOCATION => { let cap = vec.capacity(); vec.reserve(cap * 2) } diff --git a/mbedtls/src/ssl/async_io.rs b/mbedtls/src/ssl/async_io.rs index eca3d84ae..f275595d0 100644 --- a/mbedtls/src/ssl/async_io.rs +++ b/mbedtls/src/ssl/async_io.rs @@ -9,7 +9,7 @@ #![cfg(all(feature = "std", feature = "async"))] use crate::{ - error::{Error, Result}, + error::{Error, Result, codes}, ssl::{ context::Context, io::{IoCallback, IoCallbackUnsafe}, @@ -47,17 +47,17 @@ impl<'a, 'b, 'c, IO: AsyncRead + AsyncWrite + std::marker::Unpin + 'static> IoCa let io = Pin::new(&mut self.1); match io.poll_read(self.0, &mut buf) { Poll::Ready(Ok(())) => Ok(buf.filled().len()), - Poll::Ready(Err(_)) => Err(Error::NetRecvFailed), - Poll::Pending => Err(Error::SslWantRead), + Poll::Ready(Err(_)) => Err(codes::NetRecvFailed.into()), + Poll::Pending => Err(codes::SslWantRead.into()), } } fn send(&mut self, buf: &[u8]) -> Result { let io = Pin::new(&mut self.1); match io.poll_write(self.0, buf) { - Poll::Ready(Err(_)) => Err(Error::NetSendFailed), + Poll::Ready(Err(_)) => Err(codes::NetSendFailed.into()), Poll::Ready(Ok(n)) => Ok(n), - Poll::Pending => Err(Error::SslWantWrite), + Poll::Pending => Err(codes::SslWantWrite.into()), } } } @@ -76,11 +76,11 @@ impl Context { fn poll(mut self: Pin<&mut Self>, ctx: &mut TaskContext) -> std::task::Poll { self.0 .with_bio_async(ctx, |ssl_ctx| match ssl_ctx.handshake() { - Err(Error::SslWantRead) | Err(Error::SslWantWrite) => Poll::Pending, + Err(e) if matches!(e.high_level(), Some(codes::SslWantRead | codes::SslWantWrite)) => Poll::Pending, Err(e) => Poll::Ready(Err(e)), Ok(()) => Poll::Ready(Ok(())), }) - .unwrap_or(Poll::Ready(Err(Error::NetSendFailed))) + .unwrap_or(Poll::Ready(Err(codes::NetSendFailed.into()))) } } @@ -100,15 +100,15 @@ where } self.with_bio_async(cx, |ssl_ctx| match ssl_ctx.recv(buf.initialize_unfilled()) { - Err(Error::SslPeerCloseNotify) => Poll::Ready(Ok(())), - Err(Error::SslWantRead) => Poll::Pending, + Err(e) if e.high_level() == Some(codes::SslPeerCloseNotify) => Poll::Ready(Ok(())), + Err(e) if e.high_level() == Some(codes::SslWantRead) => Poll::Pending, Err(e) => Poll::Ready(Err(crate::private::error_to_io_error(e))), Ok(i) => { buf.advance(i); Poll::Ready(Ok(())) } }) - .unwrap_or_else(|| Poll::Ready(Err(crate::private::error_to_io_error(Error::NetRecvFailed)))) + .unwrap_or_else(|| Poll::Ready(Err(crate::private::error_to_io_error(Error::from(codes::NetRecvFailed))))) } } @@ -122,12 +122,11 @@ where } self.with_bio_async(cx, |ssl_ctx| match ssl_ctx.async_write(buf) { - Err(Error::SslPeerCloseNotify) => Poll::Ready(Ok(0)), - Err(Error::SslWantWrite) => Poll::Pending, + Err(e) if e.high_level() == Some(codes::SslPeerCloseNotify) => Poll::Ready(Ok(0)), + Err(e) if e.high_level() == Some(codes::SslWantWrite) => Poll::Pending, Err(e) => Poll::Ready(Err(crate::private::error_to_io_error(e))), Ok(i) => Poll::Ready(Ok(i)), - }) - .unwrap_or_else(|| Poll::Ready(Err(crate::private::error_to_io_error(Error::NetSendFailed)))) + }).unwrap_or_else(|| Poll::Ready(Err(crate::private::error_to_io_error(Error::from(codes::NetSendFailed))))) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { @@ -137,9 +136,9 @@ where match self .with_bio_async(cx, Context::flush_output) - .unwrap_or(Err(Error::NetSendFailed)) + .unwrap_or(Err(codes::NetSendFailed.into())) { - Err(Error::SslWantWrite) => Poll::Pending, + Err(e) if e.high_level() == Some(codes::SslWantWrite) => Poll::Pending, Err(e) => Poll::Ready(Err(crate::private::error_to_io_error(e))), Ok(()) => Poll::Ready(Ok(())), } @@ -152,9 +151,9 @@ where match self .with_bio_async(cx, Context::close_notify) - .unwrap_or(Err(Error::NetSendFailed)) + .unwrap_or(Err(codes::NetSendFailed.into())) { - Err(Error::SslWantRead) | Err(Error::SslWantWrite) => Poll::Pending, + Err(e) if matches!(e.high_level(), Some(codes::SslWantRead | codes::SslWantWrite)) => Poll::Pending, Err(e) => { self.drop_io(); Poll::Ready(Err(crate::private::error_to_io_error(e))) diff --git a/mbedtls/src/ssl/config.rs b/mbedtls/src/ssl/config.rs index 4b6ab4cb2..55ed27a70 100644 --- a/mbedtls/src/ssl/config.rs +++ b/mbedtls/src/ssl/config.rs @@ -20,7 +20,9 @@ use mbedtls_sys::*; use crate::alloc::List as MbedtlsList; #[cfg(not(feature = "std"))] use crate::alloc_prelude::*; -use crate::error::{Error, IntoResult, Result}; +#[cfg(feature = "std")] +use crate::error::Error; +use crate::error::{Result, IntoResult, codes}; use crate::pk::dhparam::Dhm; use crate::pk::Pk; use crate::private::UnsafeFrom; @@ -117,11 +119,7 @@ impl NullTerminatedStrList { }; for item in list { - ret.c.push( - ::std::ffi::CString::new(*item) - .map_err(|_| Error::SslBadInputData)? - .into_raw(), - ); + ret.c.push(::std::ffi::CString::new(*item).map_err(|_| crate::error::codes::SslBadInputData)?.into_raw()); } ret.c.push(core::ptr::null_mut()); @@ -267,7 +265,7 @@ impl Config { Version::Tls1_1 => 2, Version::Tls1_2 => 3, _ => { - return Err(Error::SslBadHsProtocolVersion); + return Err(codes::SslBadHsProtocolVersion.into()); } }; @@ -282,7 +280,7 @@ impl Config { Version::Tls1_1 => 2, Version::Tls1_2 => 3, _ => { - return Err(Error::SslBadHsProtocolVersion); + return Err(codes::SslBadHsProtocolVersion.into()); } }; unsafe { ssl_conf_max_version(self.into(), 3, minor) }; @@ -323,7 +321,7 @@ impl Config { pub fn push_cert(&mut self, own_cert: Arc>, own_pk: Arc) -> Result<()> { if own_cert.is_empty() { - return Err(Error::SslBadInputData); + return Err(crate::error::codes::SslBadInputData.into()); } // Need to ensure own_cert/pk_key outlive the config. self.own_cert.push(own_cert.clone()); diff --git a/mbedtls/src/ssl/context.rs b/mbedtls/src/ssl/context.rs index 349a34e99..4319308bd 100644 --- a/mbedtls/src/ssl/context.rs +++ b/mbedtls/src/ssl/context.rs @@ -17,7 +17,7 @@ use mbedtls_sys::*; use crate::alloc::List as MbedtlsList; #[cfg(not(feature = "std"))] use crate::alloc_prelude::*; -use crate::error::{Error, IntoResult, Result}; +use crate::error::{Error, Result, IntoResult, codes}; use crate::pk::Pk; use crate::private::UnsafeFrom; use crate::ssl::config::{AuthMode, Config, Version}; @@ -243,10 +243,7 @@ impl Context { // c-mbedtls's buffer, so we need to return size of bytes that has been buffered. // Since we know before this call `out_left` was 0, all buffer (with in the MBEDTLS_SSL_OUT_CONTENT_LEN part) is // buffered - Err(Error::SslWantWrite) => Ok(std::cmp::min( - unsafe { ssl_get_max_out_record_payload((&*self).into()).into_result()? as usize }, - buf.len(), - )), + Err(e) if e.high_level() == Some(codes::SslWantWrite) => Ok(std::cmp::min(unsafe { ssl_get_max_out_record_payload((&*self).into()).into_result()? as usize }, buf.len())), res => res, } } @@ -305,9 +302,8 @@ impl Context { pub fn handshake(&mut self) -> Result<()> { match self.inner_handshake() { Ok(()) => Ok(()), - Err(Error::SslWantRead) => Err(Error::SslWantRead), - Err(Error::SslWantWrite) => Err(Error::SslWantWrite), - Err(Error::SslHelloVerifyRequired) => { + Err(e) if matches!(e.high_level(), Some(codes::SslWantRead | codes::SslWantWrite)) => Err(e), + Err(e) if matches!(e.high_level(), Some(codes::SslHelloVerifyRequired)) => { unsafe { // `ssl_session_reset` resets the client ID but the user will call handshake // again in this case and the client ID is required for a DTLS connection setup @@ -325,7 +321,7 @@ impl Context { self.set_client_transport_id(&client_id)?; } } - Err(Error::SslHelloVerifyRequired) + Err(codes::SslHelloVerifyRequired.into()) } Err(e) => { self.close(); @@ -349,7 +345,7 @@ impl Context { #[cfg(not(feature = "std"))] fn set_hostname(&mut self, hostname: Option<&str>) -> Result<()> { match hostname { - Some(_) => Err(Error::SslBadInputData), + Some(_) => Err(codes::SslBadInputData.into()), None => Ok(()), } } @@ -357,7 +353,7 @@ impl Context { #[cfg(feature = "std")] fn set_hostname(&mut self, hostname: Option<&str>) -> Result<()> { if let Some(s) = hostname { - let cstr = ::std::ffi::CString::new(s).map_err(|_| Error::SslBadInputData)?; + let cstr = ::std::ffi::CString::new(s).map_err(|_| Error::from(codes::SslBadInputData))?; unsafe { ssl_set_hostname(self.into(), cstr.as_ptr()).into_result().map(|_| ()) } } else { Ok(()) @@ -442,7 +438,7 @@ impl Context { /// pub fn ciphersuite(&self) -> Result { if self.handle().session.is_null() { - return Err(Error::SslBadInputData); + return Err(codes::SslBadInputData.into()); } Ok(unsafe { self.handle().session.as_ref().unwrap().ciphersuite as u16 }) @@ -450,7 +446,7 @@ impl Context { pub fn peer_cert(&self) -> Result>> { if self.handle().session.is_null() { - return Err(Error::SslBadInputData); + return Err(codes::SslBadInputData.into()); } unsafe { @@ -459,7 +455,7 @@ impl Context { // variable for that. let peer_cert: &MbedtlsList = UnsafeFrom::from(&((*self.handle().session).peer_cert) as *const *mut x509_crt as *const *const x509_crt) - .ok_or(Error::SslBadInputData)?; + .ok_or(codes::SslBadInputData.into())?; Ok(Some(peer_cert)) } } @@ -559,7 +555,7 @@ impl HandshakeContext { pub fn set_authmode(&mut self, am: AuthMode) -> Result<()> { if self.inner.handshake as *const _ == ::core::ptr::null() { - return Err(Error::SslBadInputData); + return Err(codes::SslBadInputData.into()); } unsafe { ssl_set_hs_authmode(self.into(), am as i32) } @@ -569,7 +565,7 @@ impl HandshakeContext { pub fn set_ca_list(&mut self, chain: Option>>, crl: Option>) -> Result<()> { // mbedtls_ssl_set_hs_ca_chain does not check for NULL handshake. if self.inner.handshake as *const _ == ::core::ptr::null() { - return Err(Error::SslBadInputData); + return Err(codes::SslBadInputData.into()); } // This will override current handshake CA chain. @@ -596,7 +592,7 @@ impl HandshakeContext { pub fn push_cert(&mut self, chain: Arc>, key: Arc) -> Result<()> { // mbedtls_ssl_set_hs_own_cert does not check for NULL handshake. if self.inner.handshake as *const _ == ::core::ptr::null() || chain.is_empty() { - return Err(Error::SslBadInputData); + return Err(codes::SslBadInputData.into()); } // This will append provided certificate pointers in internal structures. diff --git a/mbedtls/src/ssl/io.rs b/mbedtls/src/ssl/io.rs index 237468e6d..ba81fbdf1 100644 --- a/mbedtls/src/ssl/io.rs +++ b/mbedtls/src/ssl/io.rs @@ -26,7 +26,7 @@ use mbedtls_sys::types::size_t; use super::context::Context; #[cfg(feature = "std")] -use crate::error::Error; +use crate::error::{Error, codes}; use crate::error::Result; /// A direct representation of the `mbedtls_ssl_send_t` and `mbedtls_ssl_recv_t` @@ -122,15 +122,15 @@ impl IoCallback for IO { impl IoCallback for IO { fn recv(&mut self, buf: &mut [u8]) -> Result { self.read(buf).map_err(|e| match e { - ref e if e.kind() == std::io::ErrorKind::WouldBlock => Error::SslWantRead, - _ => Error::NetRecvFailed, + ref e if e.kind() == std::io::ErrorKind::WouldBlock => Error::from(codes::SslWantRead), + _ => Error::from(codes::NetRecvFailed) }) } fn send(&mut self, buf: &[u8]) -> Result { self.write(buf).map_err(|e| match e { - ref e if e.kind() == std::io::ErrorKind::WouldBlock => Error::SslWantWrite, - _ => Error::NetSendFailed, + ref e if e.kind() == std::io::ErrorKind::WouldBlock => Error::from(codes::SslWantWrite), + _ => Error::from(codes::NetSendFailed) }) } } @@ -162,13 +162,13 @@ impl Io for ConnectedUdpSocket { fn recv(&mut self, buf: &mut [u8]) -> Result { match self.socket.recv(buf) { Ok(i) => Ok(i), - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => Err(Error::SslWantRead), - Err(_) => Err(Error::NetRecvFailed), + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => Err(codes::SslWantRead.into()), + Err(_) => Err(codes::NetRecvFailed.into()) } } fn send(&mut self, buf: &[u8]) -> Result { - self.socket.send(buf).map_err(|_| Error::NetSendFailed) + self.socket.send(buf).map_err(|_| codes::NetSendFailed.into()) } } @@ -191,8 +191,8 @@ impl> Io for Context { impl> Read for Context { fn read(&mut self, buf: &mut [u8]) -> IoResult { match self.recv(buf) { - Err(Error::SslPeerCloseNotify) => Ok(0), - Err(Error::SslWantRead) | Err(Error::SslWantWrite) => Err(IoErrorKind::WouldBlock.into()), + Err(e) if e.high_level() == Some(codes::SslPeerCloseNotify) => Ok(0), + Err(e) if matches!(e.high_level(), Some(codes::SslWantRead | codes::SslWantWrite)) => Err(IoErrorKind::WouldBlock.into()), Err(e) => Err(crate::private::error_to_io_error(e)), Ok(i) => Ok(i), } @@ -208,8 +208,8 @@ impl> Read for Context { impl> Write for Context { fn write(&mut self, buf: &[u8]) -> IoResult { match self.send(buf) { - Err(Error::SslPeerCloseNotify) => Ok(0), - Err(Error::SslWantRead) | Err(Error::SslWantWrite) => Err(IoErrorKind::WouldBlock.into()), + Err(e) if e.high_level() == Some(codes::SslPeerCloseNotify) => Ok(0), + Err(e) if matches!(e.high_level(), Some(codes::SslWantRead | codes::SslWantWrite)) => Err(IoErrorKind::WouldBlock.into()), Err(e) => Err(crate::private::error_to_io_error(e)), Ok(i) => Ok(i), } diff --git a/mbedtls/src/x509/certificate.rs b/mbedtls/src/x509/certificate.rs index 80cb94754..2cf10de71 100644 --- a/mbedtls/src/x509/certificate.rs +++ b/mbedtls/src/x509/certificate.rs @@ -16,7 +16,7 @@ use mbedtls_sys::*; use crate::alloc::{mbedtls_calloc, Box as MbedtlsBox, CString, List as MbedtlsList}; #[cfg(not(feature = "std"))] use crate::alloc_prelude::*; -use crate::error::{Error, IntoResult, Result}; +use crate::error::{Error, IntoResult, Result, codes}; use crate::hash::Type as MdType; use crate::pk::Pk; use crate::private::UnsafeFrom; @@ -78,7 +78,7 @@ fn x509_buf_to_vec(buf: &x509_buf) -> Vec { fn x509_time_to_time(tm: &x509_time) -> Result