Skip to content

Commit

Permalink
Histogram error handling (#25)
Browse files Browse the repository at this point in the history
* Use Option as return type where things might fail

* Test suite aligned with docs

* Equispaced does not panic anymore

* Fixed some tests

* Fixed FD tests

* Fixed wrong condition in IF

* Fixed wrong test

* Added new test for EquiSpaced and fixed old one

* Fixed doc tests

* Fix docs.

* Fix docs.

* Fix docs.

* Fmt

* Create StrategyError

* Fmt

* Return Result. Fix Equispaced, Sqrt and Rice

* Fix Rice

* Fixed Sturges

* Fix strategies

* Fix match

* Tests compile

* Fix assertion

* Fmt

* Add more error types

* Rename StrategyError to BinsBuildError

* Make GridBuilder::from_array return Result

This is nice because it doesn't lose information. (Returning an
`Option` combines the two error variants into a single case.)

* Make BinsBuildError enum non-exhaustive

Once the `#[non_exhaustive]` attribute is stable, we should replace
the hidden enum variant with that attribute on the enum.

* Use lazy OR operator.

* Use lazy OR operator.
  • Loading branch information
LukeMathWalker committed Apr 2, 2019
1 parent d838ee7 commit 86e5ca4
Show file tree
Hide file tree
Showing 10 changed files with 423 additions and 211 deletions.
97 changes: 48 additions & 49 deletions src/entropy.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Information theory (e.g. entropy, KL divergence, etc.).
use crate::errors::ShapeMismatch;
use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch};
use ndarray::{Array, ArrayBase, Data, Dimension, Zip};
use num_traits::Float;

Expand All @@ -19,7 +19,7 @@ where
/// i=1
/// ```
///
/// If the array is empty, `None` is returned.
/// If the array is empty, `Err(EmptyInput)` is returned.
///
/// **Panics** if `ln` of any element in the array panics (which can occur for negative values for some `A`).
///
Expand All @@ -38,7 +38,7 @@ where
///
/// [entropy]: https://en.wikipedia.org/wiki/Entropy_(information_theory)
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
fn entropy(&self) -> Option<A>
fn entropy(&self) -> Result<A, EmptyInput>
where
A: Float;

Expand All @@ -53,8 +53,9 @@ where
/// i=1
/// ```
///
/// If the arrays are empty, Ok(`None`) is returned.
/// If the array shapes are not identical, `Err(ShapeMismatch)` is returned.
/// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned.
/// If the array shapes are not identical,
/// `Err(MultiInputError::ShapeMismatch)` is returned.
///
/// **Panics** if, for a pair of elements *(pᵢ, qᵢ)* from *p* and *q*, computing
/// *ln(qᵢ/pᵢ)* is a panic cause for `A`.
Expand All @@ -73,7 +74,7 @@ where
///
/// [Kullback-Leibler divergence]: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
where
S2: Data<Elem = A>,
A: Float;
Expand All @@ -89,8 +90,9 @@ where
/// i=1
/// ```
///
/// If the arrays are empty, Ok(`None`) is returned.
/// If the array shapes are not identical, `Err(ShapeMismatch)` is returned.
/// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned.
/// If the array shapes are not identical,
/// `Err(MultiInputError::ShapeMismatch)` is returned.
///
/// **Panics** if any element in *q* is negative and taking the logarithm of a negative number
/// is a panic cause for `A`.
Expand All @@ -114,7 +116,7 @@ where
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
/// [optimization problems]: https://en.wikipedia.org/wiki/Cross-entropy_method
/// [machine learning]: https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression
fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
where
S2: Data<Elem = A>,
A: Float;
Expand All @@ -125,14 +127,14 @@ where
S: Data<Elem = A>,
D: Dimension,
{
fn entropy(&self) -> Option<A>
fn entropy(&self) -> Result<A, EmptyInput>
where
A: Float,
{
if self.len() == 0 {
None
Err(EmptyInput)
} else {
let entropy = self
let entropy = -self
.mapv(|x| {
if x == A::zero() {
A::zero()
Expand All @@ -141,23 +143,24 @@ where
}
})
.sum();
Some(-entropy)
Ok(entropy)
}
}

fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
where
A: Float,
S2: Data<Elem = A>,
{
if self.len() == 0 {
return Ok(None);
return Err(MultiInputError::EmptyInput);
}
if self.shape() != q.shape() {
return Err(ShapeMismatch {
first_shape: self.shape().to_vec(),
second_shape: q.shape().to_vec(),
});
}
.into());
}

let mut temp = Array::zeros(self.raw_dim());
Expand All @@ -174,22 +177,23 @@ where
}
});
let kl_divergence = -temp.sum();
Ok(Some(kl_divergence))
Ok(kl_divergence)
}

fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
where
S2: Data<Elem = A>,
A: Float,
{
if self.len() == 0 {
return Ok(None);
return Err(MultiInputError::EmptyInput);
}
if self.shape() != q.shape() {
return Err(ShapeMismatch {
first_shape: self.shape().to_vec(),
second_shape: q.shape().to_vec(),
});
}
.into());
}

let mut temp = Array::zeros(self.raw_dim());
Expand All @@ -206,15 +210,15 @@ where
}
});
let cross_entropy = -temp.sum();
Ok(Some(cross_entropy))
Ok(cross_entropy)
}
}

#[cfg(test)]
mod tests {
use super::EntropyExt;
use approx::assert_abs_diff_eq;
use errors::ShapeMismatch;
use errors::{EmptyInput, MultiInputError};
use ndarray::{array, Array1};
use noisy_float::types::n64;
use std::f64;
Expand All @@ -228,7 +232,7 @@ mod tests {
#[test]
fn test_entropy_with_empty_array_of_floats() {
let a: Array1<f64> = array![];
assert!(a.entropy().is_none());
assert_eq!(a.entropy(), Err(EmptyInput));
}

#[test]
Expand All @@ -251,13 +255,13 @@ mod tests {
}

#[test]
fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), ShapeMismatch> {
fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), MultiInputError> {
let a = array![f64::NAN, 1.];
let b = array![2., 1.];
assert!(a.cross_entropy(&b)?.unwrap().is_nan());
assert!(b.cross_entropy(&a)?.unwrap().is_nan());
assert!(a.kl_divergence(&b)?.unwrap().is_nan());
assert!(b.kl_divergence(&a)?.unwrap().is_nan());
assert!(a.cross_entropy(&b)?.is_nan());
assert!(b.cross_entropy(&a)?.is_nan());
assert!(a.kl_divergence(&b)?.is_nan());
assert!(b.kl_divergence(&a)?.is_nan());
Ok(())
}

Expand All @@ -284,20 +288,19 @@ mod tests {
}

#[test]
fn test_cross_entropy_and_kl_with_empty_array_of_floats() -> Result<(), ShapeMismatch> {
fn test_cross_entropy_and_kl_with_empty_array_of_floats() {
let p: Array1<f64> = array![];
let q: Array1<f64> = array![];
assert!(p.cross_entropy(&q)?.is_none());
assert!(p.kl_divergence(&q)?.is_none());
Ok(())
assert!(p.cross_entropy(&q).unwrap_err().is_empty_input());
assert!(p.kl_divergence(&q).unwrap_err().is_empty_input());
}

#[test]
fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), ShapeMismatch> {
fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), MultiInputError> {
let p = array![1.];
let q = array![-1.];
let cross_entropy: f64 = p.cross_entropy(&q)?.unwrap();
let kl_divergence: f64 = p.kl_divergence(&q)?.unwrap();
let cross_entropy: f64 = p.cross_entropy(&q)?;
let kl_divergence: f64 = p.kl_divergence(&q)?;
assert!(cross_entropy.is_nan());
assert!(kl_divergence.is_nan());
Ok(())
Expand All @@ -320,26 +323,26 @@ mod tests {
}

#[test]
fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), ShapeMismatch> {
fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), MultiInputError> {
let p = array![0., 0.];
let q = array![0., 0.5];
assert_eq!(p.cross_entropy(&q)?.unwrap(), 0.);
assert_eq!(p.kl_divergence(&q)?.unwrap(), 0.);
assert_eq!(p.cross_entropy(&q)?, 0.);
assert_eq!(p.kl_divergence(&q)?, 0.);
Ok(())
}

#[test]
fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership(
) -> Result<(), ShapeMismatch> {
) -> Result<(), MultiInputError> {
let p = array![0.5, 0.5];
let mut q = array![0.5, 0.];
assert_eq!(p.cross_entropy(&q.view_mut())?.unwrap(), f64::INFINITY);
assert_eq!(p.kl_divergence(&q.view_mut())?.unwrap(), f64::INFINITY);
assert_eq!(p.cross_entropy(&q.view_mut())?, f64::INFINITY);
assert_eq!(p.kl_divergence(&q.view_mut())?, f64::INFINITY);
Ok(())
}

#[test]
fn test_cross_entropy() -> Result<(), ShapeMismatch> {
fn test_cross_entropy() -> Result<(), MultiInputError> {
// Arrays of probability values - normalized and positive.
let p: Array1<f64> = array![
0.05340169, 0.02508511, 0.03460454, 0.00352313, 0.07837615, 0.05859495, 0.05782189,
Expand All @@ -356,16 +359,12 @@ mod tests {
// Computed using scipy.stats.entropy(p) + scipy.stats.entropy(p, q)
let expected_cross_entropy = 3.385347705020779;

assert_abs_diff_eq!(
p.cross_entropy(&q)?.unwrap(),
expected_cross_entropy,
epsilon = 1e-6
);
assert_abs_diff_eq!(p.cross_entropy(&q)?, expected_cross_entropy, epsilon = 1e-6);
Ok(())
}

#[test]
fn test_kl() -> Result<(), ShapeMismatch> {
fn test_kl() -> Result<(), MultiInputError> {
// Arrays of probability values - normalized and positive.
let p: Array1<f64> = array![
0.00150472, 0.01388706, 0.03495376, 0.03264211, 0.03067355, 0.02183501, 0.00137516,
Expand All @@ -390,7 +389,7 @@ mod tests {
// Computed using scipy.stats.entropy(p, q)
let expected_kl = 0.3555862567800096;

assert_abs_diff_eq!(p.kl_divergence(&q)?.unwrap(), expected_kl, epsilon = 1e-6);
assert_abs_diff_eq!(p.kl_divergence(&q)?, expected_kl, epsilon = 1e-6);
Ok(())
}
}
92 changes: 91 additions & 1 deletion src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,50 @@
use std::error::Error;
use std::fmt;

#[derive(Debug)]
/// An error that indicates that the input array was empty.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct EmptyInput;

impl fmt::Display for EmptyInput {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Empty input.")
}
}

impl Error for EmptyInput {}

/// An error computing a minimum/maximum value.
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum MinMaxError {
/// The input was empty.
EmptyInput,
/// The ordering between a tested pair of values was undefined.
UndefinedOrder,
}

impl fmt::Display for MinMaxError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
MinMaxError::EmptyInput => write!(f, "Empty input."),
MinMaxError::UndefinedOrder => {
write!(f, "Undefined ordering between a tested pair of values.")
}
}
}
}

impl Error for MinMaxError {}

impl From<EmptyInput> for MinMaxError {
fn from(_: EmptyInput) -> MinMaxError {
MinMaxError::EmptyInput
}
}

/// An error used by methods and functions that take two arrays as argument and
/// expect them to have exactly the same shape
/// (e.g. `ShapeMismatch` is raised when `a.shape() == b.shape()` evaluates to `False`).
#[derive(Clone, Debug)]
pub struct ShapeMismatch {
pub first_shape: Vec<usize>,
pub second_shape: Vec<usize>,
Expand All @@ -22,3 +62,53 @@ impl fmt::Display for ShapeMismatch {
}

impl Error for ShapeMismatch {}

/// An error for methods that take multiple non-empty array inputs.
#[derive(Clone, Debug)]
pub enum MultiInputError {
/// One or more of the arrays were empty.
EmptyInput,
/// The arrays did not have the same shape.
ShapeMismatch(ShapeMismatch),
}

impl MultiInputError {
/// Returns whether `self` is the `EmptyInput` variant.
pub fn is_empty_input(&self) -> bool {
match self {
MultiInputError::EmptyInput => true,
_ => false,
}
}

/// Returns whether `self` is the `ShapeMismatch` variant.
pub fn is_shape_mismatch(&self) -> bool {
match self {
MultiInputError::ShapeMismatch(_) => true,
_ => false,
}
}
}

impl fmt::Display for MultiInputError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
MultiInputError::EmptyInput => write!(f, "Empty input."),
MultiInputError::ShapeMismatch(e) => write!(f, "Shape mismatch: {}", e),
}
}
}

impl Error for MultiInputError {}

impl From<EmptyInput> for MultiInputError {
fn from(_: EmptyInput) -> Self {
MultiInputError::EmptyInput
}
}

impl From<ShapeMismatch> for MultiInputError {
fn from(err: ShapeMismatch) -> Self {
MultiInputError::ShapeMismatch(err)
}
}
Loading

0 comments on commit 86e5ca4

Please sign in to comment.