diff --git a/src/entropy.rs b/src/entropy.rs index 8daac6d8..3841cb02 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -1,5 +1,5 @@ //! Information theory (e.g. entropy, KL divergence, etc.). -use crate::errors::ShapeMismatch; +use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch}; use ndarray::{Array, ArrayBase, Data, Dimension, Zip}; use num_traits::Float; @@ -19,7 +19,7 @@ where /// i=1 /// ``` /// - /// If the array is empty, `None` is returned. + /// If the array is empty, `Err(EmptyInput)` is returned. /// /// **Panics** if `ln` of any element in the array panics (which can occur for negative values for some `A`). /// @@ -38,7 +38,7 @@ where /// /// [entropy]: https://en.wikipedia.org/wiki/Entropy_(information_theory) /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory - fn entropy(&self) -> Option + fn entropy(&self) -> Result where A: Float; @@ -53,8 +53,9 @@ where /// i=1 /// ``` /// - /// If the arrays are empty, Ok(`None`) is returned. - /// If the array shapes are not identical, `Err(ShapeMismatch)` is returned. + /// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned. + /// If the array shapes are not identical, + /// `Err(MultiInputError::ShapeMismatch)` is returned. /// /// **Panics** if, for a pair of elements *(pᵢ, qᵢ)* from *p* and *q*, computing /// *ln(qᵢ/pᵢ)* is a panic cause for `A`. @@ -73,7 +74,7 @@ where /// /// [Kullback-Leibler divergence]: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory - fn kl_divergence(&self, q: &ArrayBase) -> Result, ShapeMismatch> + fn kl_divergence(&self, q: &ArrayBase) -> Result where S2: Data, A: Float; @@ -89,8 +90,9 @@ where /// i=1 /// ``` /// - /// If the arrays are empty, Ok(`None`) is returned. - /// If the array shapes are not identical, `Err(ShapeMismatch)` is returned. + /// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned. + /// If the array shapes are not identical, + /// `Err(MultiInputError::ShapeMismatch)` is returned. /// /// **Panics** if any element in *q* is negative and taking the logarithm of a negative number /// is a panic cause for `A`. @@ -114,7 +116,7 @@ where /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory /// [optimization problems]: https://en.wikipedia.org/wiki/Cross-entropy_method /// [machine learning]: https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression - fn cross_entropy(&self, q: &ArrayBase) -> Result, ShapeMismatch> + fn cross_entropy(&self, q: &ArrayBase) -> Result where S2: Data, A: Float; @@ -125,14 +127,14 @@ where S: Data, D: Dimension, { - fn entropy(&self) -> Option + fn entropy(&self) -> Result where A: Float, { if self.len() == 0 { - None + Err(EmptyInput) } else { - let entropy = self + let entropy = -self .mapv(|x| { if x == A::zero() { A::zero() @@ -141,23 +143,24 @@ where } }) .sum(); - Some(-entropy) + Ok(entropy) } } - fn kl_divergence(&self, q: &ArrayBase) -> Result, ShapeMismatch> + fn kl_divergence(&self, q: &ArrayBase) -> Result where A: Float, S2: Data, { if self.len() == 0 { - return Ok(None); + return Err(MultiInputError::EmptyInput); } if self.shape() != q.shape() { return Err(ShapeMismatch { first_shape: self.shape().to_vec(), second_shape: q.shape().to_vec(), - }); + } + .into()); } let mut temp = Array::zeros(self.raw_dim()); @@ -174,22 +177,23 @@ where } }); let kl_divergence = -temp.sum(); - Ok(Some(kl_divergence)) + Ok(kl_divergence) } - fn cross_entropy(&self, q: &ArrayBase) -> Result, ShapeMismatch> + fn cross_entropy(&self, q: &ArrayBase) -> Result where S2: Data, A: Float, { if self.len() == 0 { - return Ok(None); + return Err(MultiInputError::EmptyInput); } if self.shape() != q.shape() { return Err(ShapeMismatch { first_shape: self.shape().to_vec(), second_shape: q.shape().to_vec(), - }); + } + .into()); } let mut temp = Array::zeros(self.raw_dim()); @@ -206,7 +210,7 @@ where } }); let cross_entropy = -temp.sum(); - Ok(Some(cross_entropy)) + Ok(cross_entropy) } } @@ -214,7 +218,7 @@ where mod tests { use super::EntropyExt; use approx::assert_abs_diff_eq; - use errors::ShapeMismatch; + use errors::{EmptyInput, MultiInputError}; use ndarray::{array, Array1}; use noisy_float::types::n64; use std::f64; @@ -228,7 +232,7 @@ mod tests { #[test] fn test_entropy_with_empty_array_of_floats() { let a: Array1 = array![]; - assert!(a.entropy().is_none()); + assert_eq!(a.entropy(), Err(EmptyInput)); } #[test] @@ -251,13 +255,13 @@ mod tests { } #[test] - fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), ShapeMismatch> { + fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), MultiInputError> { let a = array![f64::NAN, 1.]; let b = array![2., 1.]; - assert!(a.cross_entropy(&b)?.unwrap().is_nan()); - assert!(b.cross_entropy(&a)?.unwrap().is_nan()); - assert!(a.kl_divergence(&b)?.unwrap().is_nan()); - assert!(b.kl_divergence(&a)?.unwrap().is_nan()); + assert!(a.cross_entropy(&b)?.is_nan()); + assert!(b.cross_entropy(&a)?.is_nan()); + assert!(a.kl_divergence(&b)?.is_nan()); + assert!(b.kl_divergence(&a)?.is_nan()); Ok(()) } @@ -284,20 +288,19 @@ mod tests { } #[test] - fn test_cross_entropy_and_kl_with_empty_array_of_floats() -> Result<(), ShapeMismatch> { + fn test_cross_entropy_and_kl_with_empty_array_of_floats() { let p: Array1 = array![]; let q: Array1 = array![]; - assert!(p.cross_entropy(&q)?.is_none()); - assert!(p.kl_divergence(&q)?.is_none()); - Ok(()) + assert!(p.cross_entropy(&q).unwrap_err().is_empty_input()); + assert!(p.kl_divergence(&q).unwrap_err().is_empty_input()); } #[test] - fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), ShapeMismatch> { + fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), MultiInputError> { let p = array![1.]; let q = array![-1.]; - let cross_entropy: f64 = p.cross_entropy(&q)?.unwrap(); - let kl_divergence: f64 = p.kl_divergence(&q)?.unwrap(); + let cross_entropy: f64 = p.cross_entropy(&q)?; + let kl_divergence: f64 = p.kl_divergence(&q)?; assert!(cross_entropy.is_nan()); assert!(kl_divergence.is_nan()); Ok(()) @@ -320,26 +323,26 @@ mod tests { } #[test] - fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), ShapeMismatch> { + fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), MultiInputError> { let p = array![0., 0.]; let q = array![0., 0.5]; - assert_eq!(p.cross_entropy(&q)?.unwrap(), 0.); - assert_eq!(p.kl_divergence(&q)?.unwrap(), 0.); + assert_eq!(p.cross_entropy(&q)?, 0.); + assert_eq!(p.kl_divergence(&q)?, 0.); Ok(()) } #[test] fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership( - ) -> Result<(), ShapeMismatch> { + ) -> Result<(), MultiInputError> { let p = array![0.5, 0.5]; let mut q = array![0.5, 0.]; - assert_eq!(p.cross_entropy(&q.view_mut())?.unwrap(), f64::INFINITY); - assert_eq!(p.kl_divergence(&q.view_mut())?.unwrap(), f64::INFINITY); + assert_eq!(p.cross_entropy(&q.view_mut())?, f64::INFINITY); + assert_eq!(p.kl_divergence(&q.view_mut())?, f64::INFINITY); Ok(()) } #[test] - fn test_cross_entropy() -> Result<(), ShapeMismatch> { + fn test_cross_entropy() -> Result<(), MultiInputError> { // Arrays of probability values - normalized and positive. let p: Array1 = array![ 0.05340169, 0.02508511, 0.03460454, 0.00352313, 0.07837615, 0.05859495, 0.05782189, @@ -356,16 +359,12 @@ mod tests { // Computed using scipy.stats.entropy(p) + scipy.stats.entropy(p, q) let expected_cross_entropy = 3.385347705020779; - assert_abs_diff_eq!( - p.cross_entropy(&q)?.unwrap(), - expected_cross_entropy, - epsilon = 1e-6 - ); + assert_abs_diff_eq!(p.cross_entropy(&q)?, expected_cross_entropy, epsilon = 1e-6); Ok(()) } #[test] - fn test_kl() -> Result<(), ShapeMismatch> { + fn test_kl() -> Result<(), MultiInputError> { // Arrays of probability values - normalized and positive. let p: Array1 = array![ 0.00150472, 0.01388706, 0.03495376, 0.03264211, 0.03067355, 0.02183501, 0.00137516, @@ -390,7 +389,7 @@ mod tests { // Computed using scipy.stats.entropy(p, q) let expected_kl = 0.3555862567800096; - assert_abs_diff_eq!(p.kl_divergence(&q)?.unwrap(), expected_kl, epsilon = 1e-6); + assert_abs_diff_eq!(p.kl_divergence(&q)?, expected_kl, epsilon = 1e-6); Ok(()) } } diff --git a/src/errors.rs b/src/errors.rs index 4bbeea46..d89112a5 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -2,10 +2,50 @@ use std::error::Error; use std::fmt; -#[derive(Debug)] +/// An error that indicates that the input array was empty. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct EmptyInput; + +impl fmt::Display for EmptyInput { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "Empty input.") + } +} + +impl Error for EmptyInput {} + +/// An error computing a minimum/maximum value. +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum MinMaxError { + /// The input was empty. + EmptyInput, + /// The ordering between a tested pair of values was undefined. + UndefinedOrder, +} + +impl fmt::Display for MinMaxError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + MinMaxError::EmptyInput => write!(f, "Empty input."), + MinMaxError::UndefinedOrder => { + write!(f, "Undefined ordering between a tested pair of values.") + } + } + } +} + +impl Error for MinMaxError {} + +impl From for MinMaxError { + fn from(_: EmptyInput) -> MinMaxError { + MinMaxError::EmptyInput + } +} + /// An error used by methods and functions that take two arrays as argument and /// expect them to have exactly the same shape /// (e.g. `ShapeMismatch` is raised when `a.shape() == b.shape()` evaluates to `False`). +#[derive(Clone, Debug)] pub struct ShapeMismatch { pub first_shape: Vec, pub second_shape: Vec, @@ -22,3 +62,53 @@ impl fmt::Display for ShapeMismatch { } impl Error for ShapeMismatch {} + +/// An error for methods that take multiple non-empty array inputs. +#[derive(Clone, Debug)] +pub enum MultiInputError { + /// One or more of the arrays were empty. + EmptyInput, + /// The arrays did not have the same shape. + ShapeMismatch(ShapeMismatch), +} + +impl MultiInputError { + /// Returns whether `self` is the `EmptyInput` variant. + pub fn is_empty_input(&self) -> bool { + match self { + MultiInputError::EmptyInput => true, + _ => false, + } + } + + /// Returns whether `self` is the `ShapeMismatch` variant. + pub fn is_shape_mismatch(&self) -> bool { + match self { + MultiInputError::ShapeMismatch(_) => true, + _ => false, + } + } +} + +impl fmt::Display for MultiInputError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + MultiInputError::EmptyInput => write!(f, "Empty input."), + MultiInputError::ShapeMismatch(e) => write!(f, "Shape mismatch: {}", e), + } + } +} + +impl Error for MultiInputError {} + +impl From for MultiInputError { + fn from(_: EmptyInput) -> Self { + MultiInputError::EmptyInput + } +} + +impl From for MultiInputError { + fn from(err: ShapeMismatch) -> Self { + MultiInputError::ShapeMismatch(err) + } +} diff --git a/src/histogram/errors.rs b/src/histogram/errors.rs index 7afaea1f..8deb6218 100644 --- a/src/histogram/errors.rs +++ b/src/histogram/errors.rs @@ -1,3 +1,4 @@ +use crate::errors::{EmptyInput, MinMaxError}; use std::error; use std::fmt; @@ -15,9 +16,60 @@ impl error::Error for BinNotFound { fn description(&self) -> &str { "No bin has been found." } +} + +/// Error computing the set of histogram bins. +#[derive(Debug, Clone)] +pub enum BinsBuildError { + /// The input array was empty. + EmptyInput, + /// The strategy for computing appropriate bins failed. + Strategy, + #[doc(hidden)] + __NonExhaustive, +} + +impl BinsBuildError { + /// Returns whether `self` is the `EmptyInput` variant. + pub fn is_empty_input(&self) -> bool { + match self { + BinsBuildError::EmptyInput => true, + _ => false, + } + } + + /// Returns whether `self` is the `Strategy` variant. + pub fn is_strategy(&self) -> bool { + match self { + BinsBuildError::Strategy => true, + _ => false, + } + } +} + +impl fmt::Display for BinsBuildError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "The strategy failed to determine a non-zero bin width.") + } +} + +impl error::Error for BinsBuildError { + fn description(&self) -> &str { + "The strategy failed to determine a non-zero bin width." + } +} + +impl From for BinsBuildError { + fn from(_: EmptyInput) -> Self { + BinsBuildError::EmptyInput + } +} - fn cause(&self) -> Option<&error::Error> { - // Generic error, underlying cause isn't tracked. - None +impl From for BinsBuildError { + fn from(err: MinMaxError) -> BinsBuildError { + match err { + MinMaxError::EmptyInput => BinsBuildError::EmptyInput, + MinMaxError::UndefinedOrder => BinsBuildError::Strategy, + } } } diff --git a/src/histogram/grid.rs b/src/histogram/grid.rs index bfa5afc1..91caab03 100644 --- a/src/histogram/grid.rs +++ b/src/histogram/grid.rs @@ -1,4 +1,5 @@ use super::bins::Bins; +use super::errors::BinsBuildError; use super::strategies::BinsBuildingStrategy; use itertools::izip; use ndarray::{ArrayBase, Axis, Data, Ix1, Ix2}; @@ -54,7 +55,7 @@ use std::ops::Range; /// /// // The optimal grid layout is inferred from the data, /// // specifying a strategy (Auto in this case) -/// let grid = GridBuilder::>::from_array(&observations).build(); +/// let grid = GridBuilder::>::from_array(&observations).unwrap().build(); /// let expected_grid = Grid::from(vec![Bins::new(Edges::from(vec![1, 20, 39, 58, 77, 96, 115]))]); /// assert_eq!(grid, expected_grid); /// @@ -169,17 +170,20 @@ where /// it returns a `GridBuilder` instance that has learned the required parameter /// to build a [`Grid`] according to the specified [`strategy`]. /// + /// It returns `Err` if it is not possible to build a [`Grid`] given + /// the observed data according to the chosen [`strategy`]. + /// /// [`Grid`]: struct.Grid.html /// [`strategy`]: strategies/index.html - pub fn from_array(array: &ArrayBase) -> Self + pub fn from_array(array: &ArrayBase) -> Result where S: Data, { let bin_builders = array .axis_iter(Axis(1)) .map(|data| B::from_array(&data)) - .collect(); - Self { bin_builders } + .collect::, BinsBuildError>>()?; + Ok(Self { bin_builders }) } /// Returns a [`Grid`] instance, built accordingly to the specified [`strategy`] diff --git a/src/histogram/histograms.rs b/src/histogram/histograms.rs index 9bfe2724..4eaeaad4 100644 --- a/src/histogram/histograms.rs +++ b/src/histogram/histograms.rs @@ -123,7 +123,7 @@ where /// [n64(-1.), n64(-0.5)], /// [n64(0.5), n64(-1.)] /// ]; - /// let grid = GridBuilder::>::from_array(&observations).build(); + /// let grid = GridBuilder::>::from_array(&observations).unwrap().build(); /// let expected_grid = Grid::from( /// vec![ /// Bins::new(Edges::from(vec![n64(-1.), n64(0.), n64(1.), n64(2.)])), diff --git a/src/histogram/strategies.rs b/src/histogram/strategies.rs index eeaee686..93d75a9b 100644 --- a/src/histogram/strategies.rs +++ b/src/histogram/strategies.rs @@ -20,6 +20,7 @@ //! [`NumPy`]: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram_bin_edges.html#numpy.histogram_bin_edges use super::super::interpolate::Nearest; use super::super::{Quantile1dExt, QuantileExt}; +use super::errors::BinsBuildError; use super::{Bins, Edges}; use ndarray::prelude::*; use ndarray::Data; @@ -40,10 +41,14 @@ pub trait BinsBuildingStrategy { /// Given some observations in a 1-dimensional array it returns a `BinsBuildingStrategy` /// that has learned the required parameter to build a collection of [`Bins`]. /// + /// It returns `Err` if it is not possible to build a collection of + /// [`Bins`] given the observed data according to the chosen strategy. + /// /// [`Bins`]: ../struct.Bins.html - fn from_array(array: &ArrayBase) -> Self + fn from_array(array: &ArrayBase) -> Result where - S: Data; + S: Data, + Self: std::marker::Sized; /// Returns a [`Bins`] instance, built accordingly to the parameters /// inferred from observations in [`from_array`]. @@ -59,6 +64,7 @@ pub trait BinsBuildingStrategy { fn n_bins(&self) -> usize; } +#[derive(Debug)] struct EquiSpaced { bin_width: T, min: T, @@ -71,6 +77,7 @@ struct EquiSpaced { /// Let `n` be the number of observations. Then /// /// `n_bins` = `sqrt(n)` +#[derive(Debug)] pub struct Sqrt { builder: EquiSpaced, } @@ -84,6 +91,7 @@ pub struct Sqrt { /// /// `n_bins` is only proportional to cube root of `n`. It tends to overestimate /// the `n_bins` and it does not take into account data variability. +#[derive(Debug)] pub struct Rice { builder: EquiSpaced, } @@ -96,6 +104,7 @@ pub struct Rice { /// is too conservative for larger, non-normal datasets. /// /// This is the default method in R’s hist method. +#[derive(Debug)] pub struct Sturges { builder: EquiSpaced, } @@ -114,10 +123,12 @@ pub struct Sturges { /// The [`IQR`] is very robust to outliers. /// /// [`IQR`]: https://en.wikipedia.org/wiki/Interquartile_range +#[derive(Debug)] pub struct FreedmanDiaconis { builder: EquiSpaced, } +#[derive(Debug)] enum SturgesOrFD { Sturges(Sturges), FreedmanDiaconis(FreedmanDiaconis), @@ -133,6 +144,7 @@ enum SturgesOrFD { /// /// [`Sturges`]: struct.Sturges.html /// [`FreedmanDiaconis`]: struct.FreedmanDiaconis.html +#[derive(Debug)] pub struct Auto { builder: SturgesOrFD, } @@ -141,13 +153,17 @@ impl EquiSpaced where T: Ord + Clone + FromPrimitive + NumOps + Zero, { - /// **Panics** if `bin_width<=0`. - fn new(bin_width: T, min: T, max: T) -> Self { - assert!(bin_width > T::zero()); - Self { - bin_width, - min, - max, + /// Returns `Err(BinsBuildError::Strategy)` if `bin_width<=0` or `min` >= `max`. + /// Returns `Ok(Self)` otherwise. + fn new(bin_width: T, min: T, max: T) -> Result { + if (bin_width <= T::zero()) || (min >= max) { + Err(BinsBuildError::Strategy) + } else { + Ok(Self { + bin_width, + min, + max, + }) } } @@ -182,18 +198,20 @@ where { type Elem = T; - /// **Panics** if the array is constant or if `a.len()==0`. - fn from_array(a: &ArrayBase) -> Self + /// Returns `Err(BinsBuildError::Strategy)` if the array is constant. + /// Returns `Err(BinsBuildError::EmptyInput)` if `a.len()==0`. + /// Returns `Ok(Self)` otherwise. + fn from_array(a: &ArrayBase) -> Result where S: Data, { let n_elems = a.len(); let n_bins = (n_elems as f64).sqrt().round() as usize; - let min = a.min().unwrap().clone(); - let max = a.max().unwrap().clone(); + let min = a.min()?; + let max = a.max()?; let bin_width = compute_bin_width(min.clone(), max.clone(), n_bins); - let builder = EquiSpaced::new(bin_width, min, max); - Self { builder } + let builder = EquiSpaced::new(bin_width, min.clone(), max.clone())?; + Ok(Self { builder }) } fn build(&self) -> Bins { @@ -221,18 +239,20 @@ where { type Elem = T; - /// **Panics** if the array is constant or if `a.len()==0`. - fn from_array(a: &ArrayBase) -> Self + /// Returns `Err(BinsBuildError::Strategy)` if the array is constant. + /// Returns `Err(BinsBuildError::EmptyInput)` if `a.len()==0`. + /// Returns `Ok(Self)` otherwise. + fn from_array(a: &ArrayBase) -> Result where S: Data, { let n_elems = a.len(); let n_bins = (2. * (n_elems as f64).powf(1. / 3.)).round() as usize; - let min = a.min().unwrap().clone(); - let max = a.max().unwrap().clone(); + let min = a.min()?; + let max = a.max()?; let bin_width = compute_bin_width(min.clone(), max.clone(), n_bins); - let builder = EquiSpaced::new(bin_width, min, max); - Self { builder } + let builder = EquiSpaced::new(bin_width, min.clone(), max.clone())?; + Ok(Self { builder }) } fn build(&self) -> Bins { @@ -260,18 +280,20 @@ where { type Elem = T; - /// **Panics** if the array is constant or if `a.len()==0`. - fn from_array(a: &ArrayBase) -> Self + /// Returns `Err(BinsBuildError::Strategy)` if the array is constant. + /// Returns `Err(BinsBuildError::EmptyInput)` if `a.len()==0`. + /// Returns `Ok(Self)` otherwise. + fn from_array(a: &ArrayBase) -> Result where S: Data, { let n_elems = a.len(); let n_bins = (n_elems as f64).log2().round() as usize + 1; - let min = a.min().unwrap().clone(); - let max = a.max().unwrap().clone(); + let min = a.min()?; + let max = a.max()?; let bin_width = compute_bin_width(min.clone(), max.clone(), n_bins); - let builder = EquiSpaced::new(bin_width, min, max); - Self { builder } + let builder = EquiSpaced::new(bin_width, min.clone(), max.clone())?; + Ok(Self { builder }) } fn build(&self) -> Bins { @@ -299,12 +321,17 @@ where { type Elem = T; - /// **Panics** if `IQR==0` or if `a.len()==0`. - fn from_array(a: &ArrayBase) -> Self + /// Returns `Err(BinsBuildError::Strategy)` if `IQR==0`. + /// Returns `Err(BinsBuildError::EmptyInput)` if `a.len()==0`. + /// Returns `Ok(Self)` otherwise. + fn from_array(a: &ArrayBase) -> Result where S: Data, { let n_points = a.len(); + if n_points == 0 { + return Err(BinsBuildError::EmptyInput); + } let mut a_copy = a.to_owned(); let first_quartile = a_copy.quantile_mut::(0.25).unwrap(); @@ -312,10 +339,10 @@ where let iqr = third_quartile - first_quartile; let bin_width = FreedmanDiaconis::compute_bin_width(n_points, iqr); - let min = a_copy.min().unwrap().clone(); - let max = a_copy.max().unwrap().clone(); - let builder = EquiSpaced::new(bin_width, min, max); - Self { builder } + let min = a.min()?; + let max = a.max()?; + let builder = EquiSpaced::new(bin_width, min.clone(), max.clone())?; + Ok(Self { builder }) } fn build(&self) -> Bins { @@ -349,21 +376,34 @@ where { type Elem = T; - /// **Panics** if `IQR==0`, the array is constant, or `a.len()==0`. - fn from_array(a: &ArrayBase) -> Self + /// Returns `Err(BinsBuildError::Strategy)` if `IQR==0`. + /// Returns `Err(BinsBuildError::EmptyInput)` if `a.len()==0`. + /// Returns `Ok(Self)` otherwise. + fn from_array(a: &ArrayBase) -> Result where S: Data, { let fd_builder = FreedmanDiaconis::from_array(&a); let sturges_builder = Sturges::from_array(&a); - let builder = { - if fd_builder.bin_width() > sturges_builder.bin_width() { - SturgesOrFD::Sturges(sturges_builder) - } else { - SturgesOrFD::FreedmanDiaconis(fd_builder) + match (fd_builder, sturges_builder) { + (Err(_), Ok(sturges_builder)) => { + let builder = SturgesOrFD::Sturges(sturges_builder); + Ok(Self { builder }) + } + (Ok(fd_builder), Err(_)) => { + let builder = SturgesOrFD::FreedmanDiaconis(fd_builder); + Ok(Self { builder }) + } + (Ok(fd_builder), Ok(sturges_builder)) => { + let builder = if fd_builder.bin_width() > sturges_builder.bin_width() { + SturgesOrFD::Sturges(sturges_builder) + } else { + SturgesOrFD::FreedmanDiaconis(fd_builder) + }; + Ok(Self { builder }) } - }; - Self { builder } + (Err(err), Err(_)) => Err(err), + } } fn build(&self) -> Bins { @@ -416,10 +456,14 @@ where mod equispaced_tests { use super::*; - #[should_panic] #[test] fn bin_width_has_to_be_positive() { - EquiSpaced::new(0, 0, 200); + assert!(EquiSpaced::new(0, 0, 200).is_err()); + } + + #[test] + fn min_has_to_be_strictly_smaller_than_max() { + assert!(EquiSpaced::new(10, 0, 0).is_err()); } } @@ -428,16 +472,18 @@ mod sqrt_tests { use super::*; use ndarray::array; - #[should_panic] #[test] fn constant_array_are_bad() { - Sqrt::from_array(&array![1, 1, 1, 1, 1, 1, 1]); + assert!(Sqrt::from_array(&array![1, 1, 1, 1, 1, 1, 1]) + .unwrap_err() + .is_strategy()); } - #[should_panic] #[test] - fn empty_arrays_cause_panic() { - let _: Sqrt = Sqrt::from_array(&array![]); + fn empty_arrays_are_bad() { + assert!(Sqrt::::from_array(&array![]) + .unwrap_err() + .is_empty_input()); } } @@ -446,16 +492,18 @@ mod rice_tests { use super::*; use ndarray::array; - #[should_panic] #[test] fn constant_array_are_bad() { - Rice::from_array(&array![1, 1, 1, 1, 1, 1, 1]); + assert!(Rice::from_array(&array![1, 1, 1, 1, 1, 1, 1]) + .unwrap_err() + .is_strategy()); } - #[should_panic] #[test] - fn empty_arrays_cause_panic() { - let _: Rice = Rice::from_array(&array![]); + fn empty_arrays_are_bad() { + assert!(Rice::::from_array(&array![]) + .unwrap_err() + .is_empty_input()); } } @@ -464,16 +512,18 @@ mod sturges_tests { use super::*; use ndarray::array; - #[should_panic] #[test] fn constant_array_are_bad() { - Sturges::from_array(&array![1, 1, 1, 1, 1, 1, 1]); + assert!(Sturges::from_array(&array![1, 1, 1, 1, 1, 1, 1]) + .unwrap_err() + .is_strategy()); } - #[should_panic] #[test] - fn empty_arrays_cause_panic() { - let _: Sturges = Sturges::from_array(&array![]); + fn empty_arrays_are_bad() { + assert!(Sturges::::from_array(&array![]) + .unwrap_err() + .is_empty_input()); } } @@ -482,22 +532,27 @@ mod fd_tests { use super::*; use ndarray::array; - #[should_panic] #[test] fn constant_array_are_bad() { - FreedmanDiaconis::from_array(&array![1, 1, 1, 1, 1, 1, 1]); + assert!(FreedmanDiaconis::from_array(&array![1, 1, 1, 1, 1, 1, 1]) + .unwrap_err() + .is_strategy()); } - #[should_panic] #[test] - fn zero_iqr_causes_panic() { - FreedmanDiaconis::from_array(&array![-20, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 20]); + fn zero_iqr_is_bad() { + assert!( + FreedmanDiaconis::from_array(&array![-20, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 20]) + .unwrap_err() + .is_strategy() + ); } - #[should_panic] #[test] - fn empty_arrays_cause_panic() { - let _: FreedmanDiaconis = FreedmanDiaconis::from_array(&array![]); + fn empty_arrays_are_bad() { + assert!(FreedmanDiaconis::::from_array(&array![]) + .unwrap_err() + .is_empty_input()); } } @@ -506,21 +561,22 @@ mod auto_tests { use super::*; use ndarray::array; - #[should_panic] #[test] fn constant_array_are_bad() { - Auto::from_array(&array![1, 1, 1, 1, 1, 1, 1]); + assert!(Auto::from_array(&array![1, 1, 1, 1, 1, 1, 1]) + .unwrap_err() + .is_strategy()); } - #[should_panic] #[test] - fn zero_iqr_causes_panic() { - Auto::from_array(&array![-20, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 20]); + fn zero_iqr_is_handled_by_sturged() { + assert!(Auto::from_array(&array![-20, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 20]).is_ok()); } - #[should_panic] #[test] - fn empty_arrays_cause_panic() { - let _: Auto = Auto::from_array(&array![]); + fn empty_arrays_are_bad() { + assert!(Auto::::from_array(&array![]) + .unwrap_err() + .is_empty_input()); } } diff --git a/src/quantile.rs b/src/quantile.rs index 1b4b4fd6..626b27f9 100644 --- a/src/quantile.rs +++ b/src/quantile.rs @@ -1,3 +1,4 @@ +use crate::errors::{EmptyInput, MinMaxError, MinMaxError::UndefinedOrder}; use interpolate::Interpolate; use ndarray::prelude::*; use ndarray::{s, Data, DataMut, RemoveAxis}; @@ -184,11 +185,11 @@ where { /// Finds the index of the minimum value of the array. /// - /// Returns `None` if any of the pairwise orderings tested by the function - /// are undefined. (For example, this occurs if there are any - /// floating-point NaN values in the array.) + /// Returns `Err(MinMaxError::UndefinedOrder)` if any of the pairwise + /// orderings tested by the function are undefined. (For example, this + /// occurs if there are any floating-point NaN values in the array.) /// - /// Returns `None` if the array is empty. + /// Returns `Err(MinMaxError::EmptyInput)` if the array is empty. /// /// Even if there are multiple (equal) elements that are minima, only one /// index is returned. (Which one is returned is unspecified and may depend @@ -205,9 +206,9 @@ where /// /// let a = array![[1., 3., 5.], /// [2., 0., 6.]]; - /// assert_eq!(a.argmin(), Some((1, 1))); + /// assert_eq!(a.argmin(), Ok((1, 1))); /// ``` - fn argmin(&self) -> Option + fn argmin(&self) -> Result where A: PartialOrd; @@ -240,16 +241,16 @@ where /// Finds the elementwise minimum of the array. /// - /// Returns `None` if any of the pairwise orderings tested by the function - /// are undefined. (For example, this occurs if there are any - /// floating-point NaN values in the array.) + /// Returns `Err(MinMaxError::UndefinedOrder)` if any of the pairwise + /// orderings tested by the function are undefined. (For example, this + /// occurs if there are any floating-point NaN values in the array.) /// - /// Additionally, returns `None` if the array is empty. + /// Returns `Err(MinMaxError::EmptyInput)` if the array is empty. /// /// Even if there are multiple (equal) elements that are minima, only one /// is returned. (Which one is returned is unspecified and may depend on /// the memory layout of the array.) - fn min(&self) -> Option<&A> + fn min(&self) -> Result<&A, MinMaxError> where A: PartialOrd; @@ -269,11 +270,11 @@ where /// Finds the index of the maximum value of the array. /// - /// Returns `None` if any of the pairwise orderings tested by the function - /// are undefined. (For example, this occurs if there are any - /// floating-point NaN values in the array.) + /// Returns `Err(MinMaxError::UndefinedOrder)` if any of the pairwise + /// orderings tested by the function are undefined. (For example, this + /// occurs if there are any floating-point NaN values in the array.) /// - /// Returns `None` if the array is empty. + /// Returns `Err(MinMaxError::EmptyInput)` if the array is empty. /// /// Even if there are multiple (equal) elements that are maxima, only one /// index is returned. (Which one is returned is unspecified and may depend @@ -290,9 +291,9 @@ where /// /// let a = array![[1., 3., 7.], /// [2., 5., 6.]]; - /// assert_eq!(a.argmax(), Some((0, 2))); + /// assert_eq!(a.argmax(), Ok((0, 2))); /// ``` - fn argmax(&self) -> Option + fn argmax(&self) -> Result where A: PartialOrd; @@ -325,16 +326,16 @@ where /// Finds the elementwise maximum of the array. /// - /// Returns `None` if any of the pairwise orderings tested by the function - /// are undefined. (For example, this occurs if there are any - /// floating-point NaN values in the array.) + /// Returns `Err(MinMaxError::UndefinedOrder)` if any of the pairwise + /// orderings tested by the function are undefined. (For example, this + /// occurs if there are any floating-point NaN values in the array.) /// - /// Additionally, returns `None` if the array is empty. + /// Returns `Err(EmptyInput)` if the array is empty. /// /// Even if there are multiple (equal) elements that are maxima, only one /// is returned. (Which one is returned is unspecified and may depend on /// the memory layout of the array.) - fn max(&self) -> Option<&A> + fn max(&self) -> Result<&A, MinMaxError> where A: PartialOrd; @@ -406,21 +407,21 @@ where S: Data, D: Dimension, { - fn argmin(&self) -> Option + fn argmin(&self) -> Result where A: PartialOrd, { - let mut current_min = self.first()?; + let mut current_min = self.first().ok_or(EmptyInput)?; let mut current_pattern_min = D::zeros(self.ndim()).into_pattern(); for (pattern, elem) in self.indexed_iter() { - if elem.partial_cmp(current_min)? == cmp::Ordering::Less { + if elem.partial_cmp(current_min).ok_or(UndefinedOrder)? == cmp::Ordering::Less { current_pattern_min = pattern; current_min = elem } } - Some(current_pattern_min) + Ok(current_pattern_min) } fn argmin_skipnan(&self) -> Option @@ -445,14 +446,17 @@ where } } - fn min(&self) -> Option<&A> + fn min(&self) -> Result<&A, MinMaxError> where A: PartialOrd, { - let first = self.first()?; - self.fold(Some(first), |acc, elem| match elem.partial_cmp(acc?)? { - cmp::Ordering::Less => Some(elem), - _ => acc, + let first = self.first().ok_or(EmptyInput)?; + self.fold(Ok(first), |acc, elem| { + let acc = acc?; + match elem.partial_cmp(acc).ok_or(UndefinedOrder)? { + cmp::Ordering::Less => Ok(elem), + _ => Ok(acc), + } }) } @@ -470,21 +474,21 @@ where })) } - fn argmax(&self) -> Option + fn argmax(&self) -> Result where A: PartialOrd, { - let mut current_max = self.first()?; + let mut current_max = self.first().ok_or(EmptyInput)?; let mut current_pattern_max = D::zeros(self.ndim()).into_pattern(); for (pattern, elem) in self.indexed_iter() { - if elem.partial_cmp(current_max)? == cmp::Ordering::Greater { + if elem.partial_cmp(current_max).ok_or(UndefinedOrder)? == cmp::Ordering::Greater { current_pattern_max = pattern; current_max = elem } } - Some(current_pattern_max) + Ok(current_pattern_max) } fn argmax_skipnan(&self) -> Option @@ -509,14 +513,17 @@ where } } - fn max(&self) -> Option<&A> + fn max(&self) -> Result<&A, MinMaxError> where A: PartialOrd, { - let first = self.first()?; - self.fold(Some(first), |acc, elem| match elem.partial_cmp(acc?)? { - cmp::Ordering::Greater => Some(elem), - _ => acc, + let first = self.first().ok_or(EmptyInput)?; + self.fold(Ok(first), |acc, elem| { + let acc = acc?; + match elem.partial_cmp(acc).ok_or(UndefinedOrder)? { + cmp::Ordering::Greater => Ok(elem), + _ => Ok(acc), + } }) } @@ -619,10 +626,10 @@ where /// - worst case: O(`m`^2); /// where `m` is the number of elements in the array. /// - /// Returns `None` if the array is empty. + /// Returns `Err(EmptyInput)` if the array is empty. /// /// **Panics** if `q` is not between `0.` and `1.` (inclusive). - fn quantile_mut(&mut self, q: f64) -> Option + fn quantile_mut(&mut self, q: f64) -> Result where A: Ord + Clone, S: DataMut, @@ -633,16 +640,16 @@ impl Quantile1dExt for ArrayBase where S: Data, { - fn quantile_mut(&mut self, q: f64) -> Option + fn quantile_mut(&mut self, q: f64) -> Result where A: Ord + Clone, S: DataMut, I: Interpolate, { if self.is_empty() { - None + Err(EmptyInput) } else { - Some(self.quantile_axis_mut::(Axis(0), q).into_scalar()) + Ok(self.quantile_axis_mut::(Axis(0), q).into_scalar()) } } } diff --git a/src/summary_statistics/means.rs b/src/summary_statistics/means.rs index f1059efd..a2fed054 100644 --- a/src/summary_statistics/means.rs +++ b/src/summary_statistics/means.rs @@ -1,4 +1,5 @@ use super::SummaryStatisticsExt; +use crate::errors::EmptyInput; use ndarray::{ArrayBase, Data, Dimension}; use num_traits::{Float, FromPrimitive, Zero}; use std::ops::{Add, Div}; @@ -8,28 +9,28 @@ where S: Data, D: Dimension, { - fn mean(&self) -> Option + fn mean(&self) -> Result where A: Clone + FromPrimitive + Add + Div + Zero, { let n_elements = self.len(); if n_elements == 0 { - None + Err(EmptyInput) } else { let n_elements = A::from_usize(n_elements) .expect("Converting number of elements to `A` must not fail."); - Some(self.sum() / n_elements) + Ok(self.sum() / n_elements) } } - fn harmonic_mean(&self) -> Option + fn harmonic_mean(&self) -> Result where A: Float + FromPrimitive, { self.map(|x| x.recip()).mean().map(|x| x.recip()) } - fn geometric_mean(&self) -> Option + fn geometric_mean(&self) -> Result where A: Float + FromPrimitive, { @@ -40,6 +41,7 @@ where #[cfg(test)] mod tests { use super::SummaryStatisticsExt; + use crate::errors::EmptyInput; use approx::abs_diff_eq; use ndarray::{array, Array1}; use noisy_float::types::N64; @@ -56,17 +58,17 @@ mod tests { #[test] fn test_means_with_empty_array_of_floats() { let a: Array1 = array![]; - assert!(a.mean().is_none()); - assert!(a.harmonic_mean().is_none()); - assert!(a.geometric_mean().is_none()); + assert_eq!(a.mean(), Err(EmptyInput)); + assert_eq!(a.harmonic_mean(), Err(EmptyInput)); + assert_eq!(a.geometric_mean(), Err(EmptyInput)); } #[test] fn test_means_with_empty_array_of_noisy_floats() { let a: Array1 = array![]; - assert!(a.mean().is_none()); - assert!(a.harmonic_mean().is_none()); - assert!(a.geometric_mean().is_none()); + assert_eq!(a.mean(), Err(EmptyInput)); + assert_eq!(a.harmonic_mean(), Err(EmptyInput)); + assert_eq!(a.geometric_mean(), Err(EmptyInput)); } #[test] diff --git a/src/summary_statistics/mod.rs b/src/summary_statistics/mod.rs index 6aca865f..30d20b89 100644 --- a/src/summary_statistics/mod.rs +++ b/src/summary_statistics/mod.rs @@ -1,4 +1,5 @@ //! Summary statistics (e.g. mean, variance, etc.). +use crate::errors::EmptyInput; use ndarray::{Data, Dimension}; use num_traits::{Float, FromPrimitive, Zero}; use std::ops::{Add, Div}; @@ -18,12 +19,12 @@ where /// n i=1 /// ``` /// - /// If the array is empty, `None` is returned. + /// If the array is empty, `Err(EmptyInput)` is returned. /// /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array. /// /// [`arithmetic mean`]: https://en.wikipedia.org/wiki/Arithmetic_mean - fn mean(&self) -> Option + fn mean(&self) -> Result where A: Clone + FromPrimitive + Add + Div + Zero; @@ -35,12 +36,12 @@ where /// ⎝i=1 ⎠ /// ``` /// - /// If the array is empty, `None` is returned. + /// If the array is empty, `Err(EmptyInput)` is returned. /// /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array. /// /// [`harmonic mean`]: https://en.wikipedia.org/wiki/Harmonic_mean - fn harmonic_mean(&self) -> Option + fn harmonic_mean(&self) -> Result where A: Float + FromPrimitive; @@ -52,12 +53,12 @@ where /// ⎝i=1 ⎠ /// ``` /// - /// If the array is empty, `None` is returned. + /// If the array is empty, `Err(EmptyInput)` is returned. /// /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array. /// /// [`geometric mean`]: https://en.wikipedia.org/wiki/Geometric_mean - fn geometric_mean(&self) -> Option + fn geometric_mean(&self) -> Result where A: Float + FromPrimitive; } diff --git a/tests/quantile.rs b/tests/quantile.rs index 3e5ba53b..05f1c7a3 100644 --- a/tests/quantile.rs +++ b/tests/quantile.rs @@ -5,6 +5,7 @@ extern crate quickcheck; use ndarray::prelude::*; use ndarray_stats::{ + errors::MinMaxError, interpolate::{Higher, Linear, Lower, Midpoint, Nearest}, Quantile1dExt, QuantileExt, }; @@ -13,22 +14,22 @@ use quickcheck::quickcheck; #[test] fn test_argmin() { let a = array![[1, 5, 3], [2, 0, 6]]; - assert_eq!(a.argmin(), Some((1, 1))); + assert_eq!(a.argmin(), Ok((1, 1))); let a = array![[1., 5., 3.], [2., 0., 6.]]; - assert_eq!(a.argmin(), Some((1, 1))); + assert_eq!(a.argmin(), Ok((1, 1))); let a = array![[1., 5., 3.], [2., ::std::f64::NAN, 6.]]; - assert_eq!(a.argmin(), None); + assert_eq!(a.argmin(), Err(MinMaxError::UndefinedOrder)); let a: Array2 = array![[], []]; - assert_eq!(a.argmin(), None); + assert_eq!(a.argmin(), Err(MinMaxError::EmptyInput)); } quickcheck! { fn argmin_matches_min(data: Vec) -> bool { let a = Array1::from(data); - a.argmin().map(|i| a[i]) == a.min().cloned() + a.argmin().map(|i| &a[i]) == a.min() } } @@ -66,13 +67,13 @@ quickcheck! { #[test] fn test_min() { let a = array![[1, 5, 3], [2, 0, 6]]; - assert_eq!(a.min(), Some(&0)); + assert_eq!(a.min(), Ok(&0)); let a = array![[1., 5., 3.], [2., 0., 6.]]; - assert_eq!(a.min(), Some(&0.)); + assert_eq!(a.min(), Ok(&0.)); let a = array![[1., 5., 3.], [2., ::std::f64::NAN, 6.]]; - assert_eq!(a.min(), None); + assert_eq!(a.min(), Err(MinMaxError::UndefinedOrder)); } #[test] @@ -93,22 +94,22 @@ fn test_min_skipnan_all_nan() { #[test] fn test_argmax() { let a = array![[1, 5, 3], [2, 0, 6]]; - assert_eq!(a.argmax(), Some((1, 2))); + assert_eq!(a.argmax(), Ok((1, 2))); let a = array![[1., 5., 3.], [2., 0., 6.]]; - assert_eq!(a.argmax(), Some((1, 2))); + assert_eq!(a.argmax(), Ok((1, 2))); let a = array![[1., 5., 3.], [2., ::std::f64::NAN, 6.]]; - assert_eq!(a.argmax(), None); + assert_eq!(a.argmax(), Err(MinMaxError::UndefinedOrder)); let a: Array2 = array![[], []]; - assert_eq!(a.argmax(), None); + assert_eq!(a.argmax(), Err(MinMaxError::EmptyInput)); } quickcheck! { fn argmax_matches_max(data: Vec) -> bool { let a = Array1::from(data); - a.argmax().map(|i| a[i]) == a.max().cloned() + a.argmax().map(|i| &a[i]) == a.max() } } @@ -149,13 +150,13 @@ quickcheck! { #[test] fn test_max() { let a = array![[1, 5, 7], [2, 0, 6]]; - assert_eq!(a.max(), Some(&7)); + assert_eq!(a.max(), Ok(&7)); let a = array![[1., 5., 7.], [2., 0., 6.]]; - assert_eq!(a.max(), Some(&7.)); + assert_eq!(a.max(), Ok(&7.)); let a = array![[1., 5., 7.], [2., ::std::f64::NAN, 6.]]; - assert_eq!(a.max(), None); + assert_eq!(a.max(), Err(MinMaxError::UndefinedOrder)); } #[test]