Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
maarzt committed Nov 30, 2022
1 parent d76e8ee commit 97a461f
Show file tree
Hide file tree
Showing 10 changed files with 424 additions and 995 deletions.
95 changes: 36 additions & 59 deletions src/main/java/net/imglib2/algorithm/gradient/PartialDerivative.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,18 @@

package net.imglib2.algorithm.gradient;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

import net.imglib2.Cursor;
import net.imglib2.FinalInterval;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.parallel.TaskExecutor;
import net.imglib2.parallel.Parallelization;
import net.imglib2.parallel.TaskExecutors;
import net.imglib2.type.numeric.NumericType;
import net.imglib2.util.Intervals;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;

/**
Expand Down Expand Up @@ -98,7 +95,7 @@ public static < T extends NumericType< T > > void gradientCentralDifference2( fi
* @param source
* source image, has to provide valid data in the interval of the
* gradient image plus a one pixel border in dimension.
* @param gradient
* @param result
* output image
* @param dimension
* along which dimension the partial derivatives are computed
Expand All @@ -110,57 +107,42 @@ public static < T extends NumericType< T > > void gradientCentralDifference2( fi
*/
public static < T extends NumericType< T > > void gradientCentralDifferenceParallel(
final RandomAccessible< T > source,
final RandomAccessibleInterval< T > gradient,
final RandomAccessibleInterval< T > result,
final int dimension,
final int nTasks,
final ExecutorService es ) throws InterruptedException, ExecutionException
{
final int nDim = source.numDimensions();
if ( nDim < 2 )
{
gradientCentralDifference( source, gradient, dimension );
return;
}

long dimensionMax = Long.MIN_VALUE;
int dimensionArgMax = -1;

for ( int d = 0; d < nDim; ++d )
{
final long size = gradient.dimension( d );
if ( d != dimension && size > dimensionMax )
{
dimensionMax = size;
dimensionArgMax = d;
}
}

final long stepSize = Math.max( dimensionMax / nTasks, 1 );
final long stepSizeMinusOne = stepSize - 1;
final long min = gradient.min( dimensionArgMax );
final long max = gradient.max( dimensionArgMax );

final ArrayList< Callable< Void > > tasks = new ArrayList<>();
for ( long currentMin = min, minZeroBase = 0; minZeroBase < dimensionMax; currentMin += stepSize, minZeroBase += stepSize )
{
final long currentMax = Math.min( currentMin + stepSizeMinusOne, max );
final long[] mins = new long[ nDim ];
final long[] maxs = new long[ nDim ];
gradient.min( mins );
gradient.max( maxs );
mins[ dimensionArgMax ] = currentMin;
maxs[ dimensionArgMax ] = currentMax;
final IntervalView< T > currentInterval = Views.interval( gradient, new FinalInterval( mins, maxs ) );
tasks.add( () -> {
gradientCentralDifference( source, currentInterval, dimension );
return null;
} );
}
TaskExecutor taskExecutor = TaskExecutors.forExecutorServiceAndNumTasks( es, nTasks );
Parallelization.runWithExecutor( taskExecutor, () -> {
gradientCentralDerivativeParallel( source, result, dimension );
} );
}

final List< Future< Void > > futures = es.invokeAll( tasks );
/**
* Compute the partial derivative (central difference approximation) of source
* in a particular dimension:
* {@code d_f( x ) = ( f( x + e ) - f( x - e ) ) / 2},
* where {@code e} is the unit vector along that dimension.
*
* @param source
* source image, has to provide valid data in the interval of the
* gradient image plus a one pixel border in dimension.
* @param result
* output image
* @param dimension
* along which dimension the partial derivatives are computed
*/
private static <T extends NumericType< T >> void gradientCentralDerivativeParallel( RandomAccessible<T> source,
RandomAccessibleInterval<T> result, int dimension )
{
final RandomAccessibleInterval<T> back = Views.interval( source, Intervals.translate( result, -1, dimension ) );
final RandomAccessibleInterval<T> front = Views.interval( source, Intervals.translate( result, 1, dimension ) );

for ( final Future< Void > f : futures )
f.get();
LoopBuilder.setImages( result, back, front ).multiThreaded().forEachPixel( ( r, b, f ) -> {
r.set( f );
r.sub( b );
r.mul( 0.5 );
} );
}

// fast version
Expand All @@ -181,13 +163,8 @@ public static < T extends NumericType< T > > void gradientCentralDifferenceParal
public static < T extends NumericType< T > > void gradientCentralDifference( final RandomAccessible< T > source,
final RandomAccessibleInterval< T > result, final int dimension )
{
final RandomAccessibleInterval< T > back = Views.interval( source, Intervals.translate( result, -1, dimension ) );
final RandomAccessibleInterval< T > front = Views.interval( source, Intervals.translate( result, 1, dimension ) );

LoopBuilder.setImages( result, back, front ).forEachPixel( ( r, b, f ) -> {
r.set( f );
r.sub( b );
r.mul( 0.5 );
Parallelization.runSingleThreaded( () -> {
gradientCentralDerivativeParallel( source, result, dimension );
} );
}

Expand Down
98 changes: 38 additions & 60 deletions src/main/java/net/imglib2/algorithm/localextrema/LocalExtrema.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,33 +33,35 @@
*/
package net.imglib2.algorithm.localextrema;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.stream.IntStream;
import java.util.stream.LongStream;

import net.imglib2.Cursor;
import net.imglib2.FinalInterval;
import net.imglib2.Interval;
import net.imglib2.Localizable;
import net.imglib2.Point;
import net.imglib2.RandomAccess;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.Sampler;
import net.imglib2.algorithm.neighborhood.Neighborhood;
import net.imglib2.algorithm.neighborhood.RectangleShape;
import net.imglib2.algorithm.neighborhood.Shape;
import net.imglib2.converter.readwrite.WriteConvertedRandomAccessible;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.parallel.Parallelization;
import net.imglib2.parallel.TaskExecutor;
import net.imglib2.parallel.TaskExecutors;
import net.imglib2.util.ConstantUtils;
import net.imglib2.util.Intervals;
import net.imglib2.util.ValuePair;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.stream.IntStream;
import java.util.stream.LongStream;

/**
* Provides {@link #findLocalExtrema} to find pixels that are extrema in their
* local neighborhood.
Expand Down Expand Up @@ -320,38 +322,8 @@ public static < P, T > List< P > findLocalExtrema(
final int numTasks,
final int splitDim ) throws InterruptedException, ExecutionException
{

final long[] min = Intervals.minAsLongArray( interval );
final long[] max = Intervals.maxAsLongArray( interval );

final long splitDimSize = interval.dimension( splitDim );
final long splitDimMax = max[ splitDim ];
final long splitDimMin = min[ splitDim ];
final long taskSize = Math.max( splitDimSize / numTasks, 1 );

final ArrayList< Callable< List< P > > > tasks = new ArrayList<>();

for ( long start = splitDimMin, stop = splitDimMin + taskSize - 1; start <= splitDimMax; start += taskSize, stop += taskSize )
{
final long s = start;
// need max here instead of dimension for constructor of
// FinalInterval
final long S = Math.min( stop, splitDimMax );
tasks.add( () -> {
final long[] localMin = min.clone();
final long[] localMax = max.clone();
localMin[ splitDim ] = s;
localMax[ splitDim ] = S;
return findLocalExtrema( source, new FinalInterval( localMin, localMax ), localNeighborhoodCheck, shape );
} );
}

final ArrayList< P > extrema = new ArrayList<>();
final List< Future< List< P > > > futures = service.invokeAll( tasks );
for ( final Future< List< P > > f : futures )
extrema.addAll( f.get() );
return extrema;

TaskExecutor taskExecutor = TaskExecutors.forExecutorServiceAndNumTasks( service, numTasks );
return Parallelization.runWithExecutor( taskExecutor, () -> findLocalExtrema( source, interval, localNeighborhoodCheck, shape ) );
}

/**
Expand Down Expand Up @@ -470,22 +442,28 @@ public static < P, T > List< P > findLocalExtrema(
final LocalNeighborhoodCheck< P, T > localNeighborhoodCheck,
final Shape shape )
{
WriteConvertedRandomAccessible< T, RandomAccess< T > > randomAccessible = new WriteConvertedRandomAccessible<>( source, sampler -> (RandomAccess< T >) sampler );
RandomAccessibleInterval< RandomAccess< T > > centers = Views.interval( randomAccessible, interval);
RandomAccessibleInterval< Neighborhood< T > > neighborhoods = Views.interval( shape.neighborhoodsRandomAccessible( source ), interval );
List< List< P > > extremas = LoopBuilder.setImages( centers, neighborhoods ).multiThreaded().forEachChunk( chunk -> {
List< P > extrema = new ArrayList<>();
chunk.forEachPixel( ( center, neighborhood ) -> {
P p = localNeighborhoodCheck.check( center, neighborhood );
if ( p != null )
extrema.add( p );
} );
return extrema;
} );
return concatenate( extremas );
}

final IntervalView< T > sourceInterval = Views.interval( source, interval );

final ArrayList< P > extrema = new ArrayList<>();

final Cursor< T > center = Views.flatIterable( sourceInterval ).cursor();
for ( final Neighborhood< T > neighborhood : shape.neighborhoods( sourceInterval ) )
{
center.fwd();
final P p = localNeighborhoodCheck.check( center, neighborhood );
if ( p != null )
extrema.add( p );
}

return extrema;

private static < P > List<P> concatenate( Collection<List<P>> lists )
{
int size = lists.stream().mapToInt( List::size ).sum();
List< P > result = new ArrayList<>( size );
for ( List< P > list : lists )
result.addAll( list );
return result;
}

/**
Expand Down
Loading

0 comments on commit 97a461f

Please sign in to comment.