Skip to content

Commit

Permalink
Merge #103
Browse files Browse the repository at this point in the history
103: Trusted Certificate Callback support r=Pagten a=Pagten

This PR adds support to rust-mbedtls for the `mbedtls_ssl_conf_ca_cb()` function. This function allows users to register the set of trusted certificates through a callback, instead of via a linked list as configured by `mbedtls_ssl_conf_ca_chain()`.

Co-authored-by: Pieter Agten <[email protected]>
  • Loading branch information
bors[bot] and Pagten authored May 9, 2020
2 parents c619009 + c7b2bdf commit 7231e15
Show file tree
Hide file tree
Showing 11 changed files with 352 additions and 42 deletions.
13 changes: 10 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions mbedtls-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@ aesni = []
padlock = []
legacy_protocols = []
mpi_force_c_code = []
trusted_cert_callback = []
59 changes: 30 additions & 29 deletions mbedtls-sys/build/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,35 +379,36 @@ pub const DEFAULT_DEFINES: &'static [CDefine] = &[

#[cfg_attr(rustfmt, rustfmt_skip)]
pub const FEATURE_DEFINES: &'static [(&'static str, CDefine)] = &[
("time", ("MBEDTLS_HAVE_TIME", Defined)),
("time", ("MBEDTLS_HAVE_TIME_DATE", Defined)),
("time", ("MBEDTLS_TIMING_C", Defined)),
("custom_time", ("MBEDTLS_PLATFORM_TIME_MACRO", DefinedAs("mbedtls_time"))),
("custom_gmtime_r", ("MBEDTLS_PLATFORM_GMTIME_R_ALT", Defined)),
("havege", ("MBEDTLS_HAVEGE_C", Defined)),
("threading", ("MBEDTLS_THREADING_C", Defined)),
("pthread", ("MBEDTLS_THREADING_PTHREAD", Defined)),
("custom_threading", ("MBEDTLS_THREADING_IMPL", Defined)),
("pkcs11", ("MBEDTLS_PKCS11_C", Defined)),
("zlib", ("MBEDTLS_ZLIB_SUPPORT", Defined)),
("std", ("MBEDTLS_NET_C", Defined)),
("std", ("MBEDTLS_FS_IO", Defined)),
("std", ("MBEDTLS_NO_PLATFORM_ENTROPY", Undefined)),
("std", ("MBEDTLS_DEBUG_C", Defined)),
("std", ("MBEDTLS_ENTROPY_C", Defined)),
("custom_printf", ("MBEDTLS_PLATFORM_C", Defined)),
("custom_printf", ("MBEDTLS_PLATFORM_PRINTF_MACRO", DefinedAs("mbedtls_printf"))),
("aesni", ("MBEDTLS_AESNI_C", Defined)),
("padlock", ("MBEDTLS_PADLOCK_C", Defined)),
("custom_has_support", ("MBEDTLS_CUSTOM_HAS_AESNI", Defined)),
("custom_has_support", ("MBEDTLS_CUSTOM_HAS_PADLOCK", Defined)),
("legacy_protocols", ("MBEDTLS_SSL_PROTO_SSL3", Defined)),
("legacy_protocols", ("MBEDTLS_SSL_PROTO_TLS1", Defined)),
("legacy_protocols", ("MBEDTLS_SSL_PROTO_TLS1_1", Defined)),
("legacy_protocols", ("MBEDTLS_SSL_CBC_RECORD_SPLITTING",Defined)),
("aes_alt", ("MBEDTLS_AES_ENCRYPT_ALT", Defined)),
("aes_alt", ("MBEDTLS_AES_DECRYPT_ALT", Defined)),
("mpi_force_c_code", ("MBEDTLS_MPI_FORCE_C_CODE", Defined)),
("time", ("MBEDTLS_HAVE_TIME", Defined)),
("time", ("MBEDTLS_HAVE_TIME_DATE", Defined)),
("time", ("MBEDTLS_TIMING_C", Defined)),
("custom_time", ("MBEDTLS_PLATFORM_TIME_MACRO", DefinedAs("mbedtls_time"))),
("custom_gmtime_r", ("MBEDTLS_PLATFORM_GMTIME_R_ALT", Defined)),
("havege", ("MBEDTLS_HAVEGE_C", Defined)),
("threading", ("MBEDTLS_THREADING_C", Defined)),
("pthread", ("MBEDTLS_THREADING_PTHREAD", Defined)),
("custom_threading", ("MBEDTLS_THREADING_IMPL", Defined)),
("pkcs11", ("MBEDTLS_PKCS11_C", Defined)),
("zlib", ("MBEDTLS_ZLIB_SUPPORT", Defined)),
("std", ("MBEDTLS_NET_C", Defined)),
("std", ("MBEDTLS_FS_IO", Defined)),
("std", ("MBEDTLS_NO_PLATFORM_ENTROPY", Undefined)),
("std", ("MBEDTLS_DEBUG_C", Defined)),
("std", ("MBEDTLS_ENTROPY_C", Defined)),
("custom_printf", ("MBEDTLS_PLATFORM_C", Defined)),
("custom_printf", ("MBEDTLS_PLATFORM_PRINTF_MACRO", DefinedAs("mbedtls_printf"))),
("aesni", ("MBEDTLS_AESNI_C", Defined)),
("padlock", ("MBEDTLS_PADLOCK_C", Defined)),
("custom_has_support", ("MBEDTLS_CUSTOM_HAS_AESNI", Defined)),
("custom_has_support", ("MBEDTLS_CUSTOM_HAS_PADLOCK", Defined)),
("legacy_protocols", ("MBEDTLS_SSL_PROTO_SSL3", Defined)),
("legacy_protocols", ("MBEDTLS_SSL_PROTO_TLS1", Defined)),
("legacy_protocols", ("MBEDTLS_SSL_PROTO_TLS1_1", Defined)),
("legacy_protocols", ("MBEDTLS_SSL_CBC_RECORD_SPLITTING", Defined)),
("aes_alt", ("MBEDTLS_AES_ENCRYPT_ALT", Defined)),
("aes_alt", ("MBEDTLS_AES_DECRYPT_ALT", Defined)),
("mpi_force_c_code", ("MBEDTLS_MPI_FORCE_C_CODE", Defined)),
("trusted_cert_callback", ("MBEDTLS_X509_TRUSTED_CERTIFICATE_CALLBACK", Defined)),
];

pub const SUFFIX: &'static str = r#"
Expand Down
10 changes: 8 additions & 2 deletions mbedtls/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "mbedtls"
version = "0.5.3"
version = "0.5.4"
authors = ["Jethro Beekman <[email protected]>"]
build = "build.rs"
edition = "2018"
Expand Down Expand Up @@ -35,14 +35,15 @@ rs-libc = "0.1.0"
[dependencies.mbedtls-sys-auto]
version = "2.18.0"
default-features = false
features = ["custom_printf"]
features = ["custom_printf", "trusted_cert_callback"]
path = "../mbedtls-sys"

[dev-dependencies]
libc = "0.2.0"
rand = "0.4.0"
serde_cbor = "0.6"
hex = "0.3"
matches = "0.1.8"

[build-dependencies]
cc = "1.0"
Expand Down Expand Up @@ -109,6 +110,11 @@ path = "tests/rsa.rs"
name = "save_restore"
path = "tests/save_restore.rs"

[[test]]
name = "ssl_conf_ca_cb"
path = "tests/ssl_conf_ca_cb.rs"
required-features = ["std"]

[[test]]
name = "ssl_conf_verify"
path = "tests/ssl_conf_verify.rs"
Expand Down
1 change: 1 addition & 0 deletions mbedtls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* according to those terms. */

#![deny(warnings)]
#![allow(unused_doc_comments)]
#![cfg_attr(not(feature = "std"), no_std)]

#[cfg(all(not(feature = "std"), not(feature = "core_io")))]
Expand Down
90 changes: 90 additions & 0 deletions mbedtls/src/ssl/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ use crate::ssl::context::HandshakeContext;
use crate::ssl::ticket::TicketCallback;
use crate::x509::{certificate, Crl, LinkedCertificate, Profile, VerifyError};

extern "C" {
fn calloc(n: usize, size: usize) -> *mut c_void;
fn free(ptr: *mut c_void);
}

#[allow(non_camel_case_types)]
#[derive(Eq, PartialEq, PartialOrd, Ord, Debug, Copy, Clone)]
pub enum Version {
Expand Down Expand Up @@ -291,6 +296,91 @@ impl<'c> Config<'c> {
)
}
}

pub fn set_ca_callback<F>(&mut self, cb: &'c mut F)
where
F: FnMut(&LinkedCertificate, &mut ForeignOwnedCertListBuilder) -> Result<()>,
{
unsafe extern "C" fn ca_callback<F>(
closure: *mut c_void,
child: *const x509_crt,
candidate_cas: *mut *mut x509_crt
) -> c_int
where
F: FnMut(&LinkedCertificate, &mut ForeignOwnedCertListBuilder) -> Result<()>,
{
let cb = &mut *(closure as *mut F);
let child: &LinkedCertificate = UnsafeFrom::from(child).expect("valid child certificate");
let mut cert_builder = ForeignOwnedCertListBuilder::new();
match cb(child, &mut cert_builder) {
Ok(()) => {
*candidate_cas = cert_builder.to_x509_crt_ptr();
0
},
Err(e) => e.to_int(),
}
}

unsafe {
ssl_conf_ca_cb(
&mut self.inner,
Some(ca_callback::<F>),
cb as *mut F as _,
)
}
}
}

/// Builds a linked list of x509_crt instances, all of which are owned by mbedtls. That is, the
/// memory for these certificates has been allocated by mbedtls, on the C heap. This is needed for
/// situations in which an mbedtls function takes ownership of a list of certs. The problem with
/// handing such functions a "normal" cert list such as certificate::LinkedCertificate or
/// certificate::List, is that those lists (at least partly) consist of memory allocated on the
/// rust-side and hence cannot be freed on the c-side.
pub struct ForeignOwnedCertListBuilder {
cert_list: *mut x509_crt,
}

impl ForeignOwnedCertListBuilder {
pub(crate) fn new() -> Self {
let cert_list = unsafe { calloc(1, core::mem::size_of::<x509_crt>()) } as *mut x509_crt;
if cert_list == ::core::ptr::null_mut() {
panic!("Out of memory");
}
unsafe { ::mbedtls_sys::x509_crt_init(cert_list); }

Self {
cert_list
}
}

pub fn push_back(&mut self, cert: &LinkedCertificate) {
self.try_push_back(cert.as_der()).expect("cert is a valid DER-encoded certificate");
}

pub fn try_push_back(&mut self, cert: &[u8]) -> Result<()> {
// x509_crt_parse_der will allocate memory for the cert on the C heap
unsafe { x509_crt_parse_der(self.cert_list, cert.as_ptr(), cert.len()) }.into_result()?;
Ok(())
}

// The memory pointed to by the return value is managed by mbedtls. If the return value is
// dropped without handing it to an mbedtls-function that takes ownership of it, that memory
// will be leaked.
pub(crate) fn to_x509_crt_ptr(mut self) -> *mut x509_crt {
let res = self.cert_list;
self.cert_list = ::core::ptr::null_mut();
res
}
}

impl Drop for ForeignOwnedCertListBuilder {
fn drop(&mut self) {
unsafe {
::mbedtls_sys::x509_crt_free(self.cert_list);
free(self.cert_list as *mut c_void);
}
}
}

setter_callback!(Config<'c>::set_rng(f: crate::rng::Random) = ssl_conf_rng);
Expand Down
4 changes: 2 additions & 2 deletions mbedtls/src/x509/certificate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -688,8 +688,8 @@ mod tests {
impl Test {
fn new() -> Self {
Test {
key1: Pk::from_private_key(crate::test_support::keys::PEM_KEY, None).unwrap(),
key2: Pk::from_private_key(crate::test_support::keys::PEM_KEY, None).unwrap(),
key1: Pk::from_private_key(crate::test_support::keys::PEM_SELF_SIGNED_KEY, None).unwrap(),
key2: Pk::from_private_key(crate::test_support::keys::PEM_SELF_SIGNED_KEY, None).unwrap(),
}
}

Expand Down
2 changes: 1 addition & 1 deletion mbedtls/src/x509/csr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ mod tests {
impl Test {
fn new() -> Self {
Test {
key: Pk::from_private_key(crate::test_support::keys::PEM_KEY, None).unwrap(),
key: Pk::from_private_key(crate::test_support::keys::PEM_SELF_SIGNED_KEY, None).unwrap(),
}
}

Expand Down
4 changes: 2 additions & 2 deletions mbedtls/tests/mbedtls_self_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ fn rand() -> mbedtls_sys::types::raw_types::c_int {

#[cfg(any(not(feature = "std"), target_env = "sgx"))]
fn enable_self_test() {
use std::sync::{Once, ONCE_INIT};
use std::sync::Once;

static START: Once = ONCE_INIT;
static START: Once = Once::new();

START.call_once(|| {
// safe because synchronized
Expand Down
Loading

0 comments on commit 7231e15

Please sign in to comment.