Skip to content

Commit

Permalink
Use new imglib2 parallelization approach
Browse files Browse the repository at this point in the history
  • Loading branch information
maarzt committed Sep 19, 2019
1 parent 3aad06c commit 7626dbd
Show file tree
Hide file tree
Showing 23 changed files with 856 additions and 1,435 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ Jean-Yves Tinevez and Michael Zinsmaier.</license.copyrightOwners>

<!-- NB: Deploy releases to the SciJava Maven repository. -->
<releaseProfiles>deploy-to-scijava</releaseProfiles>

<imglib2.version>5.8.1-SNAPSHOT</imglib2.version>
</properties>

<repositories>
Expand Down
128 changes: 34 additions & 94 deletions src/main/java/net/imglib2/algorithm/binary/Thresholder.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,12 @@
*/
package net.imglib2.algorithm.binary;

import java.util.Vector;
import java.util.function.BiConsumer;

import net.imglib2.Cursor;
import net.imglib2.RandomAccess;
import net.imglib2.converter.Converter;
import net.imglib2.exception.IncompatibleTypeException;
import net.imglib2.img.Img;
import net.imglib2.img.ImgFactory;
import net.imglib2.multithreading.Chunk;
import net.imglib2.multithreading.SimpleMultiThreading;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.parallel.Parallelization;
import net.imglib2.type.Type;
import net.imglib2.type.logic.BitType;

Expand Down Expand Up @@ -73,95 +69,39 @@ public class Thresholder
*/
public static final < T extends Type< T > & Comparable< T >> Img< BitType > threshold( final Img< T > source, final T threshold, final boolean above, final int numThreads )
{
final ImgFactory< T > factory = source.factory();
try
{
final ImgFactory< BitType > bitFactory = factory.imgFactory( new BitType() );
final Img< BitType > target = bitFactory.create( source );

final Converter< T, BitType > converter;
if ( above )
{
converter = new Converter< T, BitType >()
{
@Override
public void convert( final T input, final BitType output )
{
output.set( input.compareTo( threshold ) > 0 );
}
};
}
else
{
converter = new Converter< T, BitType >()
{
@Override
public void convert( final T input, final BitType output )
{
output.set( input.compareTo( threshold ) < 0 );
}
};
}

final Vector< Chunk > chunks = SimpleMultiThreading.divideIntoChunks( target.size(), numThreads );
final Thread[] threads = SimpleMultiThreading.newThreads( numThreads );
return Parallelization.runWithNumThreads( numThreads,
() -> threshold( source, threshold, above ) );
}

if ( target.iterationOrder().equals( source.iterationOrder() ) )
{
for ( int i = 0; i < threads.length; i++ )
{
final Chunk chunk = chunks.get( i );
threads[ i ] = new Thread( "Thresholder thread " + i )
{
@Override
public void run()
{
final Cursor< BitType > cursorTarget = target.cursor();
cursorTarget.jumpFwd( chunk.getStartPosition() );
final Cursor< T > cursorSource = source.cursor();
cursorSource.jumpFwd( chunk.getStartPosition() );
for ( long steps = 0; steps < chunk.getLoopSize(); steps++ )
{
cursorTarget.fwd();
cursorSource.fwd();
converter.convert( cursorSource.get(), cursorTarget.get() );
}
}
};
}
}
else
{
for ( int i = 0; i < threads.length; i++ )
{
final Chunk chunk = chunks.get( i );
threads[ i ] = new Thread( "Thresholder thread " + i )
{
@Override
public void run()
{
final Cursor< BitType > cursorTarget = target.cursor();
cursorTarget.jumpFwd( chunk.getStartPosition() );
final RandomAccess< T > ra = source.randomAccess( target );
for ( long steps = 0; steps < chunk.getLoopSize(); steps++ )
{
cursorTarget.fwd();
ra.setPosition( cursorTarget );
converter.convert( ra.get(), cursorTarget.get() );
}
}
};
}
}
/**
* Returns a new boolean {@link Img} generated by thresholding the values of
* the source image.
*
* @param source
* the image to threshold.
* @param threshold
* the threshold.
* @param above
* if {@code true}, the target value will be true for source
* values above the threshold, {@code false} otherwise.
* @return a new {@link Img} of type {@link BitType} and of same dimension
* that the source image.
*/
public static < T extends Type< T > & Comparable< T > > Img< BitType > threshold( Img< T > source, T threshold, boolean above )
{
final ImgFactory< BitType > factory = source.factory().imgFactory( new BitType() );
final Img< BitType > target = factory.create( source );
final BiConsumer< T, BitType > converter = getThresholdConverter( threshold, above );
LoopBuilder.setImages( source, target ).multiThreaded().forEachPixel( converter );
return target;
}

SimpleMultiThreading.startAndJoin( threads );
return target;
}
catch ( final IncompatibleTypeException e )
{
e.printStackTrace();
return null;
}
public static < T extends Type< T > & Comparable< T > > BiConsumer< T, BitType > getThresholdConverter( T threshold, boolean above )
{
if ( above )
return ( input, output ) -> output.set( input.compareTo( threshold ) > 0 );
else
return ( input, output ) -> output.set( input.compareTo( threshold ) < 0 );
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*
* @author Matthias Arzt
*/
@Deprecated
public abstract class AbstractMultiThreadedConvolution< T > implements Convolution< T >
{

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,6 @@ class Concatenation< T > implements Convolution< T >
this.steps = new ArrayList<>( steps );
}

@Override
public void setExecutor( final ExecutorService executor )
{
steps.forEach( step -> step.setExecutor( executor ) );
}

@Override
public Interval requiredSourceInterval( final Interval targetInterval )
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public interface Convolution< T >
/**
* Set the {@link ExecutorService} to be used for convolution.
*/
@Deprecated
default void setExecutor( final ExecutorService executor )
{}

Expand Down
114 changes: 19 additions & 95 deletions src/main/java/net/imglib2/algorithm/convolution/LineConvolution.java
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
package net.imglib2.algorithm.convolution;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.function.Consumer;
import java.util.function.Supplier;

import net.imglib2.FinalInterval;
import net.imglib2.Interval;
import net.imglib2.Localizable;
import net.imglib2.Point;
import net.imglib2.RandomAccess;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.util.IntervalIndexer;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.util.Intervals;
import net.imglib2.util.Localizables;
import net.imglib2.view.Views;

/**
Expand All @@ -26,7 +17,7 @@
*
* @author Matthias Arzt
*/
public class LineConvolution< T > extends AbstractMultiThreadedConvolution< T >
public class LineConvolution< T > implements Convolution<T>
{
private final LineConvolverFactory< ? super T > factory;

Expand Down Expand Up @@ -55,100 +46,33 @@ public T preferredSourceType( final T targetType )
}

@Override
protected void process( final RandomAccessible< ? extends T > source, final RandomAccessibleInterval< ? extends T > target, final ExecutorService executorService, final int numThreads )
public void process( RandomAccessible< ? extends T > source, RandomAccessibleInterval< ? extends T > target )
{
final RandomAccessibleInterval< ? extends T > sourceInterval = Views.interval( source, requiredSourceInterval( target ) );
final long[] sourceMin = Intervals.minAsLongArray( sourceInterval );
final long[] targetMin = Intervals.minAsLongArray( target );

final Supplier< Consumer< Localizable > > actionFactory = () -> {

final RandomAccess< ? extends T > in = sourceInterval.randomAccess();
final RandomAccess< ? extends T > out = target.randomAccess();
final Runnable convolver = factory.getConvolver( in, out, direction, target.dimension( direction ) );

return position -> {
in.setPosition( sourceMin );
out.setPosition( targetMin );
in.move( position );
out.move( position );
convolver.run();
};
};

final long[] dim = Intervals.dimensionsAsLongArray( target );
dim[ direction ] = 1;

final int numTasks = numThreads > 1 ? timesFourAvoidOverflow(numThreads) : 1;
LineConvolution.forEachIntervalElementInParallel( executorService, numTasks, new FinalInterval( dim ), actionFactory );
}
RandomAccessibleInterval< Localizable > positions = Localizables.randomAccessibleInterval( new FinalInterval( dim ) );
LoopBuilder.setImages( positions ).multiThreaded().forEachChunk(
chunk -> {

private int timesFourAvoidOverflow( int x )
{
return (int) Math.min((long) x * 4, Integer.MAX_VALUE);
}
final RandomAccess< ? extends T > in = sourceInterval.randomAccess();
final RandomAccess< ? extends T > out = target.randomAccess();
final Runnable convolver = factory.getConvolver( in, out, direction, target.dimension( direction ) );

/**
* {@link #forEachIntervalElementInParallel(ExecutorService, int, Interval, Supplier)}
* executes a given action for each position in a given interval. Therefor
* it starts the specified number of tasks. Each tasks calls the action
* factory once, to get an instance of the action that should be executed.
* The action is then called multiple times by the task.
*
* @param service
* {@link ExecutorService} used to create the tasks.
* @param numTasks
* number of tasks to use.
* @param interval
* interval to iterate over.
* @param actionFactory
* factory that returns the action to be executed.
*/
// TODO: move to a better place
public static void forEachIntervalElementInParallel( final ExecutorService service, final int numTasks, final Interval interval,
final Supplier< Consumer< Localizable > > actionFactory )
{
final long[] min = Intervals.minAsLongArray( interval );
final long[] dim = Intervals.dimensionsAsLongArray( interval );
final long size = Intervals.numElements( dim );
final int boundedNumTasks = (int) Math.max( 1, Math.min(size, numTasks ));
final long taskSize = ( size - 1 ) / boundedNumTasks + 1; // taskSize = roundUp(size / boundedNumTasks);
final ArrayList< Callable< Void > > callables = new ArrayList<>();
chunk.forEachPixel( position -> {
in.setPosition( sourceMin );
out.setPosition( targetMin );
in.move( position );
out.move( position );
convolver.run();
} );

for ( int taskNum = 0; taskNum < boundedNumTasks; ++taskNum )
{
final long myStartIndex = taskNum * taskSize;
final long myEndIndex = Math.min( size, myStartIndex + taskSize );
final Callable< Void > r = () -> {
final Consumer< Localizable > action = actionFactory.get();
final long[] position = new long[ dim.length ];
final Localizable localizable = Point.wrap( position );
for ( long index = myStartIndex; index < myEndIndex; ++index )
{
IntervalIndexer.indexToPositionWithOffset( index, dim, min, position );
action.accept( localizable );
return null;
}
return null;
};
callables.add( r );
}
execute( service, callables );
}

private static void execute( final ExecutorService service, final ArrayList< Callable< Void > > callables )
{
try
{
final List< Future< Void > > futures = service.invokeAll( callables );
for ( final Future< Void > future : futures )
future.get();
}
catch ( final InterruptedException | ExecutionException e )
{
final Throwable cause = e.getCause();
if ( cause instanceof RuntimeException )
throw ( RuntimeException ) cause;
throw new RuntimeException( e );
}
);
}
}
Loading

0 comments on commit 7626dbd

Please sign in to comment.