Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Histogram error handling #25

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
bfb0db7
Use Option as return type where things might fail
LukeMathWalker Jan 29, 2019
4057a72
Test suite aligned with docs
LukeMathWalker Jan 29, 2019
fa1eb34
Equispaced does not panic anymore
Jan 30, 2019
669a33f
Fixed some tests
Jan 30, 2019
0e5eb6b
Fixed FD tests
Jan 30, 2019
04cf0a7
Fixed wrong condition in IF
Jan 30, 2019
4f429bc
Fixed wrong test
Jan 30, 2019
4e74c48
Added new test for EquiSpaced and fixed old one
Jan 30, 2019
56b7e45
Fixed doc tests
Jan 30, 2019
12906fd
Fix docs.
LukeMathWalker Feb 2, 2019
9d1862f
Fix docs.
LukeMathWalker Feb 2, 2019
64789d6
Fix docs.
LukeMathWalker Feb 2, 2019
fe150d1
Fmt
LukeMathWalker Mar 26, 2019
facd4c4
Merge master
LukeMathWalker Mar 26, 2019
b28c35a
Create StrategyError
LukeMathWalker Mar 26, 2019
c06f382
Fmt
LukeMathWalker Mar 26, 2019
4a24f5a
Return Result. Fix Equispaced, Sqrt and Rice
LukeMathWalker Mar 26, 2019
f708a17
Fix Rice
LukeMathWalker Mar 26, 2019
58788db
Fixed Sturges
LukeMathWalker Mar 26, 2019
3014f77
Fix strategies
LukeMathWalker Mar 26, 2019
17e5efc
Fix match
LukeMathWalker Mar 26, 2019
63abed5
Tests compile
LukeMathWalker Mar 26, 2019
4a4b489
Fix assertion
LukeMathWalker Mar 26, 2019
f692887
Fmt
LukeMathWalker Mar 26, 2019
a8ad4b1
Add more error types
jturner314 Mar 31, 2019
29f56f3
Rename StrategyError to BinsBuildError
jturner314 Apr 1, 2019
bca2dc9
Make GridBuilder::from_array return Result
jturner314 Apr 1, 2019
f41b317
Make BinsBuildError enum non-exhaustive
jturner314 Apr 1, 2019
308e0e7
Merge pull request #4 from jturner314/histogram-error-handling
LukeMathWalker Apr 1, 2019
c280c6b
Use lazy OR operator.
LukeMathWalker Apr 1, 2019
6481509
Merge remote-tracking branch 'origin/histogram-error-handling' into h…
LukeMathWalker Apr 1, 2019
701842d
Use lazy OR operator.
LukeMathWalker Apr 1, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 48 additions & 49 deletions src/entropy.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Information theory (e.g. entropy, KL divergence, etc.).
use crate::errors::ShapeMismatch;
use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch};
use ndarray::{Array, ArrayBase, Data, Dimension, Zip};
use num_traits::Float;

Expand All @@ -19,7 +19,7 @@ where
/// i=1
/// ```
///
/// If the array is empty, `None` is returned.
/// If the array is empty, `Err(EmptyInput)` is returned.
///
/// **Panics** if `ln` of any element in the array panics (which can occur for negative values for some `A`).
///
Expand All @@ -38,7 +38,7 @@ where
///
/// [entropy]: https://en.wikipedia.org/wiki/Entropy_(information_theory)
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
fn entropy(&self) -> Option<A>
fn entropy(&self) -> Result<A, EmptyInput>
where
A: Float;

Expand All @@ -53,8 +53,9 @@ where
/// i=1
/// ```
///
/// If the arrays are empty, Ok(`None`) is returned.
/// If the array shapes are not identical, `Err(ShapeMismatch)` is returned.
/// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned.
/// If the array shapes are not identical,
/// `Err(MultiInputError::ShapeMismatch)` is returned.
///
/// **Panics** if, for a pair of elements *(pᵢ, qᵢ)* from *p* and *q*, computing
/// *ln(qᵢ/pᵢ)* is a panic cause for `A`.
Expand All @@ -73,7 +74,7 @@ where
///
/// [Kullback-Leibler divergence]: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
where
S2: Data<Elem = A>,
A: Float;
Expand All @@ -89,8 +90,9 @@ where
/// i=1
/// ```
///
/// If the arrays are empty, Ok(`None`) is returned.
/// If the array shapes are not identical, `Err(ShapeMismatch)` is returned.
/// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned.
/// If the array shapes are not identical,
/// `Err(MultiInputError::ShapeMismatch)` is returned.
///
/// **Panics** if any element in *q* is negative and taking the logarithm of a negative number
/// is a panic cause for `A`.
Expand All @@ -114,7 +116,7 @@ where
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
/// [optimization problems]: https://en.wikipedia.org/wiki/Cross-entropy_method
/// [machine learning]: https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression
fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
where
S2: Data<Elem = A>,
A: Float;
Expand All @@ -125,14 +127,14 @@ where
S: Data<Elem = A>,
D: Dimension,
{
fn entropy(&self) -> Option<A>
fn entropy(&self) -> Result<A, EmptyInput>
where
A: Float,
{
if self.len() == 0 {
None
Err(EmptyInput)
} else {
let entropy = self
let entropy = -self
.mapv(|x| {
if x == A::zero() {
A::zero()
Expand All @@ -141,23 +143,24 @@ where
}
})
.sum();
Some(-entropy)
Ok(entropy)
}
}

fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
where
A: Float,
S2: Data<Elem = A>,
{
if self.len() == 0 {
return Ok(None);
return Err(MultiInputError::EmptyInput);
}
if self.shape() != q.shape() {
return Err(ShapeMismatch {
first_shape: self.shape().to_vec(),
second_shape: q.shape().to_vec(),
});
}
.into());
}

let mut temp = Array::zeros(self.raw_dim());
Expand All @@ -174,22 +177,23 @@ where
}
});
let kl_divergence = -temp.sum();
Ok(Some(kl_divergence))
Ok(kl_divergence)
}

fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
where
S2: Data<Elem = A>,
A: Float,
{
if self.len() == 0 {
return Ok(None);
return Err(MultiInputError::EmptyInput);
}
if self.shape() != q.shape() {
return Err(ShapeMismatch {
first_shape: self.shape().to_vec(),
second_shape: q.shape().to_vec(),
});
}
.into());
}

let mut temp = Array::zeros(self.raw_dim());
Expand All @@ -206,15 +210,15 @@ where
}
});
let cross_entropy = -temp.sum();
Ok(Some(cross_entropy))
Ok(cross_entropy)
}
}

#[cfg(test)]
mod tests {
use super::EntropyExt;
use approx::assert_abs_diff_eq;
use errors::ShapeMismatch;
use errors::{EmptyInput, MultiInputError};
use ndarray::{array, Array1};
use noisy_float::types::n64;
use std::f64;
Expand All @@ -228,7 +232,7 @@ mod tests {
#[test]
fn test_entropy_with_empty_array_of_floats() {
let a: Array1<f64> = array![];
assert!(a.entropy().is_none());
assert_eq!(a.entropy(), Err(EmptyInput));
}

#[test]
Expand All @@ -251,13 +255,13 @@ mod tests {
}

#[test]
fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), ShapeMismatch> {
fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), MultiInputError> {
let a = array![f64::NAN, 1.];
let b = array![2., 1.];
assert!(a.cross_entropy(&b)?.unwrap().is_nan());
assert!(b.cross_entropy(&a)?.unwrap().is_nan());
assert!(a.kl_divergence(&b)?.unwrap().is_nan());
assert!(b.kl_divergence(&a)?.unwrap().is_nan());
assert!(a.cross_entropy(&b)?.is_nan());
assert!(b.cross_entropy(&a)?.is_nan());
assert!(a.kl_divergence(&b)?.is_nan());
assert!(b.kl_divergence(&a)?.is_nan());
Ok(())
}

Expand All @@ -284,20 +288,19 @@ mod tests {
}

#[test]
fn test_cross_entropy_and_kl_with_empty_array_of_floats() -> Result<(), ShapeMismatch> {
fn test_cross_entropy_and_kl_with_empty_array_of_floats() {
let p: Array1<f64> = array![];
let q: Array1<f64> = array![];
assert!(p.cross_entropy(&q)?.is_none());
assert!(p.kl_divergence(&q)?.is_none());
Ok(())
assert!(p.cross_entropy(&q).unwrap_err().is_empty_input());
assert!(p.kl_divergence(&q).unwrap_err().is_empty_input());
}

#[test]
fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), ShapeMismatch> {
fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), MultiInputError> {
let p = array![1.];
let q = array![-1.];
let cross_entropy: f64 = p.cross_entropy(&q)?.unwrap();
let kl_divergence: f64 = p.kl_divergence(&q)?.unwrap();
let cross_entropy: f64 = p.cross_entropy(&q)?;
let kl_divergence: f64 = p.kl_divergence(&q)?;
assert!(cross_entropy.is_nan());
assert!(kl_divergence.is_nan());
Ok(())
Expand All @@ -320,26 +323,26 @@ mod tests {
}

#[test]
fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), ShapeMismatch> {
fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), MultiInputError> {
let p = array![0., 0.];
let q = array![0., 0.5];
assert_eq!(p.cross_entropy(&q)?.unwrap(), 0.);
assert_eq!(p.kl_divergence(&q)?.unwrap(), 0.);
assert_eq!(p.cross_entropy(&q)?, 0.);
assert_eq!(p.kl_divergence(&q)?, 0.);
Ok(())
}

#[test]
fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership(
) -> Result<(), ShapeMismatch> {
) -> Result<(), MultiInputError> {
let p = array![0.5, 0.5];
let mut q = array![0.5, 0.];
assert_eq!(p.cross_entropy(&q.view_mut())?.unwrap(), f64::INFINITY);
assert_eq!(p.kl_divergence(&q.view_mut())?.unwrap(), f64::INFINITY);
assert_eq!(p.cross_entropy(&q.view_mut())?, f64::INFINITY);
assert_eq!(p.kl_divergence(&q.view_mut())?, f64::INFINITY);
Ok(())
}

#[test]
fn test_cross_entropy() -> Result<(), ShapeMismatch> {
fn test_cross_entropy() -> Result<(), MultiInputError> {
// Arrays of probability values - normalized and positive.
let p: Array1<f64> = array![
0.05340169, 0.02508511, 0.03460454, 0.00352313, 0.07837615, 0.05859495, 0.05782189,
Expand All @@ -356,16 +359,12 @@ mod tests {
// Computed using scipy.stats.entropy(p) + scipy.stats.entropy(p, q)
let expected_cross_entropy = 3.385347705020779;

assert_abs_diff_eq!(
p.cross_entropy(&q)?.unwrap(),
expected_cross_entropy,
epsilon = 1e-6
);
assert_abs_diff_eq!(p.cross_entropy(&q)?, expected_cross_entropy, epsilon = 1e-6);
Ok(())
}

#[test]
fn test_kl() -> Result<(), ShapeMismatch> {
fn test_kl() -> Result<(), MultiInputError> {
// Arrays of probability values - normalized and positive.
let p: Array1<f64> = array![
0.00150472, 0.01388706, 0.03495376, 0.03264211, 0.03067355, 0.02183501, 0.00137516,
Expand All @@ -390,7 +389,7 @@ mod tests {
// Computed using scipy.stats.entropy(p, q)
let expected_kl = 0.3555862567800096;

assert_abs_diff_eq!(p.kl_divergence(&q)?.unwrap(), expected_kl, epsilon = 1e-6);
assert_abs_diff_eq!(p.kl_divergence(&q)?, expected_kl, epsilon = 1e-6);
Ok(())
}
}
92 changes: 91 additions & 1 deletion src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,50 @@
use std::error::Error;
use std::fmt;

#[derive(Debug)]
/// An error that indicates that the input array was empty.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct EmptyInput;

impl fmt::Display for EmptyInput {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Empty input.")
}
}

impl Error for EmptyInput {}

/// An error computing a minimum/maximum value.
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum MinMaxError {
/// The input was empty.
EmptyInput,
/// The ordering between a tested pair of values was undefined.
UndefinedOrder,
}

impl fmt::Display for MinMaxError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
MinMaxError::EmptyInput => write!(f, "Empty input."),
MinMaxError::UndefinedOrder => {
write!(f, "Undefined ordering between a tested pair of values.")
}
}
}
}

impl Error for MinMaxError {}

impl From<EmptyInput> for MinMaxError {
fn from(_: EmptyInput) -> MinMaxError {
MinMaxError::EmptyInput
}
}

/// An error used by methods and functions that take two arrays as argument and
/// expect them to have exactly the same shape
/// (e.g. `ShapeMismatch` is raised when `a.shape() == b.shape()` evaluates to `False`).
#[derive(Clone, Debug)]
pub struct ShapeMismatch {
pub first_shape: Vec<usize>,
pub second_shape: Vec<usize>,
Expand All @@ -22,3 +62,53 @@ impl fmt::Display for ShapeMismatch {
}

impl Error for ShapeMismatch {}

/// An error for methods that take multiple non-empty array inputs.
#[derive(Clone, Debug)]
pub enum MultiInputError {
/// One or more of the arrays were empty.
EmptyInput,
/// The arrays did not have the same shape.
ShapeMismatch(ShapeMismatch),
}

impl MultiInputError {
/// Returns whether `self` is the `EmptyInput` variant.
pub fn is_empty_input(&self) -> bool {
match self {
MultiInputError::EmptyInput => true,
_ => false,
}
}

/// Returns whether `self` is the `ShapeMismatch` variant.
pub fn is_shape_mismatch(&self) -> bool {
match self {
MultiInputError::ShapeMismatch(_) => true,
_ => false,
}
}
}

impl fmt::Display for MultiInputError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
MultiInputError::EmptyInput => write!(f, "Empty input."),
MultiInputError::ShapeMismatch(e) => write!(f, "Shape mismatch: {}", e),
}
}
}

impl Error for MultiInputError {}

impl From<EmptyInput> for MultiInputError {
fn from(_: EmptyInput) -> Self {
MultiInputError::EmptyInput
}
}

impl From<ShapeMismatch> for MultiInputError {
fn from(err: ShapeMismatch) -> Self {
MultiInputError::ShapeMismatch(err)
}
}
Loading