diff --git a/heed-traits/src/lib.rs b/heed-traits/src/lib.rs index af8c2102..88fb0efe 100644 --- a/heed-traits/src/lib.rs +++ b/heed-traits/src/lib.rs @@ -42,11 +42,6 @@ pub trait BytesDecode<'a> { /// with shorter keys collating before longer keys. pub trait Comparator { /// Compares the raw bytes representation of two keys. - /// - /// # Safety - /// - /// This function must never crash, this is the reason why it takes raw bytes as parameter, - /// to let you define the recovery method you want in case of a decoding error. fn compare(a: &[u8], b: &[u8]) -> Ordering; } @@ -61,10 +56,6 @@ pub trait Comparator { pub trait LexicographicComparator: Comparator { /// Compare a single byte; this function is used to implement [`Comparator::compare`] /// by definition of lexicographic ordering. - /// - /// # Safety - /// - /// This function must never crash. fn compare_elem(a: u8, b: u8) -> Ordering; /// Advances the given `elem` to its immediate lexicographic successor, if possible. diff --git a/heed/src/env.rs b/heed/src/env.rs index 74b64dd7..c69521ca 100644 --- a/heed/src/env.rs +++ b/heed/src/env.rs @@ -9,7 +9,9 @@ use std::os::unix::{ ffi::OsStrExt, io::{AsRawFd, BorrowedFd, RawFd}, }; +use std::panic::catch_unwind; use std::path::{Path, PathBuf}; +use std::process::abort; use std::sync::{Arc, RwLock}; use std::time::Duration; #[cfg(windows)] @@ -371,16 +373,17 @@ impl Drop for EnvInner { /// A helper function that transforms the LMDB types into Rust types (`MDB_val` into slices) /// and vice versa, the Rust types into C types (`Ordering` into an integer). -extern "C" fn custom_key_cmp_wrapper( +unsafe extern "C" fn custom_key_cmp_wrapper( a: *const ffi::MDB_val, b: *const ffi::MDB_val, ) -> i32 { let a = unsafe { ffi::from_val(*a) }; let b = unsafe { ffi::from_val(*b) }; - match C::compare(a, b) { - Ordering::Less => -1, - Ordering::Equal => 0, - Ordering::Greater => 1, + match catch_unwind(|| C::compare(a, b)) { + Ok(Ordering::Less) => -1, + Ok(Ordering::Equal) => 0, + Ok(Ordering::Greater) => 1, + Err(_) => abort(), } }