Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 218 additions & 0 deletions src/main/java/net/imglib2/algorithm/stats/InformationMetrics.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
package net.imglib2.algorithm.stats;

import java.util.ArrayList;
import java.util.List;

import net.imglib2.Cursor;
import net.imglib2.IterableInterval;
import net.imglib2.histogram.BinMapper1d;
import net.imglib2.histogram.Histogram1d;
import net.imglib2.histogram.HistogramNd;
import net.imglib2.histogram.Real1dBinMapper;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.LongType;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;

/**
* This class provides method for computing information metrics (entropy, mutual information)
* for imglib2.
*
* @author John Bogovic
*/
public class InformationMetrics
{

/**
* Returns the normalized mutual information of the inputs
* @param rai the RandomAccessibleInterval
* @param ra the RandomAccessible
* @return the normalized mutual information
*/
public static <T extends RealType< T >> double normalizedMutualInformation(
final IterableInterval< T > dataA,
final IterableInterval< T > dataB,
double histmin, double histmax, int numBins )
{
HistogramNd< T > jointHist = jointHistogram( dataA, dataB, histmin, histmax, numBins );
return normalizedMutualInformation( jointHist );
}

/**
* Returns the normalized mutual information of the inputs
* @param jointHist the joint histogram
* @return the normalized mutual information
*/
public static <T extends RealType< T >> double normalizedMutualInformation(
final HistogramNd< T > jointHist )
{
double HA = marginalEntropy( jointHist, 0 );
double HB = marginalEntropy( jointHist, 1 );
double HAB = entropy( jointHist );
return 2 * ( HA + HB - HAB ) / ( HA + HB );
}

/**
* Returns the normalized mutual information of the inputs
* @param rai the RandomAccessibleInterval
* @param ra the RandomAccessible
* @return the normalized mutual information
*/
public static <T extends RealType< T >> double mutualInformation(
IterableInterval< T > dataA,
IterableInterval< T > dataB,
double histmin, double histmax, int numBins )
{
HistogramNd< T > jointHist = jointHistogram( dataA, dataB, histmin, histmax, numBins );

double HA = marginalEntropy( jointHist, 0 );
double HB = marginalEntropy( jointHist, 1 );
double HAB = entropy( jointHist );

return HA + HB - HAB;
}

/**
* Returns the normalized mutual information of the inputs
* @param ra the RandomAccessible
* @return the normalized mutual information
*/
public static <T extends RealType< T >> double mutualInformation(
final HistogramNd< T > jointHist )
{
double HA = marginalEntropy( jointHist, 0 );
double HB = marginalEntropy( jointHist, 1 );
double HAB = entropy( jointHist );
return HA + HB - HAB;
}

/**
* Compute a the joint histogram.
* @param dataA
* @param dataB
* @param histmin
* @param histmax
* @param numBins
* @return the joint histogram
*/
public static < T extends RealType< T > > HistogramNd< T > jointHistogram(
final IterableInterval< T > dataA,
final IterableInterval< T > dataB,
final double histmin,
final double histmax,
final int numBins )
{
Real1dBinMapper< T > binMapper = new Real1dBinMapper< T >( histmin, histmax, numBins, false );
ArrayList< BinMapper1d< T > > binMappers = new ArrayList< BinMapper1d< T > >( 2 );
binMappers.add( binMapper );
binMappers.add( binMapper );

List< Iterable< T > > data = new ArrayList< Iterable< T > >( 2 );
data.add( dataA );
data.add( dataB );
return new HistogramNd< T >( data, binMappers );
}

/**
* Returns the joint entropy of the inputs
* @param rai the RandomAccessibleInterval
* @param ra the RandomAccessible
* @return the joint entropy
*/
public static <T extends RealType< T >> double jointEntropy(
IterableInterval< T > dataA,
IterableInterval< T > dataB,
double histmin, double histmax, int numBins )
{
return entropy( jointHistogram( dataA, dataB, histmin, histmax, numBins ));
}

/**
* Returns the entropy of the input.
*
* @param data the data
* @return the entropy
*/
public static <T extends RealType< T >> double entropy(
IterableInterval< T > data,
double histmin, double histmax, int numBins )
{
Real1dBinMapper< T > binMapper = new Real1dBinMapper< T >( histmin, histmax, numBins, false );
final Histogram1d< T > hist = new Histogram1d< T >( binMapper );
hist.countData( data );

return entropy( hist );
}

/**
* Computes the entropy of the input 1d histogram.
* @param hist the histogram
* @return the entropy
*/
public static < T > double entropy( Histogram1d< T > hist )
{
double entropy = 0.0;
for( int i = 0; i < hist.getBinCount(); i++ )
{
double p = hist.relativeFrequency( i, false );
if( p > 0 )
entropy -= p * Math.log( p );

}
return entropy;
}

/**
* Computes the entropy of the input nd histogram.
* @param hist the histogram
* @return the entropy
*/
public static < T > double entropy( HistogramNd< T > hist )
{
double entropy = 0.0;
Cursor< LongType > hc = hist.cursor();
long[] pos = new long[ hc.numDimensions() ];

while( hc.hasNext() )
{
hc.fwd();
hc.localize( pos );
double p = hist.relativeFrequency( pos, false );
if( p > 0 )
entropy -= p * Math.log( p );

}
return entropy;
}

public static < T > double marginalEntropy( HistogramNd< T > hist, int dim )
{
final long ni = hist.dimension( dim );
final long total = hist.valueCount();
long count = 0;
double entropy = 0.0;
long ctot = 0;
for( int i = 0; i < ni; i++ )
{
count = subHistCount( hist, dim, i );
ctot += count;
double p = 1.0 * count / total;

if( p > 0 )
entropy -= p * Math.log( p );
}
return entropy;
}

private static < T > long subHistCount( HistogramNd< T > hist, int dim, int pos )
{
long count = 0;
IntervalView< LongType > hs = Views.hyperSlice( hist, dim, pos );
Cursor< LongType > c = hs.cursor();
while( c.hasNext() )
{
count += c.next().get();
}
return count;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package net.imglib2.algorithm.stats;

import static org.junit.Assert.assertEquals;

import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.numeric.integer.IntType;

public class InformationMetricsTests
{

private Img< IntType > imgZeros;
private Img< IntType > img3;
private Img< IntType > img3Shifted;
private Img< IntType > img3SmallErr;
private Img< IntType > imgTwo;

// mutual info of a set with itself
private double MI_id = 1.0986122886681096;
private double log2 = Math.log( 2.0 );

@Before
public void setup()
{
int[] a = new int[]{ 0, 1, 2, 0, 1, 2, 0, 1, 2};
img3 = ArrayImgs.ints( a, a.length );

int[] b = new int[]{ 0, 1, 2, 1, 2, 0, 2, 0, 1};
img3Shifted = ArrayImgs.ints( b, b.length );

imgZeros = ArrayImgs.ints( new int[ 9 ], 9 );

int[] c = new int[]{ 0, 1, 2, 0, 2, 1, 0, 1, 2 };
img3SmallErr = ArrayImgs.ints( c, c.length );

int[] twodat = new int[]{ 0, 1, 0, 1, 0, 1, 0, 1 };
imgTwo = ArrayImgs.ints( twodat, twodat.length );
}

@Test
public void testEntropy()
{
double entropyZeros = InformationMetrics.entropy( imgZeros, 0, 1, 2 );
double entropyCoinFlip = InformationMetrics.entropy( imgTwo, 0, 1, 2 );

assertEquals( "entropy zeros", 0.0, entropyZeros, 1e-6 );
assertEquals( "entropy fair coin", log2, entropyCoinFlip, 1e-6 );
}

@Test
public void testMutualInformation()
{
double miAA = InformationMetrics.mutualInformation( img3, img3, 0, 2, 3 );
double nmiAA = InformationMetrics.normalizedMutualInformation( img3, img3, 0, 2, 3 );
double nmiBB = InformationMetrics.normalizedMutualInformation( img3Shifted, img3Shifted, 0, 2, 3 );

double miAB = InformationMetrics.mutualInformation( img3, img3Shifted, 0, 2, 3 );
double nmiAB = InformationMetrics.normalizedMutualInformation( img3, img3Shifted, 0, 2, 3 );

double miBA = InformationMetrics.mutualInformation( img3Shifted, img3, 0, 2, 3 );
double nmiBA = InformationMetrics.normalizedMutualInformation( img3Shifted, img3, 0, 2, 3 );

double miBB = InformationMetrics.mutualInformation( img3Shifted, img3Shifted, 0, 2, 3 );

assertEquals( "self MI", MI_id, miAA, 1e-6 );
assertEquals( "self NMI", 1.0, nmiAA, 1e-6 );

assertEquals( "MI permutation", miAB, 0.0, 1e-6 );
assertEquals( "NMI permutation", nmiAB, 0.0, 1e-6 );

}

}