diff --git a/src/main/java/net/imglib2/algorithm/stats/InformationMetrics.java b/src/main/java/net/imglib2/algorithm/stats/InformationMetrics.java new file mode 100644 index 000000000..9c209b3b0 --- /dev/null +++ b/src/main/java/net/imglib2/algorithm/stats/InformationMetrics.java @@ -0,0 +1,218 @@ +package net.imglib2.algorithm.stats; + +import java.util.ArrayList; +import java.util.List; + +import net.imglib2.Cursor; +import net.imglib2.IterableInterval; +import net.imglib2.histogram.BinMapper1d; +import net.imglib2.histogram.Histogram1d; +import net.imglib2.histogram.HistogramNd; +import net.imglib2.histogram.Real1dBinMapper; +import net.imglib2.type.numeric.RealType; +import net.imglib2.type.numeric.integer.LongType; +import net.imglib2.view.IntervalView; +import net.imglib2.view.Views; + +/** + * This class provides method for computing information metrics (entropy, mutual information) + * for imglib2. + * + * @author John Bogovic + */ +public class InformationMetrics +{ + + /** + * Returns the normalized mutual information of the inputs + * @param rai the RandomAccessibleInterval + * @param ra the RandomAccessible + * @return the normalized mutual information + */ + public static > double normalizedMutualInformation( + final IterableInterval< T > dataA, + final IterableInterval< T > dataB, + double histmin, double histmax, int numBins ) + { + HistogramNd< T > jointHist = jointHistogram( dataA, dataB, histmin, histmax, numBins ); + return normalizedMutualInformation( jointHist ); + } + + /** + * Returns the normalized mutual information of the inputs + * @param jointHist the joint histogram + * @return the normalized mutual information + */ + public static > double normalizedMutualInformation( + final HistogramNd< T > jointHist ) + { + double HA = marginalEntropy( jointHist, 0 ); + double HB = marginalEntropy( jointHist, 1 ); + double HAB = entropy( jointHist ); + return 2 * ( HA + HB - HAB ) / ( HA + HB ); + } + + /** + * Returns the normalized mutual information of the inputs + * @param rai the RandomAccessibleInterval + * @param ra the RandomAccessible + * @return the normalized mutual information + */ + public static > double mutualInformation( + IterableInterval< T > dataA, + IterableInterval< T > dataB, + double histmin, double histmax, int numBins ) + { + HistogramNd< T > jointHist = jointHistogram( dataA, dataB, histmin, histmax, numBins ); + + double HA = marginalEntropy( jointHist, 0 ); + double HB = marginalEntropy( jointHist, 1 ); + double HAB = entropy( jointHist ); + + return HA + HB - HAB; + } + + /** + * Returns the normalized mutual information of the inputs + * @param ra the RandomAccessible + * @return the normalized mutual information + */ + public static > double mutualInformation( + final HistogramNd< T > jointHist ) + { + double HA = marginalEntropy( jointHist, 0 ); + double HB = marginalEntropy( jointHist, 1 ); + double HAB = entropy( jointHist ); + return HA + HB - HAB; + } + + /** + * Compute a the joint histogram. + * @param dataA + * @param dataB + * @param histmin + * @param histmax + * @param numBins + * @return the joint histogram + */ + public static < T extends RealType< T > > HistogramNd< T > jointHistogram( + final IterableInterval< T > dataA, + final IterableInterval< T > dataB, + final double histmin, + final double histmax, + final int numBins ) + { + Real1dBinMapper< T > binMapper = new Real1dBinMapper< T >( histmin, histmax, numBins, false ); + ArrayList< BinMapper1d< T > > binMappers = new ArrayList< BinMapper1d< T > >( 2 ); + binMappers.add( binMapper ); + binMappers.add( binMapper ); + + List< Iterable< T > > data = new ArrayList< Iterable< T > >( 2 ); + data.add( dataA ); + data.add( dataB ); + return new HistogramNd< T >( data, binMappers ); + } + + /** + * Returns the joint entropy of the inputs + * @param rai the RandomAccessibleInterval + * @param ra the RandomAccessible + * @return the joint entropy + */ + public static > double jointEntropy( + IterableInterval< T > dataA, + IterableInterval< T > dataB, + double histmin, double histmax, int numBins ) + { + return entropy( jointHistogram( dataA, dataB, histmin, histmax, numBins )); + } + + /** + * Returns the entropy of the input. + * + * @param data the data + * @return the entropy + */ + public static > double entropy( + IterableInterval< T > data, + double histmin, double histmax, int numBins ) + { + Real1dBinMapper< T > binMapper = new Real1dBinMapper< T >( histmin, histmax, numBins, false ); + final Histogram1d< T > hist = new Histogram1d< T >( binMapper ); + hist.countData( data ); + + return entropy( hist ); + } + + /** + * Computes the entropy of the input 1d histogram. + * @param hist the histogram + * @return the entropy + */ + public static < T > double entropy( Histogram1d< T > hist ) + { + double entropy = 0.0; + for( int i = 0; i < hist.getBinCount(); i++ ) + { + double p = hist.relativeFrequency( i, false ); + if( p > 0 ) + entropy -= p * Math.log( p ); + + } + return entropy; + } + + /** + * Computes the entropy of the input nd histogram. + * @param hist the histogram + * @return the entropy + */ + public static < T > double entropy( HistogramNd< T > hist ) + { + double entropy = 0.0; + Cursor< LongType > hc = hist.cursor(); + long[] pos = new long[ hc.numDimensions() ]; + + while( hc.hasNext() ) + { + hc.fwd(); + hc.localize( pos ); + double p = hist.relativeFrequency( pos, false ); + if( p > 0 ) + entropy -= p * Math.log( p ); + + } + return entropy; + } + + public static < T > double marginalEntropy( HistogramNd< T > hist, int dim ) + { + final long ni = hist.dimension( dim ); + final long total = hist.valueCount(); + long count = 0; + double entropy = 0.0; + long ctot = 0; + for( int i = 0; i < ni; i++ ) + { + count = subHistCount( hist, dim, i ); + ctot += count; + double p = 1.0 * count / total; + + if( p > 0 ) + entropy -= p * Math.log( p ); + } + return entropy; + } + + private static < T > long subHistCount( HistogramNd< T > hist, int dim, int pos ) + { + long count = 0; + IntervalView< LongType > hs = Views.hyperSlice( hist, dim, pos ); + Cursor< LongType > c = hs.cursor(); + while( c.hasNext() ) + { + count += c.next().get(); + } + return count; + } +} diff --git a/src/test/java/net/imglib2/algorithm/stats/InformationMetricsTests.java b/src/test/java/net/imglib2/algorithm/stats/InformationMetricsTests.java new file mode 100644 index 000000000..ec6d011b5 --- /dev/null +++ b/src/test/java/net/imglib2/algorithm/stats/InformationMetricsTests.java @@ -0,0 +1,77 @@ +package net.imglib2.algorithm.stats; + +import static org.junit.Assert.assertEquals; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import net.imglib2.img.Img; +import net.imglib2.img.array.ArrayImgs; +import net.imglib2.type.numeric.integer.IntType; + +public class InformationMetricsTests +{ + + private Img< IntType > imgZeros; + private Img< IntType > img3; + private Img< IntType > img3Shifted; + private Img< IntType > img3SmallErr; + private Img< IntType > imgTwo; + + // mutual info of a set with itself + private double MI_id = 1.0986122886681096; + private double log2 = Math.log( 2.0 ); + + @Before + public void setup() + { + int[] a = new int[]{ 0, 1, 2, 0, 1, 2, 0, 1, 2}; + img3 = ArrayImgs.ints( a, a.length ); + + int[] b = new int[]{ 0, 1, 2, 1, 2, 0, 2, 0, 1}; + img3Shifted = ArrayImgs.ints( b, b.length ); + + imgZeros = ArrayImgs.ints( new int[ 9 ], 9 ); + + int[] c = new int[]{ 0, 1, 2, 0, 2, 1, 0, 1, 2 }; + img3SmallErr = ArrayImgs.ints( c, c.length ); + + int[] twodat = new int[]{ 0, 1, 0, 1, 0, 1, 0, 1 }; + imgTwo = ArrayImgs.ints( twodat, twodat.length ); + } + + @Test + public void testEntropy() + { + double entropyZeros = InformationMetrics.entropy( imgZeros, 0, 1, 2 ); + double entropyCoinFlip = InformationMetrics.entropy( imgTwo, 0, 1, 2 ); + + assertEquals( "entropy zeros", 0.0, entropyZeros, 1e-6 ); + assertEquals( "entropy fair coin", log2, entropyCoinFlip, 1e-6 ); + } + + @Test + public void testMutualInformation() + { + double miAA = InformationMetrics.mutualInformation( img3, img3, 0, 2, 3 ); + double nmiAA = InformationMetrics.normalizedMutualInformation( img3, img3, 0, 2, 3 ); + double nmiBB = InformationMetrics.normalizedMutualInformation( img3Shifted, img3Shifted, 0, 2, 3 ); + + double miAB = InformationMetrics.mutualInformation( img3, img3Shifted, 0, 2, 3 ); + double nmiAB = InformationMetrics.normalizedMutualInformation( img3, img3Shifted, 0, 2, 3 ); + + double miBA = InformationMetrics.mutualInformation( img3Shifted, img3, 0, 2, 3 ); + double nmiBA = InformationMetrics.normalizedMutualInformation( img3Shifted, img3, 0, 2, 3 ); + + double miBB = InformationMetrics.mutualInformation( img3Shifted, img3Shifted, 0, 2, 3 ); + + assertEquals( "self MI", MI_id, miAA, 1e-6 ); + assertEquals( "self NMI", 1.0, nmiAA, 1e-6 ); + + assertEquals( "MI permutation", miAB, 0.0, 1e-6 ); + assertEquals( "NMI permutation", nmiAB, 0.0, 1e-6 ); + + } + +}