Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mutual information and related methods #89

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 218 additions & 0 deletions src/main/java/net/imglib2/algorithm/stats/InformationMetrics.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
package net.imglib2.algorithm.stats;

import java.util.ArrayList;
import java.util.List;

import net.imglib2.Cursor;
import net.imglib2.IterableInterval;
import net.imglib2.histogram.BinMapper1d;
import net.imglib2.histogram.Histogram1d;
import net.imglib2.histogram.HistogramNd;
import net.imglib2.histogram.Real1dBinMapper;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.LongType;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;

/**
* This class provides method for computing information metrics (entropy, mutual information)
* for imglib2.
*
* @author John Bogovic
*/
public class InformationMetrics
{

/**
* Returns the normalized mutual information of the inputs
* @param rai the RandomAccessibleInterval
* @param ra the RandomAccessible
Comment on lines +28 to +29
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent with actual types of parameters in method signature

* @return the normalized mutual information
*/
public static <T extends RealType< T >> double normalizedMutualInformation(
final IterableInterval< T > dataA,
final IterableInterval< T > dataB,
Comment on lines +33 to +34
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a need to pass IterableInterval? Could it just be Iterable?

double histmin, double histmax, int numBins )
{
HistogramNd< T > jointHist = jointHistogram( dataA, dataB, histmin, histmax, numBins );
return normalizedMutualInformation( jointHist );
}

/**
* Returns the normalized mutual information of the inputs
* @param jointHist the joint histogram
* @return the normalized mutual information
*/
public static <T extends RealType< T >> double normalizedMutualInformation(
final HistogramNd< T > jointHist )
{
double HA = marginalEntropy( jointHist, 0 );
double HB = marginalEntropy( jointHist, 1 );
double HAB = entropy( jointHist );
return 2 * ( HA + HB - HAB ) / ( HA + HB );
}

/**
* Returns the normalized mutual information of the inputs
* @param rai the RandomAccessibleInterval
* @param ra the RandomAccessible
* @return the normalized mutual information
*/
public static <T extends RealType< T >> double mutualInformation(
IterableInterval< T > dataA,
IterableInterval< T > dataB,
double histmin, double histmax, int numBins )
{
HistogramNd< T > jointHist = jointHistogram( dataA, dataB, histmin, histmax, numBins );

double HA = marginalEntropy( jointHist, 0 );
double HB = marginalEntropy( jointHist, 1 );
double HAB = entropy( jointHist );

return HA + HB - HAB;
Comment on lines +68 to +72
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should call mutualInformation( jointHist ) instead of duplicating the code.

}

/**
* Returns the normalized mutual information of the inputs
* @param ra the RandomAccessible
* @return the normalized mutual information
*/
public static <T extends RealType< T >> double mutualInformation(
final HistogramNd< T > jointHist )
{
double HA = marginalEntropy( jointHist, 0 );
double HB = marginalEntropy( jointHist, 1 );
double HAB = entropy( jointHist );
return HA + HB - HAB;
}

/**
* Compute a the joint histogram.
* @param dataA
* @param dataB
* @param histmin
* @param histmax
* @param numBins
* @return the joint histogram
*/
public static < T extends RealType< T > > HistogramNd< T > jointHistogram(
final IterableInterval< T > dataA,
final IterableInterval< T > dataB,
final double histmin,
final double histmax,
final int numBins )
{
Real1dBinMapper< T > binMapper = new Real1dBinMapper< T >( histmin, histmax, numBins, false );
ArrayList< BinMapper1d< T > > binMappers = new ArrayList< BinMapper1d< T > >( 2 );
binMappers.add( binMapper );
binMappers.add( binMapper );
Comment on lines +106 to +108
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not tested but something like this should be possible:

Suggested change
ArrayList< BinMapper1d< T > > binMappers = new ArrayList< BinMapper1d< T > >( 2 );
binMappers.add( binMapper );
binMappers.add( binMapper );
List< BinMapper1d< T > > binMappers = Arrays.asList( binMapper, binMapper );


List< Iterable< T > > data = new ArrayList< Iterable< T > >( 2 );
data.add( dataA );
data.add( dataB );
Comment on lines +110 to +112
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not tested but something like this should be possible:

Suggested change
List< Iterable< T > > data = new ArrayList< Iterable< T > >( 2 );
data.add( dataA );
data.add( dataB );
List< Iterable< T > > data = Arrays.asList( dataA, dataB );

return new HistogramNd< T >( data, binMappers );
}

/**
* Returns the joint entropy of the inputs
* @param rai the RandomAccessibleInterval
* @param ra the RandomAccessible
* @return the joint entropy
*/
public static <T extends RealType< T >> double jointEntropy(
IterableInterval< T > dataA,
IterableInterval< T > dataB,
double histmin, double histmax, int numBins )
{
return entropy( jointHistogram( dataA, dataB, histmin, histmax, numBins ));
}

/**
* Returns the entropy of the input.
*
* @param data the data
* @return the entropy
*/
public static <T extends RealType< T >> double entropy(
IterableInterval< T > data,
double histmin, double histmax, int numBins )
{
Real1dBinMapper< T > binMapper = new Real1dBinMapper< T >( histmin, histmax, numBins, false );
final Histogram1d< T > hist = new Histogram1d< T >( binMapper );
hist.countData( data );

return entropy( hist );
}

/**
* Computes the entropy of the input 1d histogram.
* @param hist the histogram
* @return the entropy
*/
public static < T > double entropy( Histogram1d< T > hist )
{
double entropy = 0.0;
for( int i = 0; i < hist.getBinCount(); i++ )
{
double p = hist.relativeFrequency( i, false );
if( p > 0 )
entropy -= p * Math.log( p );

}
return entropy;
}

/**
* Computes the entropy of the input nd histogram.
* @param hist the histogram
* @return the entropy
*/
public static < T > double entropy( HistogramNd< T > hist )
{
double entropy = 0.0;
Cursor< LongType > hc = hist.cursor();
long[] pos = new long[ hc.numDimensions() ];

while( hc.hasNext() )
{
hc.fwd();
hc.localize( pos );
double p = hist.relativeFrequency( pos, false );
if( p > 0 )
entropy -= p * Math.log( p );

}
return entropy;
}

public static < T > double marginalEntropy( HistogramNd< T > hist, int dim )
{
final long ni = hist.dimension( dim );
final long total = hist.valueCount();
long count = 0;
double entropy = 0.0;
long ctot = 0;
for( int i = 0; i < ni; i++ )
{
count = subHistCount( hist, dim, i );
ctot += count;
double p = 1.0 * count / total;

if( p > 0 )
entropy -= p * Math.log( p );
}
Comment on lines +195 to +203
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we can run multi-threaded using LoopBuilder? See also #83.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I'll check 👍 thanks @imagejan

return entropy;
}

private static < T > long subHistCount( HistogramNd< T > hist, int dim, int pos )
{
long count = 0;
IntervalView< LongType > hs = Views.hyperSlice( hist, dim, pos );
Cursor< LongType > c = hs.cursor();
while( c.hasNext() )
{
count += c.next().get();
}
return count;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package net.imglib2.algorithm.stats;

import static org.junit.Assert.assertEquals;

import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.numeric.integer.IntType;

public class InformationMetricsTests
{

private Img< IntType > imgZeros;
private Img< IntType > img3;
private Img< IntType > img3Shifted;
private Img< IntType > img3SmallErr;
private Img< IntType > imgTwo;

// mutual info of a set with itself
private double MI_id = 1.0986122886681096;
private double log2 = Math.log( 2.0 );

@Before
public void setup()
{
int[] a = new int[]{ 0, 1, 2, 0, 1, 2, 0, 1, 2};
img3 = ArrayImgs.ints( a, a.length );

int[] b = new int[]{ 0, 1, 2, 1, 2, 0, 2, 0, 1};
img3Shifted = ArrayImgs.ints( b, b.length );

imgZeros = ArrayImgs.ints( new int[ 9 ], 9 );

int[] c = new int[]{ 0, 1, 2, 0, 2, 1, 0, 1, 2 };
img3SmallErr = ArrayImgs.ints( c, c.length );

int[] twodat = new int[]{ 0, 1, 0, 1, 0, 1, 0, 1 };
imgTwo = ArrayImgs.ints( twodat, twodat.length );
}

@Test
public void testEntropy()
{
double entropyZeros = InformationMetrics.entropy( imgZeros, 0, 1, 2 );
double entropyCoinFlip = InformationMetrics.entropy( imgTwo, 0, 1, 2 );

assertEquals( "entropy zeros", 0.0, entropyZeros, 1e-6 );
assertEquals( "entropy fair coin", log2, entropyCoinFlip, 1e-6 );
}

@Test
public void testMutualInformation()
{
double miAA = InformationMetrics.mutualInformation( img3, img3, 0, 2, 3 );
double nmiAA = InformationMetrics.normalizedMutualInformation( img3, img3, 0, 2, 3 );
double nmiBB = InformationMetrics.normalizedMutualInformation( img3Shifted, img3Shifted, 0, 2, 3 );

double miAB = InformationMetrics.mutualInformation( img3, img3Shifted, 0, 2, 3 );
double nmiAB = InformationMetrics.normalizedMutualInformation( img3, img3Shifted, 0, 2, 3 );

double miBA = InformationMetrics.mutualInformation( img3Shifted, img3, 0, 2, 3 );
double nmiBA = InformationMetrics.normalizedMutualInformation( img3Shifted, img3, 0, 2, 3 );

double miBB = InformationMetrics.mutualInformation( img3Shifted, img3Shifted, 0, 2, 3 );

assertEquals( "self MI", MI_id, miAA, 1e-6 );
assertEquals( "self NMI", 1.0, nmiAA, 1e-6 );

assertEquals( "MI permutation", miAB, 0.0, 1e-6 );
assertEquals( "NMI permutation", nmiAB, 0.0, 1e-6 );

}

}