From 9b0dbdb36716003d4c2c6e3ede51bdaf3231dd85 Mon Sep 17 00:00:00 2001 From: mridul <104053102+mridul-manohar@users.noreply.github.com> Date: Thu, 14 Nov 2024 23:14:16 +0530 Subject: [PATCH] Porting combined errors from mbedtls changes (#372) Due to downgrade back to v2.28 (from ~3.X), the enhancement in PR #271 was lost in the mbedtls 2.8 branch being used. we need to port these combined errors from mbedtls changes back to the mbedtls 2.8 branch and build with the latest upgrade in rust toolchain version to apply the enhancement. This PR ports said changes to a latest branch forked from mbedtls master and builds the same on latest rustc 1.83.0-nightly (26d8e9255 2024-10-11) version. --- Cargo.lock | 119 ++++----- mbedtls/Cargo.toml | 6 +- mbedtls/src/bignum/mod.rs | 21 +- mbedtls/src/cipher/raw/mod.rs | 12 +- mbedtls/src/ecp/mod.rs | 24 +- mbedtls/src/error.rs | 404 +++++++++++++++++++++---------- mbedtls/src/hash/mod.rs | 28 +-- mbedtls/src/lib.rs | 4 +- mbedtls/src/pk/dsa/mod.rs | 32 +-- mbedtls/src/pk/ec.rs | 4 +- mbedtls/src/pk/mod.rs | 107 ++++---- mbedtls/src/pkcs12/mod.rs | 10 +- mbedtls/src/private.rs | 29 ++- mbedtls/src/ssl/async_io.rs | 34 +-- mbedtls/src/ssl/config.rs | 10 +- mbedtls/src/ssl/context.rs | 27 +-- mbedtls/src/ssl/io.rs | 30 ++- mbedtls/src/x509/certificate.rs | 47 ++-- mbedtls/src/x509/csr.rs | 12 +- mbedtls/tests/async_session.rs | 18 +- mbedtls/tests/bignum.rs | 6 +- mbedtls/tests/client_server.rs | 26 +- mbedtls/tests/ec.rs | 4 +- mbedtls/tests/hyper.rs | 2 +- mbedtls/tests/rsa.rs | 4 +- mbedtls/tests/ssl_conf_ca_cb.rs | 8 +- mbedtls/tests/ssl_conf_verify.rs | 8 +- 27 files changed, 586 insertions(+), 450 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 017c5735b..a0a22c46e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "anes" @@ -31,8 +31,8 @@ version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e4655ae1a7b0cdf149156f780c5bf3f1352bc53cbd9e0a361a7ef7b22947e965" dependencies = [ - "proc-macro2 1.0.66", - "quote 1.0.33", + "proc-macro2", + "quote", "syn 1.0.64", ] @@ -72,12 +72,12 @@ dependencies = [ "log 0.4.8", "peeking_take_while", "prettyplease", - "proc-macro2 1.0.66", - "quote 1.0.33", + "proc-macro2", + "quote", "regex", "rustc-hash", "shlex", - "syn 2.0.32", + "syn 2.0.87", "which", ] @@ -429,8 +429,8 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6" dependencies = [ - "proc-macro2 1.0.66", - "quote 1.0.33", + "proc-macro2", + "quote", "syn 1.0.64", ] @@ -679,7 +679,7 @@ checksum = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08" [[package]] name = "mbedtls" -version = "0.12.3" +version = "0.13.0" dependencies = [ "async-stream", "bit-vec", @@ -733,7 +733,7 @@ dependencies = [ "lazy_static", "libc", "libz-sys", - "quote 1.0.33", + "quote", "syn 1.0.64", ] @@ -915,44 +915,26 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae005bd773ab59b4725093fd7df83fd7892f7d8eafb48dbd7de6e024e4215f9d" dependencies = [ - "proc-macro2 1.0.66", - "syn 2.0.32", + "proc-macro2", + "syn 2.0.87", ] [[package]] name = "proc-macro2" -version = "0.4.30" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf3d2011ab5c909338f7887f4fc896d35932e29146c12c8d01da6b22a80ba759" -dependencies = [ - "unicode-xid 0.1.0", -] - -[[package]] -name = "proc-macro2" -version = "1.0.66" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "0.6.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ce23b6b870e8f94f81fb0a363d65d86675884b34a09043c81e5562f11c1f8e1" -dependencies = [ - "proc-macro2 0.4.30", -] - -[[package]] -name = "quote" -version = "1.0.33" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ - "proc-macro2 1.0.66", + "proc-macro2", ] [[package]] @@ -1108,9 +1090,9 @@ checksum = "b97ed7a9823b74f99c7742f5336af7be5ecd3eeafcb1507d1fa93347b1d589b0" [[package]] name = "serde" -version = "1.0.101" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9796c9b7ba2ffe7a9ce53c2287dfc48080f4b2b362fcc245a259b3a7201119dd" +checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5" dependencies = [ "serde_derive", ] @@ -1137,13 +1119,13 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.70" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3525a779832b08693031b8ecfb0de81cd71cfd3812088fafe9a7496789572124" +checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" dependencies = [ - "proc-macro2 0.4.30", - "quote 0.6.13", - "syn 0.14.9", + "proc-macro2", + "quote", + "syn 2.0.87", ] [[package]] @@ -1178,36 +1160,25 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" -[[package]] -name = "syn" -version = "0.14.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "261ae9ecaa397c42b960649561949d69311f08eeaea86a65696e6e46517cf741" -dependencies = [ - "proc-macro2 0.4.30", - "quote 0.6.13", - "unicode-xid 0.1.0", -] - [[package]] name = "syn" version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fd9d1e9976102a03c542daa2eff1b43f9d72306342f3f8b3ed5fb8908195d6f" dependencies = [ - "proc-macro2 1.0.66", - "quote 1.0.33", - "unicode-xid 0.2.1", + "proc-macro2", + "quote", + "unicode-xid", ] [[package]] name = "syn" -version = "2.0.32" +version = "2.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" dependencies = [ - "proc-macro2 1.0.66", - "quote 1.0.33", + "proc-macro2", + "quote", "unicode-ident", ] @@ -1269,8 +1240,8 @@ version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d266c00fde287f55d3f1c3e96c500c362a2b8c695076ec180f27918820bc6df8" dependencies = [ - "proc-macro2 1.0.66", - "quote 1.0.33", + "proc-macro2", + "quote", "syn 1.0.64", ] @@ -1292,8 +1263,8 @@ version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "11c75893af559bc8e10716548bdef5cb2b983f8e637db9d0e15126b61b484ee2" dependencies = [ - "proc-macro2 1.0.66", - "quote 1.0.33", + "proc-macro2", + "quote", "syn 1.0.64", ] @@ -1357,12 +1328,6 @@ dependencies = [ "tinyvec", ] -[[package]] -name = "unicode-xid" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc72304796d0818e357ead4e000d19c9c174ab23dc11093ac919054d20a6a7fc" - [[package]] name = "unicode-xid" version = "0.2.1" @@ -1427,9 +1392,9 @@ dependencies = [ "bumpalo", "log 0.4.8", "once_cell", - "proc-macro2 1.0.66", - "quote 1.0.33", - "syn 2.0.32", + "proc-macro2", + "quote", + "syn 2.0.87", "wasm-bindgen-shared", ] @@ -1439,7 +1404,7 @@ version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e4c238561b2d428924c49815533a8b9121c664599558a5d9ec51f8a1740a999" dependencies = [ - "quote 1.0.33", + "quote", "wasm-bindgen-macro-support", ] @@ -1449,9 +1414,9 @@ version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bae1abb6806dc1ad9e560ed242107c0f6c84335f1749dd4e8ddb012ebd5e25a7" dependencies = [ - "proc-macro2 1.0.66", - "quote 1.0.33", - "syn 2.0.32", + "proc-macro2", + "quote", + "syn 2.0.87", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/mbedtls/Cargo.toml b/mbedtls/Cargo.toml index 61b3068bd..4f1003e35 100644 --- a/mbedtls/Cargo.toml +++ b/mbedtls/Cargo.toml @@ -2,7 +2,7 @@ name = "mbedtls" # We jumped from v0.9 to v0.12 because v0.10 and v0.11 were based on mbedtls 3.X, which # we decided not to support. -version = "0.12.3" +version = "0.13.0" authors = ["Jethro Beekman "] build = "build.rs" edition = "2018" @@ -24,8 +24,8 @@ features = ["x509", "ssl"] [dependencies] bitflags = "1" -serde = { version = "1.0.7", default-features = false, features = ["alloc"] } -serde_derive = "1.0.7" +serde = { version = "1.0.214", default-features = false, features = ["alloc"] } +serde_derive = "1.0.214" byteorder = { version = "1.0.0", default-features = false } yasna = { version = "0.2", optional = true, features = [ "num-bigint", diff --git a/mbedtls/src/bignum/mod.rs b/mbedtls/src/bignum/mod.rs index 0df9e7bc6..6be52df49 100644 --- a/mbedtls/src/bignum/mod.rs +++ b/mbedtls/src/bignum/mod.rs @@ -6,7 +6,8 @@ * option. This file may not be copied, modified, or distributed except * according to those terms. */ -use crate::error::{Error, IntoResult, Result}; +use crate::error::Error; +use crate::error::{codes, IntoResult, Result}; use mbedtls_sys::*; #[cfg(not(feature = "std"))] @@ -161,7 +162,7 @@ impl Mpi { pub fn as_u32(&self) -> Result { if self.bit_length()? > 32 { // Not exactly correct but close enough - return Err(Error::MpiBufferTooSmall); + return Err(codes::MpiBufferTooSmall.into()); } Ok(self.get_limb(0) as u32) @@ -183,7 +184,7 @@ impl Mpi { let r = unsafe { mpi_write_string(&self.inner, radix, ::core::ptr::null_mut(), 0, &mut olen) }; if r != ERR_MPI_BUFFER_TOO_SMALL { - return Err(Error::from_mbedtls_code(r)); + return Err(r.into()); } let mut buf = vec![0u8; olen]; @@ -264,7 +265,7 @@ impl Mpi { let zero = Mpi::new(0)?; if self < &zero || self >= p { - return Err(Error::MpiBadInputData); + return Err(codes::MpiBadInputData.into()); } if self == &zero { return Ok(zero); @@ -273,12 +274,12 @@ impl Mpi { // This ignores p=2 (for which this algorithm is valid), as not // cryptographically interesting. if p.get_bit(0) == false || p <= &zero { - return Err(Error::MpiBadInputData); + return Err(codes::MpiBadInputData.into()); } if self.jacobi(p)? != 1 { // a is not a quadratic residue mod p - return Err(Error::MpiBadInputData); + return Err(codes::MpiBadInputData.into()); } if (p % 4)?.as_u32()? == 3 { @@ -325,7 +326,7 @@ impl Mpi { bo = bo.mod_exp(&two, p)?; m += 1; if m >= r { - return Err(Error::MpiBadInputData); + return Err(codes::MpiBadInputData.into()); } } @@ -358,7 +359,7 @@ impl Mpi { let one = Mpi::new(1)?; if self < &zero || n < &zero || n.get_bit(0) == false { - return Err(Error::MpiBadInputData); + return Err(codes::MpiBadInputData.into()); } let mut x = self.modulo(n)?; @@ -431,7 +432,7 @@ impl Mpi { pub(super) fn mpi_inner_eq_const_time(x: &mpi, y: &mpi) -> core::prelude::v1::Result { match mpi_inner_cmp_const_time(x, y) { Ok(order) => Ok(order == Ordering::Equal), - Err(Error::MpiBadInputData) => Ok(false), + Err(e) if e == codes::MpiBadInputData.into() => Ok(false), Err(e) => Err(e), } } @@ -779,7 +780,7 @@ mod tests { ]) .unwrap(); assert_eq!(mpi3.less_than_const_time(&mpi3), Ok(false)); - assert_eq!(mpi2.less_than_const_time(&mpi3), Err(Error::MpiBadInputData)); + assert_eq!(mpi2.less_than_const_time(&mpi3), Err(codes::MpiBadInputData.into())); } #[test] diff --git a/mbedtls/src/cipher/raw/mod.rs b/mbedtls/src/cipher/raw/mod.rs index 2a25e3435..5f260ba52 100644 --- a/mbedtls/src/cipher/raw/mod.rs +++ b/mbedtls/src/cipher/raw/mod.rs @@ -8,7 +8,7 @@ use mbedtls_sys::*; -use crate::error::{Error, IntoResult, Result}; +use crate::error::{codes, IntoResult, Result}; mod serde; @@ -254,7 +254,7 @@ impl Cipher { }; if out_data.len() < required_size { - return Err(Error::CipherFullBlockExpected); + return Err(codes::CipherFullBlockExpected.into()); } let mut olen = 0; @@ -274,7 +274,7 @@ impl Cipher { pub fn finish(&mut self, out_data: &mut [u8]) -> Result { // Check that minimum required space is available in out_data buffer if out_data.len() < self.block_size() { - return Err(Error::CipherFullBlockExpected); + return Err(codes::CipherFullBlockExpected.into()); } let mut olen = 0; @@ -337,7 +337,7 @@ impl Cipher { .checked_sub(tag_len) .map_or(true, |cipher_len| cipher_len < plain.len()) { - return Err(Error::CipherBadInputData); + return Err(codes::CipherBadInputData.into()); } let iv = self.inner.iv; @@ -371,7 +371,7 @@ impl Cipher { .checked_sub(tag_len) .map_or(true, |cipher_len| plain.len() < cipher_len) { - return Err(Error::CipherBadInputData); + return Err(codes::CipherBadInputData.into()); } let iv = self.inner.iv; @@ -470,7 +470,7 @@ impl Cipher { pub fn cmac(&mut self, key: &[u8], data: &[u8], out_data: &mut [u8]) -> Result<()> { // Check that out_data buffer has enough space if out_data.len() < self.block_size() { - return Err(Error::CipherFullBlockExpected); + return Err(codes::CipherFullBlockExpected.into()); } self.reset()?; unsafe { diff --git a/mbedtls/src/ecp/mod.rs b/mbedtls/src/ecp/mod.rs index 6b6fffd7c..63b56e34c 100644 --- a/mbedtls/src/ecp/mod.rs +++ b/mbedtls/src/ecp/mod.rs @@ -6,7 +6,7 @@ * option. This file may not be copied, modified, or distributed except * according to those terms. */ -use crate::error::{Error, IntoResult, Result}; +use crate::error::{codes, Error, IntoResult, Result}; use core::convert::TryFrom; use mbedtls_sys::*; @@ -107,7 +107,7 @@ impl EcGroup { || &order <= &zero || (&a == &zero && &b == &zero) { - return Err(Error::EcpBadInputData); + return Err(codes::EcpBadInputData.into()); } // Compute `order - 2`, needed below. @@ -128,7 +128,7 @@ impl EcGroup { Test that the provided generator satisfies the curve equation */ if unsafe { ecp_check_pubkey(&ret.inner, &ret.inner.G) } != 0 { - return Err(Error::EcpBadInputData); + return Err(codes::EcpBadInputData.into()); } /* @@ -155,7 +155,7 @@ impl EcGroup { let is_zero = unsafe { ecp_is_zero(&g_m.inner as *const ecp_point as *mut ecp_point) }; if is_zero != 1 { - return Err(Error::EcpBadInputData); + return Err(codes::EcpBadInputData.into()); } Ok(ret) @@ -193,7 +193,7 @@ impl EcGroup { EcGroupId::Curve25519 => Ok(8), EcGroupId::Curve448 => Ok(4), // Requires a point-counting algorithm such as SEA. - EcGroupId::None => Err(Error::EcpFeatureUnavailable), + EcGroupId::None => Err(codes::EcpFeatureUnavailable.into()), _ => Ok(1), } } @@ -206,7 +206,7 @@ impl EcGroup { match unsafe { ecp_check_pubkey(&self.inner, &point.inner) } { 0 => Ok(true), ERR_ECP_INVALID_KEY => Ok(false), - err => Err(Error::from_mbedtls_code(err)), + err => Err(err.into()), } } } @@ -249,7 +249,7 @@ impl EcPoint { } pub fn from_binary(group: &EcGroup, bin: &[u8]) -> Result { - let prefix = *bin.get(0).ok_or(Error::EcpBadInputData)?; + let prefix = *bin.get(0).ok_or(Error::from(codes::EcpBadInputData))?; if prefix == 0x02 || prefix == 0x03 { // Compressed point, which mbedtls does not understand @@ -260,7 +260,7 @@ impl EcPoint { let b = group.b()?; if bin.len() != (p.byte_length()? + 1) { - return Err(Error::EcpBadInputData); + return Err(codes::EcpBadInputData.into()); } let x = Mpi::from_binary(&bin[1..]).unwrap(); @@ -317,7 +317,7 @@ impl EcPoint { match unsafe { ecp_is_zero(&self.inner as *const ecp_point as *mut ecp_point) } { 0 => Ok(false), 1 => Ok(true), - _ => Err(Error::EcpInvalidKey), + _ => Err(codes::EcpInvalidKey.into()), } } @@ -402,11 +402,11 @@ Please use `mul_with_rng` instead." let mut ret = Self::init(); if group.contains_point(&pt1)? == false { - return Err(Error::EcpInvalidKey); + return Err(codes::EcpInvalidKey.into()); } if group.contains_point(&pt2)? == false { - return Err(Error::EcpInvalidKey); + return Err(codes::EcpInvalidKey.into()); } unsafe { @@ -430,7 +430,7 @@ Please use `mul_with_rng` instead." match r { 0 => Ok(true), ERR_ECP_BAD_INPUT_DATA => Ok(false), - x => Err(Error::from_mbedtls_code(x)), + x => Err(x.into()), } } diff --git a/mbedtls/src/error.rs b/mbedtls/src/error.rs index 690a7804e..935538a45 100644 --- a/mbedtls/src/error.rs +++ b/mbedtls/src/error.rs @@ -8,6 +8,7 @@ use core::convert::Infallible; use core::fmt; +use core::ops::BitOr; use core::str::Utf8Error; #[cfg(feature = "std")] use std::error::Error as StdError; @@ -23,6 +24,11 @@ pub trait IntoResult: Sized { } } +pub mod codes { + pub use super::HiError::*; + pub use super::LoError::*; +} + // This is intended not to overlap with mbedtls error codes. Utf8Error is // generated in the bindings when converting to rust UTF-8 strings. Only in rare // circumstances (callbacks from mbedtls to rust) do we need to pass a Utf8Error @@ -30,59 +36,102 @@ pub trait IntoResult: Sized { pub const ERR_UTF8_INVALID: c_int = -0x10000; macro_rules! error_enum { - {enum $n:ident {$($rust:ident = $c:ident,)*}} => { - #[derive(Debug, Eq, PartialEq)] - pub enum $n { + { + const MASK: c_int = $mask:literal; + enum $error:ident {$($rust:ident = $c:ident,)*} + } => { + #[non_exhaustive] + #[derive(Debug, Eq, PartialEq, Copy, Clone)] + pub enum $error { $($rust,)* - Other(c_int), - Utf8Error(Option), - // Stable-Rust equivalent of `#[non_exhaustive]` attribute. This - // value should never be used by users of this crate! - #[doc(hidden)] - __Nonexhaustive, + Unknown(c_int) } - impl IntoResult for c_int { - fn into_result(self) -> Result { - let err_code = match self { - _ if self >= 0 => return Ok(self), - ERR_UTF8_INVALID => return Err(Error::Utf8Error(None)), - _ => -self, - }; - let (high_level_code, low_level_code) = (err_code & 0xFF80, err_code & 0x7F); - Err($n::from_mbedtls_code(if high_level_code > 0 { -high_level_code } else { -low_level_code })) + impl From for $error { + fn from(code: c_int) -> $error { + // check against mask here (not in match block) to make it compile-time + $(const $c: c_int = $error::assert_in_mask(::mbedtls_sys::$c);)* + match -code { + $($c => return $error::$rust),*, + _ => return $error::Unknown(-code) + } } } - impl $n { - pub fn from_mbedtls_code(code: c_int) -> Self { - match code { - $(::mbedtls_sys::$c => $n::$rust),*, - _ => $n::Other(code) + impl From<$error> for c_int { + fn from(error: $error) -> c_int { + match error { + $($error::$rust => return ::mbedtls_sys::$c,)* + $error::Unknown(code) => return code, } } + } - pub fn as_str(&self) -> &'static str { - match self { - $(&$n::$rust => concat!("mbedTLS error ",stringify!($n::$rust)),)* - &$n::Other(_) => "mbedTLS unknown error", - &$n::Utf8Error(_) => "error converting to UTF-8", - &$n::__Nonexhaustive => unreachable!("__Nonexhaustive value should not be instantiated"), - } + impl $error { + const fn mask() -> c_int { + $mask } - pub fn to_int(&self) -> c_int { - match *self { - $($n::$rust => ::mbedtls_sys::$c,)* - $n::Other(code) => code, - $n::Utf8Error(_) => ERR_UTF8_INVALID, - $n::__Nonexhaustive => unreachable!("__Nonexhaustive value should not be instantiated"), + const fn assert_in_mask(val: c_int) -> c_int { + assert!((-val & !Self::mask()) == 0); + val + } + + pub fn as_str(&self)-> &'static str { + match self { + $($error::$rust => concat!("mbedTLS error ", stringify!($error::$rust)),)* + $error::Unknown(_) => concat!("mbedTLS unknown ", stringify!($error), " error") } } } }; } +#[derive(Debug, Eq, PartialEq)] +pub enum Error { + HighLevel(HiError), + LowLevel(LoError), + HighAndLowLevel(HiError, LoError), + Other(c_int), + Utf8Error(Option), +} + +impl Error { + pub fn low_level(&self) -> Option { + match self { + Error::LowLevel(error) | Error::HighAndLowLevel(_, error) => Some(*error), + _ => None, + } + } + + pub fn high_level(&self) -> Option { + match self { + Error::HighLevel(error) | Error::HighAndLowLevel(error, _) => Some(*error), + _ => None, + } + } + + pub fn as_str(&self) -> &'static str { + match &self { + &Error::HighLevel(e) => e.as_str(), + &Error::LowLevel(e) => e.as_str(), + &Error::HighAndLowLevel(e, _) => e.as_str(), + &Error::Other(_) => "mbedTLS unknown error", + &Error::Utf8Error(_) => "error converting to UTF-8", + } + } + + pub fn to_int(&self) -> c_int { + match self { + &Error::HighLevel(error) => error.into(), + &Error::LowLevel(error) => error.into(), + &Error::HighAndLowLevel(hl_error, ll_error) => c_int::from(hl_error) + c_int::from(ll_error), + &Error::Other(error) => error, + &Error::Utf8Error(_) => ERR_UTF8_INVALID, + } + } +} + impl From for Error { fn from(e: Utf8Error) -> Error { Error::Utf8Error(Some(e)) @@ -95,57 +144,77 @@ impl From for Error { } } +impl From for Error { + fn from(error: LoError) -> Error { + Error::LowLevel(error) + } +} + +impl From for Error { + fn from(error: HiError) -> Error { + Error::HighLevel(error) + } +} + +impl BitOr for HiError { + type Output = Error; + fn bitor(self, rhs: LoError) -> Self::Output { + Error::HighAndLowLevel(self, rhs) + } +} +impl BitOr for LoError { + type Output = Error; + fn bitor(self, rhs: HiError) -> Self::Output { + Error::HighAndLowLevel(rhs, self) + } +} + +impl From for Error { + fn from(x: c_int) -> Error { + let (high_level_code, low_level_code) = (-x & HiError::mask(), -x & LoError::mask()); + if -x & (HiError::mask() | LoError::mask()) != -x || x >= 0 { + Error::Other(x) + } else if high_level_code == 0 { + Error::LowLevel(low_level_code.into()) + } else if low_level_code == 0 { + Error::HighLevel(high_level_code.into()) + } else { + Error::HighAndLowLevel(high_level_code.into(), low_level_code.into()) + } + } +} + impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - &Error::Utf8Error(Some(ref e)) => f.write_fmt(format_args!("Error converting to UTF-8: {}", e)), - &Error::Utf8Error(None) => f.write_fmt(format_args!("Error converting to UTF-8")), - &Error::Other(i) => f.write_fmt(format_args!("mbedTLS unknown error ({})", i)), - &Error::__Nonexhaustive => unreachable!("__Nonexhaustive value should not be instantiated"), - e @ _ => f.write_fmt(format_args!("mbedTLS error {:?}", e)), + &Error::Utf8Error(Some(ref e)) => { + write!(f, "Error converting to UTF-8: {}", e) + } + &Error::Utf8Error(None) => write!(f, "Error converting to UTF-8"), + &Error::LowLevel(e) => write!(f, "{}", e.as_str()), + &Error::HighLevel(e) => write!(f, "{}", e.as_str()), + &Error::HighAndLowLevel(hi, lo) => write!(f, "({}, {})", hi.as_str(), lo.as_str()), + &Error::Other(code) => write!(f, "mbedTLS unknown error code {}", code), } } } #[cfg(feature = "std")] -impl StdError for Error { - fn description(&self) -> &str { - self.as_str() +impl StdError for Error {} + +impl IntoResult for c_int { + fn into_result(self) -> Result { + match self { + 0.. => return Ok(self), + ERR_UTF8_INVALID => return Err(Error::Utf8Error(None)), + _ => return Err(Error::from(self)), + }; } } error_enum!( - enum Error { - AesBadInputData = ERR_AES_BAD_INPUT_DATA, - AesFeatureUnavailable = ERR_AES_FEATURE_UNAVAILABLE, - AesHwAccelFailed = ERR_AES_HW_ACCEL_FAILED, - AesInvalidInputLength = ERR_AES_INVALID_INPUT_LENGTH, - AesInvalidKeyLength = ERR_AES_INVALID_KEY_LENGTH, - Arc4HwAccelFailed = ERR_ARC4_HW_ACCEL_FAILED, - AriaFeatureUnavailable = ERR_ARIA_FEATURE_UNAVAILABLE, - AriaHwAccelFailed = ERR_ARIA_HW_ACCEL_FAILED, - AriaInvalidInputLength = ERR_ARIA_INVALID_INPUT_LENGTH, - Asn1AllocFailed = ERR_ASN1_ALLOC_FAILED, - Asn1BufTooSmall = ERR_ASN1_BUF_TOO_SMALL, - Asn1InvalidData = ERR_ASN1_INVALID_DATA, - Asn1InvalidLength = ERR_ASN1_INVALID_LENGTH, - Asn1LengthMismatch = ERR_ASN1_LENGTH_MISMATCH, - Asn1OutOfData = ERR_ASN1_OUT_OF_DATA, - Asn1UnexpectedTag = ERR_ASN1_UNEXPECTED_TAG, - Base64BufferTooSmall = ERR_BASE64_BUFFER_TOO_SMALL, - Base64InvalidCharacter = ERR_BASE64_INVALID_CHARACTER, - BlowfishHwAccelFailed = ERR_BLOWFISH_HW_ACCEL_FAILED, - BlowfishInvalidInputLength = ERR_BLOWFISH_INVALID_INPUT_LENGTH, - CamelliaHwAccelFailed = ERR_CAMELLIA_HW_ACCEL_FAILED, - CamelliaInvalidInputLength = ERR_CAMELLIA_INVALID_INPUT_LENGTH, - CcmAuthFailed = ERR_CCM_AUTH_FAILED, - CcmBadInput = ERR_CCM_BAD_INPUT, - CcmHwAccelFailed = ERR_CCM_HW_ACCEL_FAILED, - Chacha20BadInputData = ERR_CHACHA20_BAD_INPUT_DATA, - Chacha20FeatureUnavailable = ERR_CHACHA20_FEATURE_UNAVAILABLE, - Chacha20HwAccelFailed = ERR_CHACHA20_HW_ACCEL_FAILED, - ChachapolyAuthFailed = ERR_CHACHAPOLY_AUTH_FAILED, - ChachapolyBadState = ERR_CHACHAPOLY_BAD_STATE, + const MASK: c_int = 0x7F80; + enum HiError { CipherAllocFailed = ERR_CIPHER_ALLOC_FAILED, CipherAuthFailed = ERR_CIPHER_AUTH_FAILED, CipherBadInputData = ERR_CIPHER_BAD_INPUT_DATA, @@ -154,13 +223,6 @@ error_enum!( CipherHwAccelFailed = ERR_CIPHER_HW_ACCEL_FAILED, CipherInvalidContext = ERR_CIPHER_INVALID_CONTEXT, CipherInvalidPadding = ERR_CIPHER_INVALID_PADDING, - CmacHwAccelFailed = ERR_CMAC_HW_ACCEL_FAILED, - CtrDrbgEntropySourceFailed = ERR_CTR_DRBG_ENTROPY_SOURCE_FAILED, - CtrDrbgFileIoError = ERR_CTR_DRBG_FILE_IO_ERROR, - CtrDrbgInputTooBig = ERR_CTR_DRBG_INPUT_TOO_BIG, - CtrDrbgRequestTooBig = ERR_CTR_DRBG_REQUEST_TOO_BIG, - DesHwAccelFailed = ERR_DES_HW_ACCEL_FAILED, - DesInvalidInputLength = ERR_DES_INVALID_INPUT_LENGTH, DhmAllocFailed = ERR_DHM_ALLOC_FAILED, DhmBadInputData = ERR_DHM_BAD_INPUT_DATA, DhmCalcSecretFailed = ERR_DHM_CALC_SECRET_FAILED, @@ -181,51 +243,12 @@ error_enum!( EcpRandomFailed = ERR_ECP_RANDOM_FAILED, EcpSigLenMismatch = ERR_ECP_SIG_LEN_MISMATCH, EcpVerifyFailed = ERR_ECP_VERIFY_FAILED, - EntropyFileIoError = ERR_ENTROPY_FILE_IO_ERROR, - EntropyMaxSources = ERR_ENTROPY_MAX_SOURCES, - EntropyNoSourcesDefined = ERR_ENTROPY_NO_SOURCES_DEFINED, - EntropyNoStrongSource = ERR_ENTROPY_NO_STRONG_SOURCE, - EntropySourceFailed = ERR_ENTROPY_SOURCE_FAILED, - GcmAuthFailed = ERR_GCM_AUTH_FAILED, - GcmBadInput = ERR_GCM_BAD_INPUT, - GcmHwAccelFailed = ERR_GCM_HW_ACCEL_FAILED, HkdfBadInputData = ERR_HKDF_BAD_INPUT_DATA, - HmacDrbgEntropySourceFailed = ERR_HMAC_DRBG_ENTROPY_SOURCE_FAILED, - HmacDrbgFileIoError = ERR_HMAC_DRBG_FILE_IO_ERROR, - HmacDrbgInputTooBig = ERR_HMAC_DRBG_INPUT_TOO_BIG, - HmacDrbgRequestTooBig = ERR_HMAC_DRBG_REQUEST_TOO_BIG, - Md2HwAccelFailed = ERR_MD2_HW_ACCEL_FAILED, - Md4HwAccelFailed = ERR_MD4_HW_ACCEL_FAILED, - Md5HwAccelFailed = ERR_MD5_HW_ACCEL_FAILED, MdAllocFailed = ERR_MD_ALLOC_FAILED, MdBadInputData = ERR_MD_BAD_INPUT_DATA, MdFeatureUnavailable = ERR_MD_FEATURE_UNAVAILABLE, MdFileIoError = ERR_MD_FILE_IO_ERROR, MdHwAccelFailed = ERR_MD_HW_ACCEL_FAILED, - MpiAllocFailed = ERR_MPI_ALLOC_FAILED, - MpiBadInputData = ERR_MPI_BAD_INPUT_DATA, - MpiBufferTooSmall = ERR_MPI_BUFFER_TOO_SMALL, - MpiDivisionByZero = ERR_MPI_DIVISION_BY_ZERO, - MpiFileIoError = ERR_MPI_FILE_IO_ERROR, - MpiInvalidCharacter = ERR_MPI_INVALID_CHARACTER, - MpiNegativeValue = ERR_MPI_NEGATIVE_VALUE, - MpiNotAcceptable = ERR_MPI_NOT_ACCEPTABLE, - NetAcceptFailed = ERR_NET_ACCEPT_FAILED, - NetBadInputData = ERR_NET_BAD_INPUT_DATA, - NetBindFailed = ERR_NET_BIND_FAILED, - NetBufferTooSmall = ERR_NET_BUFFER_TOO_SMALL, - NetConnReset = ERR_NET_CONN_RESET, - NetConnectFailed = ERR_NET_CONNECT_FAILED, - NetInvalidContext = ERR_NET_INVALID_CONTEXT, - NetListenFailed = ERR_NET_LISTEN_FAILED, - NetPollFailed = ERR_NET_POLL_FAILED, - NetRecvFailed = ERR_NET_RECV_FAILED, - NetSendFailed = ERR_NET_SEND_FAILED, - NetSocketFailed = ERR_NET_SOCKET_FAILED, - NetUnknownHost = ERR_NET_UNKNOWN_HOST, - OidBufTooSmall = ERR_OID_BUF_TOO_SMALL, - OidNotFound = ERR_OID_NOT_FOUND, - PadlockDataMisaligned = ERR_PADLOCK_DATA_MISALIGNED, PemAllocFailed = ERR_PEM_ALLOC_FAILED, PemBadInputData = ERR_PEM_BAD_INPUT_DATA, PemFeatureUnavailable = ERR_PEM_FEATURE_UNAVAILABLE, @@ -258,10 +281,6 @@ error_enum!( Pkcs5FeatureUnavailable = ERR_PKCS5_FEATURE_UNAVAILABLE, Pkcs5InvalidFormat = ERR_PKCS5_INVALID_FORMAT, Pkcs5PasswordMismatch = ERR_PKCS5_PASSWORD_MISMATCH, - Poly1305BadInputData = ERR_POLY1305_BAD_INPUT_DATA, - Poly1305FeatureUnavailable = ERR_POLY1305_FEATURE_UNAVAILABLE, - Poly1305HwAccelFailed = ERR_POLY1305_HW_ACCEL_FAILED, - Ripemd160HwAccelFailed = ERR_RIPEMD160_HW_ACCEL_FAILED, RsaBadInputData = ERR_RSA_BAD_INPUT_DATA, RsaHwAccelFailed = ERR_RSA_HW_ACCEL_FAILED, RsaInvalidPadding = ERR_RSA_INVALID_PADDING, @@ -273,9 +292,6 @@ error_enum!( RsaRngFailed = ERR_RSA_RNG_FAILED, RsaUnsupportedOperation = ERR_RSA_UNSUPPORTED_OPERATION, RsaVerifyFailed = ERR_RSA_VERIFY_FAILED, - Sha1HwAccelFailed = ERR_SHA1_HW_ACCEL_FAILED, - Sha256HwAccelFailed = ERR_SHA256_HW_ACCEL_FAILED, - Sha512HwAccelFailed = ERR_SHA512_HW_ACCEL_FAILED, SslAllocFailed = ERR_SSL_ALLOC_FAILED, SslAsyncInProgress = ERR_SSL_ASYNC_IN_PROGRESS, SslBadHsCertificate = ERR_SSL_BAD_HS_CERTIFICATE, @@ -349,7 +365,135 @@ error_enum!( X509UnknownOid = ERR_X509_UNKNOWN_OID, X509UnknownSigAlg = ERR_X509_UNKNOWN_SIG_ALG, X509UnknownVersion = ERR_X509_UNKNOWN_VERSION, - XteaHwAccelFailed = ERR_XTEA_HW_ACCEL_FAILED, - XteaInvalidInputLength = ERR_XTEA_INVALID_INPUT_LENGTH, } ); + +error_enum!( + const MASK: c_int = 0x7F; + enum LoError { + AesBadInputData = ERR_AES_BAD_INPUT_DATA, + AesInvalidInputLength = ERR_AES_INVALID_INPUT_LENGTH, + AesInvalidKeyLength = ERR_AES_INVALID_KEY_LENGTH, + AriaBadInputData = ERR_ARIA_BAD_INPUT_DATA, + AriaInvalidInputLength = ERR_ARIA_INVALID_INPUT_LENGTH, + Asn1AllocFailed = ERR_ASN1_ALLOC_FAILED, + Asn1BufTooSmall = ERR_ASN1_BUF_TOO_SMALL, + Asn1InvalidData = ERR_ASN1_INVALID_DATA, + Asn1InvalidLength = ERR_ASN1_INVALID_LENGTH, + Asn1LengthMismatch = ERR_ASN1_LENGTH_MISMATCH, + Asn1OutOfData = ERR_ASN1_OUT_OF_DATA, + Asn1UnexpectedTag = ERR_ASN1_UNEXPECTED_TAG, + Base64BufferTooSmall = ERR_BASE64_BUFFER_TOO_SMALL, + Base64InvalidCharacter = ERR_BASE64_INVALID_CHARACTER, + CamelliaBadInputData = ERR_CAMELLIA_BAD_INPUT_DATA, + CamelliaInvalidInputLength = ERR_CAMELLIA_INVALID_INPUT_LENGTH, + CcmAuthFailed = ERR_CCM_AUTH_FAILED, + CcmBadInput = ERR_CCM_BAD_INPUT, + Chacha20BadInputData = ERR_CHACHA20_BAD_INPUT_DATA, + ChachapolyAuthFailed = ERR_CHACHAPOLY_AUTH_FAILED, + ChachapolyBadState = ERR_CHACHAPOLY_BAD_STATE, + CtrDrbgEntropySourceFailed = ERR_CTR_DRBG_ENTROPY_SOURCE_FAILED, + CtrDrbgFileIoError = ERR_CTR_DRBG_FILE_IO_ERROR, + CtrDrbgInputTooBig = ERR_CTR_DRBG_INPUT_TOO_BIG, + CtrDrbgRequestTooBig = ERR_CTR_DRBG_REQUEST_TOO_BIG, + DesInvalidInputLength = ERR_DES_INVALID_INPUT_LENGTH, + EntropyFileIoError = ERR_ENTROPY_FILE_IO_ERROR, + EntropyMaxSources = ERR_ENTROPY_MAX_SOURCES, + EntropyNoSourcesDefined = ERR_ENTROPY_NO_SOURCES_DEFINED, + EntropyNoStrongSource = ERR_ENTROPY_NO_STRONG_SOURCE, + EntropySourceFailed = ERR_ENTROPY_SOURCE_FAILED, + ErrorCorruptionDetected = ERR_ERROR_CORRUPTION_DETECTED, + ErrorGenericError = ERR_ERROR_GENERIC_ERROR, + GcmAuthFailed = ERR_GCM_AUTH_FAILED, + GcmBadInput = ERR_GCM_BAD_INPUT, + HmacDrbgEntropySourceFailed = ERR_HMAC_DRBG_ENTROPY_SOURCE_FAILED, + HmacDrbgFileIoError = ERR_HMAC_DRBG_FILE_IO_ERROR, + HmacDrbgInputTooBig = ERR_HMAC_DRBG_INPUT_TOO_BIG, + HmacDrbgRequestTooBig = ERR_HMAC_DRBG_REQUEST_TOO_BIG, + MpiAllocFailed = ERR_MPI_ALLOC_FAILED, + MpiBadInputData = ERR_MPI_BAD_INPUT_DATA, + MpiBufferTooSmall = ERR_MPI_BUFFER_TOO_SMALL, + MpiDivisionByZero = ERR_MPI_DIVISION_BY_ZERO, + MpiFileIoError = ERR_MPI_FILE_IO_ERROR, + MpiInvalidCharacter = ERR_MPI_INVALID_CHARACTER, + MpiNegativeValue = ERR_MPI_NEGATIVE_VALUE, + MpiNotAcceptable = ERR_MPI_NOT_ACCEPTABLE, + NetAcceptFailed = ERR_NET_ACCEPT_FAILED, + NetBadInputData = ERR_NET_BAD_INPUT_DATA, + NetBindFailed = ERR_NET_BIND_FAILED, + NetBufferTooSmall = ERR_NET_BUFFER_TOO_SMALL, + NetConnectFailed = ERR_NET_CONNECT_FAILED, + NetConnReset = ERR_NET_CONN_RESET, + NetInvalidContext = ERR_NET_INVALID_CONTEXT, + NetListenFailed = ERR_NET_LISTEN_FAILED, + NetPollFailed = ERR_NET_POLL_FAILED, + NetRecvFailed = ERR_NET_RECV_FAILED, + NetSendFailed = ERR_NET_SEND_FAILED, + NetSocketFailed = ERR_NET_SOCKET_FAILED, + NetUnknownHost = ERR_NET_UNKNOWN_HOST, + OidBufTooSmall = ERR_OID_BUF_TOO_SMALL, + OidNotFound = ERR_OID_NOT_FOUND, + PlatformFeatureUnsupported = ERR_PLATFORM_FEATURE_UNSUPPORTED, + PlatformHwAccelFailed = ERR_PLATFORM_HW_ACCEL_FAILED, + Poly1305BadInputData = ERR_POLY1305_BAD_INPUT_DATA, + Sha1BadInputData = ERR_SHA1_BAD_INPUT_DATA, + Sha256BadInputData = ERR_SHA256_BAD_INPUT_DATA, + Sha512BadInputData = ERR_SHA512_BAD_INPUT_DATA, + ThreadingBadInputData = ERR_THREADING_BAD_INPUT_DATA, + ThreadingMutexError = ERR_THREADING_MUTEX_ERROR, + } +); + +#[cfg(test)] +mod tests { + use super::{codes, Error, HiError, LoError}; + + #[test] + fn test_common_error_operations() { + let (hi, lo) = (codes::CipherAllocFailed, codes::AesBadInputData); + let (hi_only_error, lo_only_error, combined_error) = + (Error::HighLevel(hi), Error::LowLevel(lo), Error::HighAndLowLevel(hi, lo)); + assert_eq!(combined_error.high_level().unwrap(), hi); + assert_eq!(combined_error.low_level().unwrap(), lo); + assert_eq!(hi_only_error.to_int(), -24960); + assert_eq!(lo_only_error.to_int(), -33); + assert_eq!(combined_error.to_int(), hi_only_error.to_int() + lo_only_error.to_int()); + assert_eq!(codes::CipherAllocFailed | codes::AesBadInputData, combined_error); + assert_eq!(codes::AesBadInputData | codes::CipherAllocFailed, combined_error); + } + + #[test] + fn test_error_display() { + let (hi, lo) = (HiError::CipherAllocFailed, LoError::AesBadInputData); + let (hi_only_error, lo_only_error, combined_error) = + (Error::HighLevel(hi), Error::LowLevel(lo), Error::HighAndLowLevel(hi, lo)); + assert_eq!(format!("{}", hi_only_error), "mbedTLS error HiError :: CipherAllocFailed"); + assert_eq!(format!("{}", lo_only_error), "mbedTLS error LoError :: AesBadInputData"); + assert_eq!( + format!("{}", combined_error), + "(mbedTLS error HiError :: CipherAllocFailed, mbedTLS error LoError :: AesBadInputData)" + ); + } + + #[test] + fn test_error_from_int() { + // positive error code + assert_eq!(Error::from(0), Error::Other(0)); + assert_eq!(Error::from(1), Error::Other(1)); + // Lo, Hi, HiAndLo cases + assert_eq!(Error::from(-1), Error::LowLevel(LoError::ErrorGenericError)); + assert_eq!(Error::from(-0x80), Error::HighLevel(HiError::Unknown(-0x80))); + assert_eq!( + Error::from(-0x81), + Error::HighAndLowLevel(HiError::Unknown(-0x80), LoError::ErrorGenericError) + ); + assert_eq!( + Error::from(-24993), + Error::HighAndLowLevel(HiError::CipherAllocFailed, LoError::AesBadInputData) + ); + assert_eq!(Error::from(-24960), Error::HighLevel(HiError::CipherAllocFailed)); + assert_eq!(Error::from(-33), Error::LowLevel(LoError::AesBadInputData)); + // error code out of boundaries + assert_eq!(Error::from(-0x01FFFF), Error::Other(-0x01FFFF)); + } +} diff --git a/mbedtls/src/hash/mod.rs b/mbedtls/src/hash/mod.rs index c50f09911..0b71e1119 100644 --- a/mbedtls/src/hash/mod.rs +++ b/mbedtls/src/hash/mod.rs @@ -6,7 +6,7 @@ * option. This file may not be copied, modified, or distributed except * according to those terms. */ -use crate::error::{Error, IntoResult, Result}; +use crate::error::{codes, IntoResult, Result}; use mbedtls_sys::*; define!( @@ -97,7 +97,7 @@ impl Md { pub fn new(md: Type) -> Result { let md: MdInfo = match md.into() { Some(md) => md, - None => return Err(Error::MdBadInputData), + None => return Err(codes::MdBadInputData.into()), }; let mut ctx = Md::init(); @@ -117,7 +117,7 @@ impl Md { unsafe { let olen = (*self.inner.md_info).size as usize; if out.len() < olen { - return Err(Error::MdBadInputData); + return Err(codes::MdBadInputData.into()); } md_finish(&mut self.inner, out.as_mut_ptr()).into_result()?; Ok(olen) @@ -127,13 +127,13 @@ impl Md { pub fn hash(mdt: Type, data: &[u8], out: &mut [u8]) -> Result { let mdinfo: MdInfo = match mdt.into() { Some(md) => md, - None => return Err(Error::MdBadInputData), + None => return Err(codes::MdBadInputData.into()), }; unsafe { let olen = mdinfo.inner.size as usize; if out.len() < olen { - return Err(Error::MdBadInputData); + return Err(codes::MdBadInputData.into()); } md(mdinfo.inner, data.as_ptr(), data.len(), out.as_mut_ptr()).into_result()?; Ok(olen) @@ -150,7 +150,7 @@ impl Hmac { pub fn new(md: Type, key: &[u8]) -> Result { let md: MdInfo = match md.into() { Some(md) => md, - None => return Err(Error::MdBadInputData), + None => return Err(codes::MdBadInputData.into()), }; let mut ctx = Md::init(); @@ -170,7 +170,7 @@ impl Hmac { unsafe { let olen = (*self.ctx.inner.md_info).size as usize; if out.len() < olen { - return Err(Error::MdBadInputData); + return Err(codes::MdBadInputData.into()); } md_hmac_finish(&mut self.ctx.inner, out.as_mut_ptr()).into_result()?; Ok(olen) @@ -180,13 +180,13 @@ impl Hmac { pub fn hmac(md: Type, key: &[u8], data: &[u8], out: &mut [u8]) -> Result { let md: MdInfo = match md.into() { Some(md) => md, - None => return Err(Error::MdBadInputData), + None => return Err(codes::MdBadInputData.into()), }; unsafe { let olen = md.inner.size as usize; if out.len() < olen { - return Err(Error::MdBadInputData); + return Err(codes::MdBadInputData.into()); } md_hmac(md.inner, key.as_ptr(), key.len(), data.as_ptr(), data.len(), out.as_mut_ptr()).into_result()?; Ok(olen) @@ -221,7 +221,7 @@ impl Hkdf { pub fn hkdf(md: Type, salt: &[u8], ikm: &[u8], info: &[u8], okm: &mut [u8]) -> Result<()> { let md: MdInfo = match md.into() { Some(md) => md, - None => return Err(Error::MdBadInputData), + None => return Err(codes::MdBadInputData.into()), }; unsafe { @@ -265,7 +265,7 @@ impl Hkdf { pub fn hkdf_optional_salt(md: Type, maybe_salt: Option<&[u8]>, ikm: &[u8], info: &[u8], okm: &mut [u8]) -> Result<()> { let md: MdInfo = match md.into() { Some(md) => md, - None => return Err(Error::MdBadInputData), + None => return Err(codes::MdBadInputData.into()), }; unsafe { @@ -314,7 +314,7 @@ impl Hkdf { pub fn hkdf_extract(md: Type, maybe_salt: Option<&[u8]>, ikm: &[u8], prk: &mut [u8]) -> Result<()> { let md: MdInfo = match md.into() { Some(md) => md, - None => return Err(Error::MdBadInputData), + None => return Err(codes::MdBadInputData.into()), }; unsafe { @@ -360,7 +360,7 @@ impl Hkdf { pub fn hkdf_expand(md: Type, prk: &[u8], info: &[u8], okm: &mut [u8]) -> Result<()> { let md: MdInfo = match md.into() { Some(md) => md, - None => return Err(Error::MdBadInputData), + None => return Err(codes::MdBadInputData.into()), }; unsafe { @@ -382,7 +382,7 @@ impl Hkdf { pub fn pbkdf2_hmac(md: Type, password: &[u8], salt: &[u8], iterations: u32, key: &mut [u8]) -> Result<()> { let md: MdInfo = match md.into() { Some(md) => md, - None => return Err(Error::MdBadInputData), + None => return Err(codes::MdBadInputData.into()), }; unsafe { diff --git a/mbedtls/src/lib.rs b/mbedtls/src/lib.rs index 7a8620528..e43eadc52 100644 --- a/mbedtls/src/lib.rs +++ b/mbedtls/src/lib.rs @@ -7,7 +7,7 @@ * according to those terms. */ #![deny(warnings)] -#![allow(unused_doc_comments)] +#![allow(unused_doc_comments, ambiguous_glob_reexports)] // allow ambiguous glob reexports for now in autogenerated bindings. #![cfg_attr(not(feature = "std"), no_std)] #![cfg_attr(nightly, feature(doc_auto_cfg))] @@ -27,7 +27,7 @@ mod wrapper_macros; // API // ============== pub mod bignum; -mod error; +pub mod error; pub use crate::error::{Error, Result}; pub mod cipher; pub mod ecp; diff --git a/mbedtls/src/pk/dsa/mod.rs b/mbedtls/src/pk/dsa/mod.rs index 1fc1d82e8..98b7f1fcc 100644 --- a/mbedtls/src/pk/dsa/mod.rs +++ b/mbedtls/src/pk/dsa/mod.rs @@ -9,9 +9,9 @@ use crate::bignum::Mpi; use crate::hash::{MdInfo, Type as MdType}; use crate::pk::rfc6979::generate_rfc6979_nonce; -use crate::rng::Random; -use crate::{Error, Result}; +use crate::{error::codes, Result}; +use crate::rng::Random; use bit_vec::BitVec; use num_bigint::BigUint; use yasna::models::ObjectIdentifier; @@ -27,11 +27,11 @@ pub struct DsaParams { impl DsaParams { pub fn from_components(p: Mpi, q: Mpi, g: Mpi) -> Result { if g > p || q > p { - return Err(Error::PkBadInputData); + return Err(codes::PkBadInputData.into()); } if p.modulo(&q)? != Mpi::new(1)? { - return Err(Error::PkBadInputData); + return Err(codes::PkBadInputData.into()); } Ok(Self { p, q, g }) @@ -65,11 +65,11 @@ const DSA_OBJECT_IDENTIFIER: &[u64] = &[1, 2, 840, 10040, 4, 1]; impl DsaPublicKey { pub fn from_components(params: DsaParams, y: Mpi) -> Result { if y < Mpi::new(1)? || y >= params.p { - return Err(Error::PkBadInputData); + return Err(codes::PkBadInputData.into()); } // Verify that y is of order q modulo p if y.mod_exp(¶ms.q, ¶ms.p)? != Mpi::new(1)? { - return Err(Error::PkBadInputData); + return Err(codes::PkBadInputData.into()); } Ok(Self { params, y }) } @@ -93,9 +93,9 @@ impl DsaPublicKey { Ok((p, q, g, y)) }) }) - .map_err(|_| Error::PkInvalidPubkey)?; + .map_err(|_| codes::PkInvalidPubkey)?; - let y = yasna::parse_der(&y.to_bytes(), |r| r.read_biguint()).map_err(|_| Error::PkInvalidPubkey)?; + let y = yasna::parse_der(&y.to_bytes(), |r| r.read_biguint()).map_err(|_| codes::PkInvalidPubkey)?; let p = Mpi::from_binary(&p.to_bytes_be()).expect("Success"); let q = Mpi::from_binary(&q.to_bytes_be()).expect("Success"); @@ -140,7 +140,7 @@ impl DsaPublicKey { Ok((r, s)) }) }) - .map_err(|_| Error::X509InvalidSignature)?; + .map_err(|_| codes::X509InvalidSignature)?; let r = Mpi::from_binary(&r.to_bytes_be()).expect("Success"); let s = Mpi::from_binary(&s.to_bytes_be()).expect("Success"); @@ -152,14 +152,14 @@ impl DsaPublicKey { let zero = Mpi::new(0)?; if r <= &zero || s <= &zero { - return Err(Error::X509InvalidSignature); + return Err(codes::X509InvalidSignature.into()); } let p = &self.params.p; let q = &self.params.q; if r >= q || s >= q { - return Err(Error::X509InvalidSignature); + return Err(codes::X509InvalidSignature.into()); } let m = reduce_mod_q(pre_hashed_message, q)?; @@ -176,7 +176,7 @@ impl DsaPublicKey { let gsm_ysr = (&gsm * &ysr)?.modulo(p)?; if &gsm_ysr.modulo(q)? != r { - return Err(Error::X509InvalidSignature); + return Err(codes::X509InvalidSignature.into()); } Ok(()) @@ -225,7 +225,7 @@ fn encode_dsa_signature(r: &Mpi, s: &Mpi) -> Result> { impl DsaPrivateKey { pub fn from_components(params: DsaParams, x: Mpi) -> Result { if x <= Mpi::new(1)? || x >= params.q { - return Err(Error::PkBadInputData); + return Err(codes::PkBadInputData.into()); } Ok(Self { params, x }) } @@ -257,9 +257,9 @@ impl DsaPrivateKey { Ok((p, q, g, x)) }) }) - .map_err(|_| Error::PkInvalidPubkey)?; + .map_err(|_| codes::PkInvalidPubkey)?; - let x = yasna::parse_der(&x, |r| r.read_biguint()).map_err(|_| Error::PkInvalidPubkey)?; + let x = yasna::parse_der(&x, |r| r.read_biguint()).map_err(|_| codes::PkInvalidPubkey)?; let p = Mpi::from_binary(&p.to_bytes_be()).expect("Success"); let q = Mpi::from_binary(&q.to_bytes_be()).expect("Success"); @@ -345,7 +345,7 @@ impl DsaPrivateKey { let zero = Mpi::new(0)?; if r == zero || s == zero { - return Err(Error::MpiBadInputData); + return Err(codes::MpiBadInputData.into()); } encode_dsa_signature(&r, &s) } diff --git a/mbedtls/src/pk/ec.rs b/mbedtls/src/pk/ec.rs index 2cc8a0151..b9025870b 100644 --- a/mbedtls/src/pk/ec.rs +++ b/mbedtls/src/pk/ec.rs @@ -9,7 +9,7 @@ use mbedtls_sys::ECDSA_MAX_LEN as MBEDTLS_ECDSA_MAX_LEN; use mbedtls_sys::*; -use crate::error::{Error, IntoResult, Result}; +use crate::error::{codes, IntoResult, Result}; define!( #[c_ty(ecp_group_id)] @@ -80,7 +80,7 @@ define!( impl Ecdh { pub fn from_keys(private: &EcpKeypair, public: &EcpKeypair) -> Result { if public.inner.grp.id == ECP_DP_NONE || public.inner.grp.id != private.inner.grp.id { - return Err(Error::EcpBadInputData); + return Err(codes::EcpBadInputData.into()); } let mut ret = Self::init(); diff --git a/mbedtls/src/pk/mod.rs b/mbedtls/src/pk/mod.rs index 63c77575e..f2e93332f 100644 --- a/mbedtls/src/pk/mod.rs +++ b/mbedtls/src/pk/mod.rs @@ -12,7 +12,7 @@ use mbedtls_sys::*; use mbedtls_sys::types::raw_types::c_void; -use crate::error::{Error, IntoResult, Result}; +use crate::error::{codes, Error, IntoResult, Result}; use crate::hash::Type as MdType; use crate::private::UnsafeFrom; use crate::rng::Random; @@ -463,7 +463,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn custom_algo_id(&self) -> Result<&[u64]> { if self.pk_type() != Type::Custom { - return Err(Error::PkInvalidAlg); + return Err(codes::PkInvalidAlg.into()); } unsafe { @@ -474,7 +474,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn custom_public_key(&self) -> Result<&[u8]> { if self.pk_type() != Type::Custom { - return Err(Error::PkInvalidAlg); + return Err(codes::PkInvalidAlg.into()); } let ctx = self.inner.pk_ctx as *const CustomPkContext; @@ -483,13 +483,13 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn custom_private_key(&self) -> Result<&[u8]> { if self.pk_type() != Type::Custom { - return Err(Error::PkInvalidAlg); + return Err(codes::PkInvalidAlg.into()); } let ctx = self.inner.pk_ctx as *const CustomPkContext; unsafe { if (*ctx).sk.len() == 0 { - return Err(Error::PkTypeMismatch); + return Err(codes::PkTypeMismatch.into()); } Ok(&(*ctx).sk) } @@ -537,7 +537,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn curve(&self) -> Result { match self.pk_type() { Type::Eckey | Type::EckeyDh | Type::Ecdsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } unsafe { Ok((*(self.inner.pk_ctx as *const ecp_keypair)).grp.id.into()) } @@ -556,14 +556,14 @@ Please use `private_from_ec_scalar_with_rng` instead." EcGroupId::SecP256R1 => Ok(vec![1, 2, 840, 10045, 3, 1, 7]), EcGroupId::SecP384R1 => Ok(vec![1, 3, 132, 0, 34]), EcGroupId::SecP521R1 => Ok(vec![1, 3, 132, 0, 35]), - _ => Err(Error::OidNotFound), + _ => Err(codes::OidNotFound.into()), } } pub fn ec_group(&self) -> Result { match self.pk_type() { Type::Eckey | Type::EckeyDh | Type::Ecdsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } match self.curve()? { @@ -588,7 +588,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn ec_public(&self) -> Result { match self.pk_type() { Type::Eckey | Type::EckeyDh | Type::Ecdsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } let q = &unsafe { (*(self.inner.pk_ctx as *const ecp_keypair)).Q }; @@ -598,7 +598,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn ec_private(&self) -> Result { match self.pk_type() { Type::Eckey | Type::EckeyDh | Type::Ecdsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } let d = &unsafe { (*(self.inner.pk_ctx as *const ecp_keypair)).d }; @@ -608,7 +608,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn rsa_public_modulus(&self) -> Result { match self.pk_type() { Type::Rsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } let mut n = Mpi::new(0)?; @@ -631,7 +631,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn rsa_private_prime1(&self) -> Result { match self.pk_type() { Type::Rsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } let mut p = Mpi::new(0)?; @@ -654,7 +654,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn rsa_private_prime2(&self) -> Result { match self.pk_type() { Type::Rsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } let mut q = Mpi::new(0)?; @@ -677,7 +677,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn rsa_private_exponent(&self) -> Result { match self.pk_type() { Type::Rsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } let mut d = Mpi::new(0)?; @@ -700,7 +700,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn rsa_crt_dp(&self) -> Result { match self.pk_type() { Type::Rsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } let mut dp = Mpi::new(0)?; @@ -721,7 +721,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn rsa_crt_dq(&self) -> Result { match self.pk_type() { Type::Rsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } let mut dq = Mpi::new(0)?; @@ -742,7 +742,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn rsa_crt_qp(&self) -> Result { match self.pk_type() { Type::Rsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } let mut qp = Mpi::new(0)?; @@ -763,7 +763,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn rsa_public_exponent(&self) -> Result { match self.pk_type() { Type::Rsa => {} - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } let mut e: [u8; 4] = [0, 0, 0, 0]; @@ -797,14 +797,14 @@ Please use `private_from_ec_scalar_with_rng` instead." if unsafe { (*ctx).padding == RAW_RSA_DECRYPT } { let olen = self.len() / 8; if plain.len() < olen { - return Err(Error::RsaOutputTooLarge); + return Err(codes::RsaOutputTooLarge.into()); } // Don't process outside of {2, ..., n-2} let nm1 = self.rsa_public_modulus()?.sub(&Mpi::new(1)?)?; let c_mpi = Mpi::from_binary(cipher)?; if c_mpi <= Mpi::new(1).unwrap() || c_mpi >= nm1 { - return Err(Error::MpiBadInputData); + return Err(codes::MpiBadInputData.into()); } unsafe { @@ -843,11 +843,11 @@ Please use `private_from_ec_scalar_with_rng` instead." label: &[u8], ) -> Result { if self.pk_type() != Type::Rsa { - return Err(Error::PkTypeMismatch); + return Err(codes::PkTypeMismatch.into()); } let ctx = self.inner.pk_ctx as *mut rsa_context; if unsafe { (*ctx).padding != RSA_PKCS_V21 } { - return Err(Error::RsaInvalidPadding); + return Err(codes::RsaInvalidPadding.into()); } let mut ret = 0usize; @@ -899,15 +899,15 @@ Please use `private_from_ec_scalar_with_rng` instead." label: &[u8], ) -> Result { if self.pk_type() != Type::Rsa { - return Err(Error::PkTypeMismatch); + return Err(codes::PkTypeMismatch.into()); } let ctx = self.inner.pk_ctx as *mut rsa_context; if unsafe { (*ctx).padding != RSA_PKCS_V21 } { - return Err(Error::RsaInvalidPadding); + return Err(codes::RsaInvalidPadding.into()); } let olen = self.len() / 8; if cipher.len() < olen { - return Err(Error::RsaOutputTooLarge); + return Err(codes::RsaOutputTooLarge.into()); } unsafe { @@ -943,21 +943,21 @@ Please use `private_from_ec_scalar_with_rng` instead." // If hash or sig are allowed with size 0 (&[]) then mbedtls will attempt to // auto-detect size and cause an invalid write. if hash.len() == 0 || sig.len() == 0 { - return Err(Error::PkBadInputData); + return Err(codes::PkBadInputData.into()); } match self.pk_type() { Type::Rsa | Type::RsaAlt | Type::RsassaPss => { if sig.len() < (self.len() / 8) { - return Err(Error::PkSigLenMismatch); + return Err(codes::PkSigLenMismatch.into()); } } Type::Eckey | Type::Ecdsa => { if sig.len() < ECDSA_MAX_LEN { - return Err(Error::PkSigLenMismatch); + return Err(codes::PkSigLenMismatch.into()); } } - _ => return Err(Error::PkSigLenMismatch), + _ => return Err(codes::PkSigLenMismatch.into()), } let mut ret = 0usize; unsafe { @@ -980,14 +980,14 @@ Please use `private_from_ec_scalar_with_rng` instead." // If hash or sig are allowed with size 0 (&[]) then mbedtls will attempt to // auto-detect size and cause an invalid write. if hash.len() == 0 || sig.len() == 0 { - return Err(Error::PkBadInputData); + return Err(codes::PkBadInputData.into()); } use crate::rng::RngCallbackMut; if self.pk_type() == Type::Ecdsa || self.pk_type() == Type::Eckey { if sig.len() < ECDSA_MAX_LEN { - return Err(Error::PkSigLenMismatch); + return Err(codes::PkSigLenMismatch.into()); } // RFC 6979 signature scheme @@ -1017,7 +1017,7 @@ Please use `private_from_ec_scalar_with_rng` instead." } else if self.pk_type() == Type::Rsa { // Reject sign_deterministic being use for PSS if unsafe { (*(self.inner.pk_ctx as *mut rsa_context)).padding } != RSA_PKCS_V15 { - return Err(Error::PkInvalidAlg); + return Err(codes::PkInvalidAlg.into()); } // This is a PKCSv1.5 signature which is already deterministic; just pass it to @@ -1025,7 +1025,7 @@ Please use `private_from_ec_scalar_with_rng` instead." return self.sign(md, hash, sig, rng); } else { // Some non-deterministic scheme - return Err(Error::PkInvalidAlg); + return Err(codes::PkInvalidAlg.into()); } } @@ -1033,7 +1033,7 @@ Please use `private_from_ec_scalar_with_rng` instead." // If hash or sig are allowed with size 0 (&[]) then mbedtls will attempt to // auto-detect size and cause an invalid write. if hash.len() == 0 || sig.len() == 0 { - return Err(Error::PkBadInputData); + return Err(codes::PkBadInputData.into()); } unsafe { @@ -1056,13 +1056,13 @@ Please use `private_from_ec_scalar_with_rng` instead." )?; ecdh.calc_secret(shared, rng) }, - _ => return Err(Error::PkTypeMismatch), + _ => return Err(codes::PkTypeMismatch.into()), } } pub fn write_private_der<'buf>(&mut self, buf: &'buf mut [u8]) -> Result> { match unsafe { pk_write_key_der(&mut self.inner, buf.as_mut_ptr(), buf.len()).into_result() } { - Err(Error::Asn1BufTooSmall) => Ok(None), + Err(e) if e.low_level() == Some(codes::Asn1BufTooSmall) => Ok(None), Err(e) => Err(e), Ok(n) => Ok(Some(&buf[buf.len() - (n as usize)..])), } @@ -1074,7 +1074,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn write_private_pem<'buf>(&mut self, buf: &'buf mut [u8]) -> Result> { match unsafe { pk_write_key_pem(&mut self.inner, buf.as_mut_ptr(), buf.len()).into_result() } { - Err(Error::Base64BufferTooSmall) => Ok(None), + Err(e) if e.low_level() == Some(codes::Base64BufferTooSmall) => Ok(None), Err(e) => Err(e), Ok(n) => Ok(Some(&buf[buf.len() - (n as usize)..])), } @@ -1091,7 +1091,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn write_public_der<'buf>(&mut self, buf: &'buf mut [u8]) -> Result> { match unsafe { pk_write_pubkey_der(&mut self.inner, buf.as_mut_ptr(), buf.len()).into_result() } { - Err(Error::Asn1BufTooSmall) => Ok(None), + Err(e) if e.low_level() == Some(codes::Asn1BufTooSmall) => Ok(None), Err(e) => Err(e), Ok(n) => Ok(Some(&buf[buf.len() - (n as usize)..])), } @@ -1103,7 +1103,7 @@ Please use `private_from_ec_scalar_with_rng` instead." pub fn write_public_pem<'buf>(&mut self, buf: &'buf mut [u8]) -> Result> { match unsafe { pk_write_pubkey_pem(&mut self.inner, buf.as_mut_ptr(), buf.len()).into_result() } { - Err(Error::Base64BufferTooSmall) => Ok(None), + Err(e) if e.low_level() == Some(codes::Base64BufferTooSmall) => Ok(None), Err(e) => Err(e), Ok(n) => Ok(Some(&buf[buf.len() - (n as usize)..])), } @@ -1339,30 +1339,33 @@ iy6KC991zzvaWY/Ys+q/84Afqa+0qJKQnPuy/7F5GkVdQA/lfbhi .unwrap(); pk.verify(digest, data, &signature[0..len]).unwrap(); - assert_eq!(pk.verify(digest, data, &[]).unwrap_err(), Error::PkBadInputData); - assert_eq!(pk.verify(digest, &[], &signature[0..len]).unwrap_err(), Error::PkBadInputData); + assert_eq!(pk.verify(digest, data, &[]).unwrap_err(), codes::PkBadInputData.into()); + assert_eq!( + pk.verify(digest, &[], &signature[0..len]).unwrap_err(), + codes::PkBadInputData.into() + ); let mut dummy_sig = []; assert_eq!( pk.sign(digest, data, &mut dummy_sig, &mut crate::test_support::rand::test_rng()) .unwrap_err(), - Error::PkBadInputData + codes::PkBadInputData.into() ); assert_eq!( pk.sign(digest, &[], &mut signature, &mut crate::test_support::rand::test_rng()) .unwrap_err(), - Error::PkBadInputData + codes::PkBadInputData.into() ); assert_eq!( pk.sign_deterministic(digest, data, &mut dummy_sig, &mut crate::test_support::rand::test_rng()) .unwrap_err(), - Error::PkBadInputData + codes::PkBadInputData.into() ); assert_eq!( pk.sign_deterministic(digest, &[], &mut signature, &mut crate::test_support::rand::test_rng()) .unwrap_err(), - Error::PkBadInputData + codes::PkBadInputData.into() ); } } @@ -1465,7 +1468,7 @@ iy6KC991zzvaWY/Ys+q/84Afqa+0qJKQnPuy/7F5GkVdQA/lfbhi assert_eq!( pk.encrypt(b"test", &mut cipher, &mut crate::test_support::rand::test_rng()) .unwrap_err(), - Error::RsaInvalidPadding + codes::RsaInvalidPadding.into() ); } @@ -1504,7 +1507,7 @@ iy6KC991zzvaWY/Ys+q/84Afqa+0qJKQnPuy/7F5GkVdQA/lfbhi b"WRONG_LABEL" ) .unwrap_err(), - Error::RsaInvalidPadding + codes::RsaInvalidPadding.into() ); } @@ -1520,7 +1523,7 @@ iy6KC991zzvaWY/Ys+q/84Afqa+0qJKQnPuy/7F5GkVdQA/lfbhi assert_eq!( pk.sign(Type::Sha256, data, &mut signature, &mut crate::test_support::rand::test_rng()) .unwrap_err(), - Error::RsaInvalidPadding + codes::RsaInvalidPadding.into() ); } @@ -1543,7 +1546,7 @@ iy6KC991zzvaWY/Ys+q/84Afqa+0qJKQnPuy/7F5GkVdQA/lfbhi }); assert_eq!( pk.verify(digest, data, &signature[0..len]).unwrap_err(), - Error::RsaInvalidPadding + codes::RsaInvalidPadding.into() ); } @@ -1604,7 +1607,7 @@ iy6KC991zzvaWY/Ys+q/84Afqa+0qJKQnPuy/7F5GkVdQA/lfbhi const LEN: usize = 256; // Decrypting anything out of {2, n-2} should fail - let expected_err = Error::MpiBadInputData; + let expected_err = codes::MpiBadInputData; let mut pt = [0x00; LEN]; @@ -1617,7 +1620,7 @@ iy6KC991zzvaWY/Ys+q/84Afqa+0qJKQnPuy/7F5GkVdQA/lfbhi for c in [_0, _1, nm1, n] { let ct = c.to_binary_padded(LEN).unwrap(); let l = pk.decrypt(&ct, &mut pt, rng); - assert_eq!(l.unwrap_err(), expected_err); + assert_eq!(l.unwrap_err(), expected_err.into()); } for c in [_2, nm2] { let ct = c.to_binary_padded(LEN).unwrap(); @@ -1669,7 +1672,7 @@ iy6KC991zzvaWY/Ys+q/84Afqa+0qJKQnPuy/7F5GkVdQA/lfbhi Ok(_) => panic!("expected an error, got a Pk"), Err(e) => e, }; - assert_eq!(err, Error::RsaKeyCheckFailed); + assert_eq!(err, codes::RsaKeyCheckFailed.into()); } #[test] diff --git a/mbedtls/src/pkcs12/mod.rs b/mbedtls/src/pkcs12/mod.rs index 54e5a82fa..e33698037 100644 --- a/mbedtls/src/pkcs12/mod.rs +++ b/mbedtls/src/pkcs12/mod.rs @@ -32,10 +32,10 @@ use yasna::{ASN1Result, BERDecodable, BERReader, BERReaderSeq, Tag}; use crate::alloc::Box as MbedtlsBox; use crate::cipher::raw::{CipherId, CipherMode}; use crate::cipher::{Cipher, Decryption, Fresh, Traditional}; +use crate::error::{codes, Error as MbedtlsError}; use crate::hash::{pbkdf_pkcs12, Hmac, MdInfo, Type as MdType}; use crate::pk::Pk; use crate::x509::Certificate; -use crate::Error as MbedtlsError; // Constants for various object identifiers used in PKCS12: @@ -686,7 +686,7 @@ impl Pfx { let md_info: MdInfo = match md.into() { Some(md) => md, - None => return Err(Pkcs12Error::from(MbedtlsError::MdBadInputData)), + None => return Err(Pkcs12Error::from(MbedtlsError::from(codes::MdBadInputData))), }; if stored_mac.len() != md_info.size() { @@ -861,6 +861,7 @@ impl BERDecodable for Pfx { #[cfg(test)] mod tests { + use crate::error::{codes, Error}; use crate::mbedtls::pkcs12::{ASN1Error, ASN1ErrorKind, Pfx, Pkcs12Error}; #[test] @@ -1024,7 +1025,10 @@ mod tests { let pfx = parsed_pfx.decrypt(&wrong_password, None); assert!(pfx.is_err()); - assert_eq!(pfx.unwrap_err(), Pkcs12Error::Crypto(crate::Error::CipherInvalidPadding)); + assert_eq!( + pfx.unwrap_err(), + Pkcs12Error::Crypto(Error::from(codes::CipherInvalidPadding)) + ); let pfx = parsed_pfx.decrypt(&wrong_password_correct_padding, None); assert!(pfx.is_err()); diff --git a/mbedtls/src/private.rs b/mbedtls/src/private.rs index 33df9b6f9..186b552e7 100644 --- a/mbedtls/src/private.rs +++ b/mbedtls/src/private.rs @@ -13,7 +13,7 @@ use mbedtls_sys::types::raw_types::c_char; use mbedtls_sys::types::raw_types::{c_int, c_uchar}; use mbedtls_sys::types::size_t; -use crate::error::{Error, IntoResult, Result}; +use crate::error::{Error, HiError, IntoResult, LoError, Result}; pub trait UnsafeFrom where @@ -34,18 +34,25 @@ where const MAX_VECTOR_ALLOCATION: usize = 4 * 1024 * 1024; let mut vec = Vec::with_capacity(2048 /* big because of bug in x509write */); + + let is_buf_too_small = |e: &Error| match (e.high_level(), e.low_level()) { + (Some(HiError::EcpBufferTooSmall | HiError::SslBufferTooSmall | HiError::X509BufferTooSmall), _) + | ( + _, + Some( + LoError::Asn1BufTooSmall + | LoError::Base64BufferTooSmall + | LoError::MpiBufferTooSmall + | LoError::NetBufferTooSmall + | LoError::OidBufTooSmall, + ), + ) => true, + _ => false, + }; + loop { match f(vec.as_mut_ptr(), vec.capacity()).into_result() { - Err(Error::Asn1BufTooSmall) - | Err(Error::Base64BufferTooSmall) - | Err(Error::EcpBufferTooSmall) - | Err(Error::MpiBufferTooSmall) - | Err(Error::NetBufferTooSmall) - | Err(Error::OidBufTooSmall) - | Err(Error::SslBufferTooSmall) - | Err(Error::X509BufferTooSmall) - if vec.capacity() < MAX_VECTOR_ALLOCATION => - { + Err(e) if is_buf_too_small(&e) && vec.capacity() < MAX_VECTOR_ALLOCATION => { let cap = vec.capacity(); vec.reserve(cap * 2) } diff --git a/mbedtls/src/ssl/async_io.rs b/mbedtls/src/ssl/async_io.rs index eca3d84ae..24997e9f1 100644 --- a/mbedtls/src/ssl/async_io.rs +++ b/mbedtls/src/ssl/async_io.rs @@ -9,7 +9,7 @@ #![cfg(all(feature = "std", feature = "async"))] use crate::{ - error::{Error, Result}, + error::{codes, Error, Result}, ssl::{ context::Context, io::{IoCallback, IoCallbackUnsafe}, @@ -47,17 +47,17 @@ impl<'a, 'b, 'c, IO: AsyncRead + AsyncWrite + std::marker::Unpin + 'static> IoCa let io = Pin::new(&mut self.1); match io.poll_read(self.0, &mut buf) { Poll::Ready(Ok(())) => Ok(buf.filled().len()), - Poll::Ready(Err(_)) => Err(Error::NetRecvFailed), - Poll::Pending => Err(Error::SslWantRead), + Poll::Ready(Err(_)) => Err(codes::NetRecvFailed.into()), + Poll::Pending => Err(codes::SslWantRead.into()), } } fn send(&mut self, buf: &[u8]) -> Result { let io = Pin::new(&mut self.1); match io.poll_write(self.0, buf) { - Poll::Ready(Err(_)) => Err(Error::NetSendFailed), + Poll::Ready(Err(_)) => Err(codes::NetSendFailed.into()), Poll::Ready(Ok(n)) => Ok(n), - Poll::Pending => Err(Error::SslWantWrite), + Poll::Pending => Err(codes::SslWantWrite.into()), } } } @@ -76,11 +76,11 @@ impl Context { fn poll(mut self: Pin<&mut Self>, ctx: &mut TaskContext) -> std::task::Poll { self.0 .with_bio_async(ctx, |ssl_ctx| match ssl_ctx.handshake() { - Err(Error::SslWantRead) | Err(Error::SslWantWrite) => Poll::Pending, + Err(e) if matches!(e.high_level(), Some(codes::SslWantRead | codes::SslWantWrite)) => Poll::Pending, Err(e) => Poll::Ready(Err(e)), Ok(()) => Poll::Ready(Ok(())), }) - .unwrap_or(Poll::Ready(Err(Error::NetSendFailed))) + .unwrap_or(Poll::Ready(Err(codes::NetSendFailed.into()))) } } @@ -100,15 +100,15 @@ where } self.with_bio_async(cx, |ssl_ctx| match ssl_ctx.recv(buf.initialize_unfilled()) { - Err(Error::SslPeerCloseNotify) => Poll::Ready(Ok(())), - Err(Error::SslWantRead) => Poll::Pending, + Err(e) if e.high_level() == Some(codes::SslPeerCloseNotify) => Poll::Ready(Ok(())), + Err(e) if e.high_level() == Some(codes::SslWantRead) => Poll::Pending, Err(e) => Poll::Ready(Err(crate::private::error_to_io_error(e))), Ok(i) => { buf.advance(i); Poll::Ready(Ok(())) } }) - .unwrap_or_else(|| Poll::Ready(Err(crate::private::error_to_io_error(Error::NetRecvFailed)))) + .unwrap_or_else(|| Poll::Ready(Err(crate::private::error_to_io_error(Error::from(codes::NetRecvFailed))))) } } @@ -122,12 +122,12 @@ where } self.with_bio_async(cx, |ssl_ctx| match ssl_ctx.async_write(buf) { - Err(Error::SslPeerCloseNotify) => Poll::Ready(Ok(0)), - Err(Error::SslWantWrite) => Poll::Pending, + Err(e) if e.high_level() == Some(codes::SslPeerCloseNotify) => Poll::Ready(Ok(0)), + Err(e) if e.high_level() == Some(codes::SslWantWrite) => Poll::Pending, Err(e) => Poll::Ready(Err(crate::private::error_to_io_error(e))), Ok(i) => Poll::Ready(Ok(i)), }) - .unwrap_or_else(|| Poll::Ready(Err(crate::private::error_to_io_error(Error::NetSendFailed)))) + .unwrap_or_else(|| Poll::Ready(Err(crate::private::error_to_io_error(Error::from(codes::NetSendFailed))))) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { @@ -137,9 +137,9 @@ where match self .with_bio_async(cx, Context::flush_output) - .unwrap_or(Err(Error::NetSendFailed)) + .unwrap_or(Err(codes::NetSendFailed.into())) { - Err(Error::SslWantWrite) => Poll::Pending, + Err(e) if e.high_level() == Some(codes::SslWantWrite) => Poll::Pending, Err(e) => Poll::Ready(Err(crate::private::error_to_io_error(e))), Ok(()) => Poll::Ready(Ok(())), } @@ -152,9 +152,9 @@ where match self .with_bio_async(cx, Context::close_notify) - .unwrap_or(Err(Error::NetSendFailed)) + .unwrap_or(Err(codes::NetSendFailed.into())) { - Err(Error::SslWantRead) | Err(Error::SslWantWrite) => Poll::Pending, + Err(e) if matches!(e.high_level(), Some(codes::SslWantRead | codes::SslWantWrite)) => Poll::Pending, Err(e) => { self.drop_io(); Poll::Ready(Err(crate::private::error_to_io_error(e))) diff --git a/mbedtls/src/ssl/config.rs b/mbedtls/src/ssl/config.rs index 4b6ab4cb2..4f9d05e2f 100644 --- a/mbedtls/src/ssl/config.rs +++ b/mbedtls/src/ssl/config.rs @@ -20,7 +20,7 @@ use mbedtls_sys::*; use crate::alloc::List as MbedtlsList; #[cfg(not(feature = "std"))] use crate::alloc_prelude::*; -use crate::error::{Error, IntoResult, Result}; +use crate::error::{IntoResult, Result}; use crate::pk::dhparam::Dhm; use crate::pk::Pk; use crate::private::UnsafeFrom; @@ -119,7 +119,7 @@ impl NullTerminatedStrList { for item in list { ret.c.push( ::std::ffi::CString::new(*item) - .map_err(|_| Error::SslBadInputData)? + .map_err(|_| crate::error::codes::SslBadInputData)? .into_raw(), ); } @@ -267,7 +267,7 @@ impl Config { Version::Tls1_1 => 2, Version::Tls1_2 => 3, _ => { - return Err(Error::SslBadHsProtocolVersion); + return Err(crate::error::codes::SslBadHsProtocolVersion.into()); } }; @@ -282,7 +282,7 @@ impl Config { Version::Tls1_1 => 2, Version::Tls1_2 => 3, _ => { - return Err(Error::SslBadHsProtocolVersion); + return Err(crate::error::codes::SslBadHsProtocolVersion.into()); } }; unsafe { ssl_conf_max_version(self.into(), 3, minor) }; @@ -323,7 +323,7 @@ impl Config { pub fn push_cert(&mut self, own_cert: Arc>, own_pk: Arc) -> Result<()> { if own_cert.is_empty() { - return Err(Error::SslBadInputData); + return Err(crate::error::codes::SslBadInputData.into()); } // Need to ensure own_cert/pk_key outlive the config. self.own_cert.push(own_cert.clone()); diff --git a/mbedtls/src/ssl/context.rs b/mbedtls/src/ssl/context.rs index 349a34e99..2f180e8eb 100644 --- a/mbedtls/src/ssl/context.rs +++ b/mbedtls/src/ssl/context.rs @@ -17,7 +17,7 @@ use mbedtls_sys::*; use crate::alloc::List as MbedtlsList; #[cfg(not(feature = "std"))] use crate::alloc_prelude::*; -use crate::error::{Error, IntoResult, Result}; +use crate::error::{codes, Error, IntoResult, Result}; use crate::pk::Pk; use crate::private::UnsafeFrom; use crate::ssl::config::{AuthMode, Config, Version}; @@ -243,7 +243,7 @@ impl Context { // c-mbedtls's buffer, so we need to return size of bytes that has been buffered. // Since we know before this call `out_left` was 0, all buffer (with in the MBEDTLS_SSL_OUT_CONTENT_LEN part) is // buffered - Err(Error::SslWantWrite) => Ok(std::cmp::min( + Err(e) if e.high_level() == Some(codes::SslWantWrite) => Ok(std::cmp::min( unsafe { ssl_get_max_out_record_payload((&*self).into()).into_result()? as usize }, buf.len(), )), @@ -305,9 +305,8 @@ impl Context { pub fn handshake(&mut self) -> Result<()> { match self.inner_handshake() { Ok(()) => Ok(()), - Err(Error::SslWantRead) => Err(Error::SslWantRead), - Err(Error::SslWantWrite) => Err(Error::SslWantWrite), - Err(Error::SslHelloVerifyRequired) => { + Err(e) if matches!(e.high_level(), Some(codes::SslWantRead | codes::SslWantWrite)) => Err(e), + Err(e) if matches!(e.high_level(), Some(codes::SslHelloVerifyRequired)) => { unsafe { // `ssl_session_reset` resets the client ID but the user will call handshake // again in this case and the client ID is required for a DTLS connection setup @@ -325,7 +324,7 @@ impl Context { self.set_client_transport_id(&client_id)?; } } - Err(Error::SslHelloVerifyRequired) + Err(codes::SslHelloVerifyRequired.into()) } Err(e) => { self.close(); @@ -349,7 +348,7 @@ impl Context { #[cfg(not(feature = "std"))] fn set_hostname(&mut self, hostname: Option<&str>) -> Result<()> { match hostname { - Some(_) => Err(Error::SslBadInputData), + Some(_) => Err(codes::SslBadInputData.into()), None => Ok(()), } } @@ -357,7 +356,7 @@ impl Context { #[cfg(feature = "std")] fn set_hostname(&mut self, hostname: Option<&str>) -> Result<()> { if let Some(s) = hostname { - let cstr = ::std::ffi::CString::new(s).map_err(|_| Error::SslBadInputData)?; + let cstr = ::std::ffi::CString::new(s).map_err(|_| Error::from(codes::SslBadInputData))?; unsafe { ssl_set_hostname(self.into(), cstr.as_ptr()).into_result().map(|_| ()) } } else { Ok(()) @@ -442,7 +441,7 @@ impl Context { /// pub fn ciphersuite(&self) -> Result { if self.handle().session.is_null() { - return Err(Error::SslBadInputData); + return Err(codes::SslBadInputData.into()); } Ok(unsafe { self.handle().session.as_ref().unwrap().ciphersuite as u16 }) @@ -450,7 +449,7 @@ impl Context { pub fn peer_cert(&self) -> Result>> { if self.handle().session.is_null() { - return Err(Error::SslBadInputData); + return Err(codes::SslBadInputData.into()); } unsafe { @@ -459,7 +458,7 @@ impl Context { // variable for that. let peer_cert: &MbedtlsList = UnsafeFrom::from(&((*self.handle().session).peer_cert) as *const *mut x509_crt as *const *const x509_crt) - .ok_or(Error::SslBadInputData)?; + .ok_or::(codes::SslBadInputData.into())?; Ok(Some(peer_cert)) } } @@ -559,7 +558,7 @@ impl HandshakeContext { pub fn set_authmode(&mut self, am: AuthMode) -> Result<()> { if self.inner.handshake as *const _ == ::core::ptr::null() { - return Err(Error::SslBadInputData); + return Err(codes::SslBadInputData.into()); } unsafe { ssl_set_hs_authmode(self.into(), am as i32) } @@ -569,7 +568,7 @@ impl HandshakeContext { pub fn set_ca_list(&mut self, chain: Option>>, crl: Option>) -> Result<()> { // mbedtls_ssl_set_hs_ca_chain does not check for NULL handshake. if self.inner.handshake as *const _ == ::core::ptr::null() { - return Err(Error::SslBadInputData); + return Err(codes::SslBadInputData.into()); } // This will override current handshake CA chain. @@ -596,7 +595,7 @@ impl HandshakeContext { pub fn push_cert(&mut self, chain: Arc>, key: Arc) -> Result<()> { // mbedtls_ssl_set_hs_own_cert does not check for NULL handshake. if self.inner.handshake as *const _ == ::core::ptr::null() || chain.is_empty() { - return Err(Error::SslBadInputData); + return Err(codes::SslBadInputData.into()); } // This will append provided certificate pointers in internal structures. diff --git a/mbedtls/src/ssl/io.rs b/mbedtls/src/ssl/io.rs index 237468e6d..4c3b7d2ce 100644 --- a/mbedtls/src/ssl/io.rs +++ b/mbedtls/src/ssl/io.rs @@ -25,9 +25,9 @@ use mbedtls_sys::types::raw_types::{c_int, c_uchar, c_void}; use mbedtls_sys::types::size_t; use super::context::Context; -#[cfg(feature = "std")] -use crate::error::Error; use crate::error::Result; +#[cfg(feature = "std")] +use crate::error::{codes, Error}; /// A direct representation of the `mbedtls_ssl_send_t` and `mbedtls_ssl_recv_t` /// callback function pointers. @@ -122,15 +122,15 @@ impl IoCallback for IO { impl IoCallback for IO { fn recv(&mut self, buf: &mut [u8]) -> Result { self.read(buf).map_err(|e| match e { - ref e if e.kind() == std::io::ErrorKind::WouldBlock => Error::SslWantRead, - _ => Error::NetRecvFailed, + ref e if e.kind() == std::io::ErrorKind::WouldBlock => Error::from(codes::SslWantRead), + _ => Error::from(codes::NetRecvFailed), }) } fn send(&mut self, buf: &[u8]) -> Result { self.write(buf).map_err(|e| match e { - ref e if e.kind() == std::io::ErrorKind::WouldBlock => Error::SslWantWrite, - _ => Error::NetSendFailed, + ref e if e.kind() == std::io::ErrorKind::WouldBlock => Error::from(codes::SslWantWrite), + _ => Error::from(codes::NetSendFailed), }) } } @@ -162,13 +162,13 @@ impl Io for ConnectedUdpSocket { fn recv(&mut self, buf: &mut [u8]) -> Result { match self.socket.recv(buf) { Ok(i) => Ok(i), - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => Err(Error::SslWantRead), - Err(_) => Err(Error::NetRecvFailed), + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => Err(codes::SslWantRead.into()), + Err(_) => Err(codes::NetRecvFailed.into()), } } fn send(&mut self, buf: &[u8]) -> Result { - self.socket.send(buf).map_err(|_| Error::NetSendFailed) + self.socket.send(buf).map_err(|_| codes::NetSendFailed.into()) } } @@ -191,8 +191,10 @@ impl> Io for Context { impl> Read for Context { fn read(&mut self, buf: &mut [u8]) -> IoResult { match self.recv(buf) { - Err(Error::SslPeerCloseNotify) => Ok(0), - Err(Error::SslWantRead) | Err(Error::SslWantWrite) => Err(IoErrorKind::WouldBlock.into()), + Err(e) if e.high_level() == Some(codes::SslPeerCloseNotify) => Ok(0), + Err(e) if matches!(e.high_level(), Some(codes::SslWantRead | codes::SslWantWrite)) => { + Err(IoErrorKind::WouldBlock.into()) + } Err(e) => Err(crate::private::error_to_io_error(e)), Ok(i) => Ok(i), } @@ -208,8 +210,10 @@ impl> Read for Context { impl> Write for Context { fn write(&mut self, buf: &[u8]) -> IoResult { match self.send(buf) { - Err(Error::SslPeerCloseNotify) => Ok(0), - Err(Error::SslWantRead) | Err(Error::SslWantWrite) => Err(IoErrorKind::WouldBlock.into()), + Err(e) if e.high_level() == Some(codes::SslPeerCloseNotify) => Ok(0), + Err(e) if matches!(e.high_level(), Some(codes::SslWantRead | codes::SslWantWrite)) => { + Err(IoErrorKind::WouldBlock.into()) + } Err(e) => Err(crate::private::error_to_io_error(e)), Ok(i) => Ok(i), } diff --git a/mbedtls/src/x509/certificate.rs b/mbedtls/src/x509/certificate.rs index 80cb94754..10b11c584 100644 --- a/mbedtls/src/x509/certificate.rs +++ b/mbedtls/src/x509/certificate.rs @@ -16,7 +16,7 @@ use mbedtls_sys::*; use crate::alloc::{mbedtls_calloc, Box as MbedtlsBox, CString, List as MbedtlsList}; #[cfg(not(feature = "std"))] use crate::alloc_prelude::*; -use crate::error::{Error, IntoResult, Result}; +use crate::error::{codes, Error, IntoResult, Result}; use crate::hash::Type as MdType; use crate::pk::Pk; use crate::private::UnsafeFrom; @@ -78,7 +78,7 @@ fn x509_buf_to_vec(buf: &x509_buf) -> Vec { fn x509_time_to_time(tm: &x509_time) -> Result