Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error handling #4

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 48 additions & 49 deletions src/entropy.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Information theory (e.g. entropy, KL divergence, etc.).
use crate::errors::ShapeMismatch;
use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch};
use ndarray::{Array, ArrayBase, Data, Dimension, Zip};
use num_traits::Float;

Expand All @@ -19,7 +19,7 @@ where
/// i=1
/// ```
///
/// If the array is empty, `None` is returned.
/// If the array is empty, `Err(EmptyInput)` is returned.
///
/// **Panics** if `ln` of any element in the array panics (which can occur for negative values for some `A`).
///
Expand All @@ -38,7 +38,7 @@ where
///
/// [entropy]: https://en.wikipedia.org/wiki/Entropy_(information_theory)
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
fn entropy(&self) -> Option<A>
fn entropy(&self) -> Result<A, EmptyInput>
where
A: Float;

Expand All @@ -53,8 +53,9 @@ where
/// i=1
/// ```
///
/// If the arrays are empty, Ok(`None`) is returned.
/// If the array shapes are not identical, `Err(ShapeMismatch)` is returned.
/// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned.
/// If the array shapes are not identical,
/// `Err(MultiInputError::ShapeMismatch)` is returned.
///
/// **Panics** if, for a pair of elements *(pᵢ, qᵢ)* from *p* and *q*, computing
/// *ln(qᵢ/pᵢ)* is a panic cause for `A`.
Expand All @@ -73,7 +74,7 @@ where
///
/// [Kullback-Leibler divergence]: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
where
S2: Data<Elem = A>,
A: Float;
Expand All @@ -89,8 +90,9 @@ where
/// i=1
/// ```
///
/// If the arrays are empty, Ok(`None`) is returned.
/// If the array shapes are not identical, `Err(ShapeMismatch)` is returned.
/// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned.
/// If the array shapes are not identical,
/// `Err(MultiInputError::ShapeMismatch)` is returned.
///
/// **Panics** if any element in *q* is negative and taking the logarithm of a negative number
/// is a panic cause for `A`.
Expand All @@ -114,7 +116,7 @@ where
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
/// [optimization problems]: https://en.wikipedia.org/wiki/Cross-entropy_method
/// [machine learning]: https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression
fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
where
S2: Data<Elem = A>,
A: Float;
Expand All @@ -125,14 +127,14 @@ where
S: Data<Elem = A>,
D: Dimension,
{
fn entropy(&self) -> Option<A>
fn entropy(&self) -> Result<A, EmptyInput>
where
A: Float,
{
if self.len() == 0 {
None
Err(EmptyInput)
} else {
let entropy = self
let entropy = -self
.mapv(|x| {
if x == A::zero() {
A::zero()
Expand All @@ -141,23 +143,24 @@ where
}
})
.sum();
Some(-entropy)
Ok(entropy)
}
}

fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
where
A: Float,
S2: Data<Elem = A>,
{
if self.len() == 0 {
return Ok(None);
return Err(MultiInputError::EmptyInput);
}
if self.shape() != q.shape() {
return Err(ShapeMismatch {
first_shape: self.shape().to_vec(),
second_shape: q.shape().to_vec(),
});
}
.into());
}

let mut temp = Array::zeros(self.raw_dim());
Expand All @@ -174,22 +177,23 @@ where
}
});
let kl_divergence = -temp.sum();
Ok(Some(kl_divergence))
Ok(kl_divergence)
}

fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
where
S2: Data<Elem = A>,
A: Float,
{
if self.len() == 0 {
return Ok(None);
return Err(MultiInputError::EmptyInput);
}
if self.shape() != q.shape() {
return Err(ShapeMismatch {
first_shape: self.shape().to_vec(),
second_shape: q.shape().to_vec(),
});
}
.into());
}

let mut temp = Array::zeros(self.raw_dim());
Expand All @@ -206,15 +210,15 @@ where
}
});
let cross_entropy = -temp.sum();
Ok(Some(cross_entropy))
Ok(cross_entropy)
}
}

#[cfg(test)]
mod tests {
use super::EntropyExt;
use approx::assert_abs_diff_eq;
use errors::ShapeMismatch;
use errors::{EmptyInput, MultiInputError};
use ndarray::{array, Array1};
use noisy_float::types::n64;
use std::f64;
Expand All @@ -228,7 +232,7 @@ mod tests {
#[test]
fn test_entropy_with_empty_array_of_floats() {
let a: Array1<f64> = array![];
assert!(a.entropy().is_none());
assert_eq!(a.entropy(), Err(EmptyInput));
}

#[test]
Expand All @@ -251,13 +255,13 @@ mod tests {
}

#[test]
fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), ShapeMismatch> {
fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), MultiInputError> {
let a = array![f64::NAN, 1.];
let b = array![2., 1.];
assert!(a.cross_entropy(&b)?.unwrap().is_nan());
assert!(b.cross_entropy(&a)?.unwrap().is_nan());
assert!(a.kl_divergence(&b)?.unwrap().is_nan());
assert!(b.kl_divergence(&a)?.unwrap().is_nan());
assert!(a.cross_entropy(&b)?.is_nan());
assert!(b.cross_entropy(&a)?.is_nan());
assert!(a.kl_divergence(&b)?.is_nan());
assert!(b.kl_divergence(&a)?.is_nan());
Ok(())
}

Expand All @@ -284,20 +288,19 @@ mod tests {
}

#[test]
fn test_cross_entropy_and_kl_with_empty_array_of_floats() -> Result<(), ShapeMismatch> {
fn test_cross_entropy_and_kl_with_empty_array_of_floats() {
let p: Array1<f64> = array![];
let q: Array1<f64> = array![];
assert!(p.cross_entropy(&q)?.is_none());
assert!(p.kl_divergence(&q)?.is_none());
Ok(())
assert!(p.cross_entropy(&q).unwrap_err().is_empty_input());
assert!(p.kl_divergence(&q).unwrap_err().is_empty_input());
}

#[test]
fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), ShapeMismatch> {
fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), MultiInputError> {
let p = array![1.];
let q = array![-1.];
let cross_entropy: f64 = p.cross_entropy(&q)?.unwrap();
let kl_divergence: f64 = p.kl_divergence(&q)?.unwrap();
let cross_entropy: f64 = p.cross_entropy(&q)?;
let kl_divergence: f64 = p.kl_divergence(&q)?;
assert!(cross_entropy.is_nan());
assert!(kl_divergence.is_nan());
Ok(())
Expand All @@ -320,26 +323,26 @@ mod tests {
}

#[test]
fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), ShapeMismatch> {
fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), MultiInputError> {
let p = array![0., 0.];
let q = array![0., 0.5];
assert_eq!(p.cross_entropy(&q)?.unwrap(), 0.);
assert_eq!(p.kl_divergence(&q)?.unwrap(), 0.);
assert_eq!(p.cross_entropy(&q)?, 0.);
assert_eq!(p.kl_divergence(&q)?, 0.);
Ok(())
}

#[test]
fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership(
) -> Result<(), ShapeMismatch> {
) -> Result<(), MultiInputError> {
let p = array![0.5, 0.5];
let mut q = array![0.5, 0.];
assert_eq!(p.cross_entropy(&q.view_mut())?.unwrap(), f64::INFINITY);
assert_eq!(p.kl_divergence(&q.view_mut())?.unwrap(), f64::INFINITY);
assert_eq!(p.cross_entropy(&q.view_mut())?, f64::INFINITY);
assert_eq!(p.kl_divergence(&q.view_mut())?, f64::INFINITY);
Ok(())
}

#[test]
fn test_cross_entropy() -> Result<(), ShapeMismatch> {
fn test_cross_entropy() -> Result<(), MultiInputError> {
// Arrays of probability values - normalized and positive.
let p: Array1<f64> = array![
0.05340169, 0.02508511, 0.03460454, 0.00352313, 0.07837615, 0.05859495, 0.05782189,
Expand All @@ -356,16 +359,12 @@ mod tests {
// Computed using scipy.stats.entropy(p) + scipy.stats.entropy(p, q)
let expected_cross_entropy = 3.385347705020779;

assert_abs_diff_eq!(
p.cross_entropy(&q)?.unwrap(),
expected_cross_entropy,
epsilon = 1e-6
);
assert_abs_diff_eq!(p.cross_entropy(&q)?, expected_cross_entropy, epsilon = 1e-6);
Ok(())
}

#[test]
fn test_kl() -> Result<(), ShapeMismatch> {
fn test_kl() -> Result<(), MultiInputError> {
// Arrays of probability values - normalized and positive.
let p: Array1<f64> = array![
0.00150472, 0.01388706, 0.03495376, 0.03264211, 0.03067355, 0.02183501, 0.00137516,
Expand All @@ -390,7 +389,7 @@ mod tests {
// Computed using scipy.stats.entropy(p, q)
let expected_kl = 0.3555862567800096;

assert_abs_diff_eq!(p.kl_divergence(&q)?.unwrap(), expected_kl, epsilon = 1e-6);
assert_abs_diff_eq!(p.kl_divergence(&q)?, expected_kl, epsilon = 1e-6);
Ok(())
}
}
92 changes: 91 additions & 1 deletion src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,50 @@
use std::error::Error;
use std::fmt;

#[derive(Debug)]
/// An error that indicates that the input array was empty.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct EmptyInput;

impl fmt::Display for EmptyInput {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Empty input.")
}
}

impl Error for EmptyInput {}

/// An error computing a minimum/maximum value.
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum MinMaxError {
/// The input was empty.
EmptyInput,
/// The ordering between a tested pair of values was undefined.
UndefinedOrder,
}

impl fmt::Display for MinMaxError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
MinMaxError::EmptyInput => write!(f, "Empty input."),
MinMaxError::UndefinedOrder => {
write!(f, "Undefined ordering between a tested pair of values.")
}
}
}
}

impl Error for MinMaxError {}

impl From<EmptyInput> for MinMaxError {
fn from(_: EmptyInput) -> MinMaxError {
MinMaxError::EmptyInput
}
}

/// An error used by methods and functions that take two arrays as argument and
/// expect them to have exactly the same shape
/// (e.g. `ShapeMismatch` is raised when `a.shape() == b.shape()` evaluates to `False`).
#[derive(Clone, Debug)]
pub struct ShapeMismatch {
pub first_shape: Vec<usize>,
pub second_shape: Vec<usize>,
Expand All @@ -22,3 +62,53 @@ impl fmt::Display for ShapeMismatch {
}

impl Error for ShapeMismatch {}

/// An error for methods that take multiple non-empty array inputs.
#[derive(Clone, Debug)]
pub enum MultiInputError {
/// One or more of the arrays were empty.
EmptyInput,
/// The arrays did not have the same shape.
ShapeMismatch(ShapeMismatch),
}

impl MultiInputError {
/// Returns whether `self` is the `EmptyInput` variant.
pub fn is_empty_input(&self) -> bool {
match self {
MultiInputError::EmptyInput => true,
_ => false,
}
}

/// Returns whether `self` is the `ShapeMismatch` variant.
pub fn is_shape_mismatch(&self) -> bool {
match self {
MultiInputError::ShapeMismatch(_) => true,
_ => false,
}
}
}

impl fmt::Display for MultiInputError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
MultiInputError::EmptyInput => write!(f, "Empty input."),
MultiInputError::ShapeMismatch(e) => write!(f, "Shape mismatch: {}", e),
}
}
}

impl Error for MultiInputError {}

impl From<EmptyInput> for MultiInputError {
fn from(_: EmptyInput) -> Self {
MultiInputError::EmptyInput
}
}

impl From<ShapeMismatch> for MultiInputError {
fn from(err: ShapeMismatch) -> Self {
MultiInputError::ShapeMismatch(err)
}
}
Loading