From 54f81d863242e59b08f22daf039d8f09de892b8f Mon Sep 17 00:00:00 2001 From: bogovicj Date: Tue, 3 Nov 2020 20:20:05 -0500 Subject: [PATCH 1/2] rough start to mutual information and related methods --- .../algorithm/stats/InformationMetrics.java | 185 ++++++++++++++++++ .../stats/InformationMetricsTests.java | 81 ++++++++ 2 files changed, 266 insertions(+) create mode 100644 src/main/java/net/imglib2/algorithm/stats/InformationMetrics.java create mode 100644 src/test/java/net/imglib2/algorithm/stats/InformationMetricsTests.java diff --git a/src/main/java/net/imglib2/algorithm/stats/InformationMetrics.java b/src/main/java/net/imglib2/algorithm/stats/InformationMetrics.java new file mode 100644 index 000000000..cbf001cba --- /dev/null +++ b/src/main/java/net/imglib2/algorithm/stats/InformationMetrics.java @@ -0,0 +1,185 @@ +package net.imglib2.algorithm.stats; + +import java.util.ArrayList; +import java.util.List; + +import net.imglib2.Cursor; +import net.imglib2.IterableInterval; +import net.imglib2.histogram.BinMapper1d; +import net.imglib2.histogram.Histogram1d; +import net.imglib2.histogram.HistogramNd; +import net.imglib2.histogram.Real1dBinMapper; +import net.imglib2.type.numeric.RealType; +import net.imglib2.type.numeric.integer.LongType; +import net.imglib2.view.IntervalView; +import net.imglib2.view.Views; + +/** + * This class provides method for computing information metrics (entropy, mutual information) + * for imglib2. + * + * @author John Bogovic + */ +public class InformationMetrics +{ + + /** + * Returns the normalized mutual information of the inputs + * @param rai the RandomAccessibleInterval + * @param ra the RandomAccessible + * @return the normalized mutual information + */ + public static > double normalizedMutualInformation( + IterableInterval< T > dataA, + IterableInterval< T > dataB, + double histmin, double histmax, int numBins ) + { + HistogramNd< T > jointHist = jointHistogram( dataA, dataB, histmin, histmax, numBins ); + + double HA = marginalEntropy( jointHist, 0 ); + double HB = marginalEntropy( jointHist, 1 ); + double HAB = entropy( jointHist ); + return ( HA + HB ) / HAB; + } + + /** + * Returns the normalized mutual information of the inputs + * @param rai the RandomAccessibleInterval + * @param ra the RandomAccessible + * @return the normalized mutual information + */ + public static > double mutualInformation( + IterableInterval< T > dataA, + IterableInterval< T > dataB, + double histmin, double histmax, int numBins ) + { + HistogramNd< T > jointHist = jointHistogram( dataA, dataB, histmin, histmax, numBins ); + + double HA = marginalEntropy( jointHist, 0 ); + double HB = marginalEntropy( jointHist, 1 ); + double HAB = entropy( jointHist ); + + return HA + HB - HAB; + } + + public static > HistogramNd jointHistogram( + IterableInterval< T > dataA, + IterableInterval< T > dataB, + double histmin, double histmax, int numBins ) + { + Real1dBinMapper binMapper = new Real1dBinMapper( histmin, histmax, numBins, false ); + ArrayList> binMappers = new ArrayList>( 2 ); + binMappers.add( binMapper ); + binMappers.add( binMapper ); + + List> data = new ArrayList>( 2 ); + data.add( dataA ); + data.add( dataB ); + return new HistogramNd( data, binMappers ); + } + + /** + * Returns the joint entropy of the inputs + * @param rai the RandomAccessibleInterval + * @param ra the RandomAccessible + * @return the joint entropy + */ + public static > double jointEntropy( + IterableInterval< T > dataA, + IterableInterval< T > dataB, + double histmin, double histmax, int numBins ) + { + return entropy( jointHistogram( dataA, dataB, histmin, histmax, numBins )); + } + + /** + * Returns the entropy of the input. + * + * @param data the data + * @return the entropy + */ + public static > double entropy( + IterableInterval< T > data, + double histmin, double histmax, int numBins ) + { + Real1dBinMapper binMapper = new Real1dBinMapper( + histmin, histmax, numBins, false ); + final Histogram1d hist = new Histogram1d( binMapper ); + hist.countData( data ); + + return entropy( hist ); + } + + /** + * Computes the entropy of the input 1d histogram. + * @param hist the histogram + * @return the entropy + */ + public static < T > double entropy( Histogram1d< T > hist ) + { + double entropy = 0.0; + for( int i = 0; i < hist.getBinCount(); i++ ) + { + double p = hist.relativeFrequency( i, false ); + if( p > 0 ) + entropy -= p * Math.log( p ); + + } + return entropy; + } + + /** + * Computes the entropy of the input nd histogram. + * @param hist the histogram + * @return the entropy + */ + public static < T > double entropy( HistogramNd< T > hist ) + { + double entropy = 0.0; + Cursor< LongType > hc = hist.cursor(); + long[] pos = new long[ hc.numDimensions() ]; + + while( hc.hasNext() ) + { + hc.fwd(); + hc.localize( pos ); + double p = hist.relativeFrequency( pos, false ); + if( p > 0 ) + entropy -= p * Math.log( p ); + + } + return entropy; + } + + public static < T > double marginalEntropy( HistogramNd< T > hist, int dim ) + { + + final long ni = hist.dimension( dim ); + final long total = hist.valueCount(); + long count = 0; + double entropy = 0.0; + long ctot = 0; + for( int i = 0; i < ni; i++ ) + { + count = subHistCount( hist, dim, i ); + ctot += count; + double p = 1.0 * count / total; + + if( p > 0 ) + entropy -= p * Math.log( p ); + } + return entropy; + } + + private static < T > long subHistCount( HistogramNd< T > hist, int dim, int pos ) + { + long count = 0; + IntervalView< LongType > hs = Views.hyperSlice( hist, dim, pos ); + Cursor< LongType > c = hs.cursor(); + while( c.hasNext() ) + { + count += c.next().get(); + } + return count; + } +} diff --git a/src/test/java/net/imglib2/algorithm/stats/InformationMetricsTests.java b/src/test/java/net/imglib2/algorithm/stats/InformationMetricsTests.java new file mode 100644 index 000000000..0f5d6ca49 --- /dev/null +++ b/src/test/java/net/imglib2/algorithm/stats/InformationMetricsTests.java @@ -0,0 +1,81 @@ +package net.imglib2.algorithm.stats; + +import static org.junit.Assert.assertEquals; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import net.imglib2.img.Img; +import net.imglib2.img.array.ArrayImgs; +import net.imglib2.type.numeric.integer.IntType; + +public class InformationMetricsTests +{ + + private Img< IntType > imgZeros; + private Img< IntType > img3; + private Img< IntType > img3Shifted; + private Img< IntType > imgTwo; + + // mutual info of a set with itself + private double MI_id = 1.0986122886681096; + + @Before + public void setup() + { + int[] a = new int[]{ 0, 1, 2, 0, 1, 2, 0, 1, 2}; + img3 = ArrayImgs.ints( a, a.length ); + + int[] b = new int[]{ 0, 1, 2, 1, 2, 0, 2, 0, 1}; + img3Shifted = ArrayImgs.ints( b, b.length ); + + imgZeros = ArrayImgs.ints( new int[ 9 ], 9 ); + + int[] c = new int[]{ 0, 1, 0, 1, 0, 1, 0, 1 }; + imgTwo = ArrayImgs.ints( c, c.length ); + } + + @Test + public void testEntropy() + { + double entropyZeros = InformationMetrics.entropy( imgZeros, 0, 1, 2 ); + double entropyCoinFlip = InformationMetrics.entropy( imgTwo, 0, 1, 2 ); + + /* + * These tests fail + */ +// assertEquals( 0.0, entropyZeros, 1e-6 ); +// assertEquals( 1.0, entropyCoinFlip, 1e-6 ); + +// System.out.println( "entropy zeros : " + entropyZeros ); + } + + @Test + public void testMutualInformation() + { + double miAA = InformationMetrics.mutualInformation( img3, img3, 0, 2, 3 ); + double nmiAA = InformationMetrics.normalizedMutualInformation( img3, img3, 0, 2, 3 ); + + double miAB = InformationMetrics.mutualInformation( img3, img3Shifted, 0, 2, 3 ); + double nmiAB = InformationMetrics.normalizedMutualInformation( img3, img3Shifted, 0, 2, 3 ); + + double miBA = InformationMetrics.mutualInformation( img3Shifted, img3, 0, 2, 3 ); + double nmiBA = InformationMetrics.normalizedMutualInformation( img3Shifted, img3, 0, 2, 3 ); + + double miBB = InformationMetrics.mutualInformation( img3Shifted, img3Shifted, 0, 2, 3 ); + + assertEquals( "self MI", MI_id, miAA, 1e-6 ); + assertEquals( "self MI", MI_id, miBB, 1e-6 ); + +// assertEquals( "MI symmetry", miAA, miBA, 1e-6 ); +// +// System.out.println( "mi:" ); +// System.out.println( miAA ); +// System.out.println( miAB ); +// System.out.println( "nmi:" ); +// System.out.println( nmiAA ); +// System.out.println( nmiAB ); + } + +} From 421880c845e741e8a70b16ea3c1692ed1f7aec22 Mon Sep 17 00:00:00 2001 From: bogovicj Date: Mon, 7 Dec 2020 15:40:31 -0500 Subject: [PATCH 2/2] fix nmi, fix info measure tests. more doc --- .../algorithm/stats/InformationMetrics.java | 65 ++++++++++++++----- .../stats/InformationMetricsTests.java | 32 ++++----- 2 files changed, 63 insertions(+), 34 deletions(-) diff --git a/src/main/java/net/imglib2/algorithm/stats/InformationMetrics.java b/src/main/java/net/imglib2/algorithm/stats/InformationMetrics.java index cbf001cba..9c209b3b0 100644 --- a/src/main/java/net/imglib2/algorithm/stats/InformationMetrics.java +++ b/src/main/java/net/imglib2/algorithm/stats/InformationMetrics.java @@ -30,16 +30,26 @@ public class InformationMetrics * @return the normalized mutual information */ public static > double normalizedMutualInformation( - IterableInterval< T > dataA, - IterableInterval< T > dataB, + final IterableInterval< T > dataA, + final IterableInterval< T > dataB, double histmin, double histmax, int numBins ) { HistogramNd< T > jointHist = jointHistogram( dataA, dataB, histmin, histmax, numBins ); + return normalizedMutualInformation( jointHist ); + } + /** + * Returns the normalized mutual information of the inputs + * @param jointHist the joint histogram + * @return the normalized mutual information + */ + public static > double normalizedMutualInformation( + final HistogramNd< T > jointHist ) + { double HA = marginalEntropy( jointHist, 0 ); double HB = marginalEntropy( jointHist, 1 ); double HAB = entropy( jointHist ); - return ( HA + HB ) / HAB; + return 2 * ( HA + HB - HAB ) / ( HA + HB ); } /** @@ -61,21 +71,46 @@ public static > double mutualInformation( return HA + HB - HAB; } + + /** + * Returns the normalized mutual information of the inputs + * @param ra the RandomAccessible + * @return the normalized mutual information + */ + public static > double mutualInformation( + final HistogramNd< T > jointHist ) + { + double HA = marginalEntropy( jointHist, 0 ); + double HB = marginalEntropy( jointHist, 1 ); + double HAB = entropy( jointHist ); + return HA + HB - HAB; + } - public static > HistogramNd jointHistogram( - IterableInterval< T > dataA, - IterableInterval< T > dataB, - double histmin, double histmax, int numBins ) + /** + * Compute a the joint histogram. + * @param dataA + * @param dataB + * @param histmin + * @param histmax + * @param numBins + * @return the joint histogram + */ + public static < T extends RealType< T > > HistogramNd< T > jointHistogram( + final IterableInterval< T > dataA, + final IterableInterval< T > dataB, + final double histmin, + final double histmax, + final int numBins ) { - Real1dBinMapper binMapper = new Real1dBinMapper( histmin, histmax, numBins, false ); - ArrayList> binMappers = new ArrayList>( 2 ); + Real1dBinMapper< T > binMapper = new Real1dBinMapper< T >( histmin, histmax, numBins, false ); + ArrayList< BinMapper1d< T > > binMappers = new ArrayList< BinMapper1d< T > >( 2 ); binMappers.add( binMapper ); binMappers.add( binMapper ); - List> data = new ArrayList>( 2 ); + List< Iterable< T > > data = new ArrayList< Iterable< T > >( 2 ); data.add( dataA ); data.add( dataB ); - return new HistogramNd( data, binMappers ); + return new HistogramNd< T >( data, binMappers ); } /** @@ -102,9 +137,8 @@ public static > double entropy( IterableInterval< T > data, double histmin, double histmax, int numBins ) { - Real1dBinMapper binMapper = new Real1dBinMapper( - histmin, histmax, numBins, false ); - final Histogram1d hist = new Histogram1d( binMapper ); + Real1dBinMapper< T > binMapper = new Real1dBinMapper< T >( histmin, histmax, numBins, false ); + final Histogram1d< T > hist = new Histogram1d< T >( binMapper ); hist.countData( data ); return entropy( hist ); @@ -153,7 +187,6 @@ public static < T > double entropy( HistogramNd< T > hist ) public static < T > double marginalEntropy( HistogramNd< T > hist, int dim ) { - final long ni = hist.dimension( dim ); final long total = hist.valueCount(); long count = 0; @@ -170,7 +203,7 @@ public static < T > double marginalEntropy( HistogramNd< T > hist, int dim ) } return entropy; } - + private static < T > long subHistCount( HistogramNd< T > hist, int dim, int pos ) { long count = 0; diff --git a/src/test/java/net/imglib2/algorithm/stats/InformationMetricsTests.java b/src/test/java/net/imglib2/algorithm/stats/InformationMetricsTests.java index 0f5d6ca49..ec6d011b5 100644 --- a/src/test/java/net/imglib2/algorithm/stats/InformationMetricsTests.java +++ b/src/test/java/net/imglib2/algorithm/stats/InformationMetricsTests.java @@ -16,10 +16,12 @@ public class InformationMetricsTests private Img< IntType > imgZeros; private Img< IntType > img3; private Img< IntType > img3Shifted; + private Img< IntType > img3SmallErr; private Img< IntType > imgTwo; // mutual info of a set with itself private double MI_id = 1.0986122886681096; + private double log2 = Math.log( 2.0 ); @Before public void setup() @@ -32,8 +34,11 @@ public void setup() imgZeros = ArrayImgs.ints( new int[ 9 ], 9 ); - int[] c = new int[]{ 0, 1, 0, 1, 0, 1, 0, 1 }; - imgTwo = ArrayImgs.ints( c, c.length ); + int[] c = new int[]{ 0, 1, 2, 0, 2, 1, 0, 1, 2 }; + img3SmallErr = ArrayImgs.ints( c, c.length ); + + int[] twodat = new int[]{ 0, 1, 0, 1, 0, 1, 0, 1 }; + imgTwo = ArrayImgs.ints( twodat, twodat.length ); } @Test @@ -42,13 +47,8 @@ public void testEntropy() double entropyZeros = InformationMetrics.entropy( imgZeros, 0, 1, 2 ); double entropyCoinFlip = InformationMetrics.entropy( imgTwo, 0, 1, 2 ); - /* - * These tests fail - */ -// assertEquals( 0.0, entropyZeros, 1e-6 ); -// assertEquals( 1.0, entropyCoinFlip, 1e-6 ); - -// System.out.println( "entropy zeros : " + entropyZeros ); + assertEquals( "entropy zeros", 0.0, entropyZeros, 1e-6 ); + assertEquals( "entropy fair coin", log2, entropyCoinFlip, 1e-6 ); } @Test @@ -56,6 +56,7 @@ public void testMutualInformation() { double miAA = InformationMetrics.mutualInformation( img3, img3, 0, 2, 3 ); double nmiAA = InformationMetrics.normalizedMutualInformation( img3, img3, 0, 2, 3 ); + double nmiBB = InformationMetrics.normalizedMutualInformation( img3Shifted, img3Shifted, 0, 2, 3 ); double miAB = InformationMetrics.mutualInformation( img3, img3Shifted, 0, 2, 3 ); double nmiAB = InformationMetrics.normalizedMutualInformation( img3, img3Shifted, 0, 2, 3 ); @@ -66,16 +67,11 @@ public void testMutualInformation() double miBB = InformationMetrics.mutualInformation( img3Shifted, img3Shifted, 0, 2, 3 ); assertEquals( "self MI", MI_id, miAA, 1e-6 ); - assertEquals( "self MI", MI_id, miBB, 1e-6 ); + assertEquals( "self NMI", 1.0, nmiAA, 1e-6 ); -// assertEquals( "MI symmetry", miAA, miBA, 1e-6 ); -// -// System.out.println( "mi:" ); -// System.out.println( miAA ); -// System.out.println( miAB ); -// System.out.println( "nmi:" ); -// System.out.println( nmiAA ); -// System.out.println( nmiAB ); + assertEquals( "MI permutation", miAB, 0.0, 1e-6 ); + assertEquals( "NMI permutation", nmiAB, 0.0, 1e-6 ); + } }