Skip to content

Commit

Permalink
WIP: Parallelize computation
Browse files Browse the repository at this point in the history
  • Loading branch information
gselzer committed Apr 1, 2024
1 parent 62c551c commit b8d1226
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 35 deletions.
6 changes: 6 additions & 0 deletions scijava-ops-flim/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@
</dependency>

<!-- SciJava dependencies -->
<dependency>
<groupId>org.scijava</groupId>
<artifactId>scijava-concurrent</artifactId>
<version>${project.version}</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.scijava</groupId>
<artifactId>scijava-function</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@
package org.scijava.ops.flim.impl;

import net.imglib2.type.numeric.RealType;
import org.scijava.concurrent.Parallelization;
import org.scijava.ops.flim.FitParams;
import org.scijava.ops.flim.FitResults;
import org.scijava.ops.flim.util.RAHelper;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public abstract class AbstractSingleFitWorker<I extends RealType<I>> extends
AbstractFitWorker<I>
Expand Down Expand Up @@ -117,48 +122,77 @@ protected void onThreadInit() {}
@Override
public void fitBatch(List<int[]> pos, FitEventHandler<I> handler) {
final AbstractSingleFitWorker<I> thisWorker = this;
// TODO: Re-implement parallel behavior

// thread-local reusable read/write buffers
final FitParams<I> lParams;
final FitResults lResults;
final AbstractSingleFitWorker<I> fitWorker;
// don't make copy in single thread mode
if (!params.multithread || pos.size() == 1) {
lParams = params;
lResults = results;
fitWorker = thisWorker;
}
else {
lParams = params.copy();
lResults = results.copy();
// grab your own buffer
lParams.param = lParams.trans = lResults.param = lResults.fitted =
lResults.residuals = null;
fitWorker = duplicate(lParams, lResults);
}
fitWorker.onThreadInit();

final RAHelper<I> helper = new RAHelper<>(params, results);

for (int[] xytPos : pos) {
if (!helper.loadData(fitWorker.transBuffer, fitWorker.paramBuffer, params,
xytPos)) lResults.retCode = FitResults.RET_INTENSITY_BELOW_THRESH;
else {
fitWorker.fitSingle();
Consumer<int[]> worker = (data) -> {
int start = data[0];
int size = data[1];
if (!params.multithread) {
// let the first fitting thread do all the work
if (start != 0) {
return;
}
size = pos.size();
}

// invalidate fit if chisq is insane
final float chisq = lResults.chisq;
if (params.dropBad && lResults.retCode == FitResults.RET_OK &&
(chisq < 0 || chisq > 1E5 || Float.isNaN(chisq))) lResults.retCode =
FitResults.RET_BAD_FIT_CHISQ_OUT_OF_RANGE;
// thread-local reusable read/write buffers
final FitParams<I> lParams;
final FitResults lResults;
final AbstractSingleFitWorker<I> fitWorker;
// don't make copy in single thread mode
if (!params.multithread || pos.size() == 1) {
lParams = params;
lResults = results;
fitWorker = thisWorker;
} else {
lParams = params.copy();
lResults = results.copy();
// grab your own buffer
lParams.param = lParams.trans =
lResults.param = lResults.fitted = lResults.residuals = null;
fitWorker = duplicate(lParams, lResults);
}
fitWorker.onThreadInit();

final RAHelper<I> helper = new RAHelper<>(params, results);

helper.commitRslts(lParams, lResults, xytPos);
for (int i = start; i < start + size; i++) {
final int[] xytPos = pos.get(i);

if (handler != null) handler.onSingleComplete(xytPos, params, results);
if (!helper.loadData(fitWorker.transBuffer, fitWorker.paramBuffer, params, xytPos))
lResults.retCode = FitResults.RET_INTENSITY_BELOW_THRESH;
else {
fitWorker.fitSingle();

// invalidate fit if chisq is insane
final float chisq = lResults.chisq;
if (params.dropBad && lResults.retCode == FitResults.RET_OK
&& (chisq < 0 || chisq > 1E5 || Float.isNaN(chisq)))
lResults.retCode = FitResults.RET_BAD_FIT_CHISQ_OUT_OF_RANGE;
}

helper.commitRslts(lParams, lResults, xytPos);

if (handler != null)
handler.onSingleComplete(xytPos, params, results);
}
};

int n = Parallelization.getTaskExecutor().suggestNumberOfTasks();
int s = pos.size() / n;
int r = pos.size() % n;

List<int[]> list = new ArrayList<>(n);
int start = 0;
int size = s + 1;
for(int i = 0; i < n; i++) {
list.add(new int[] {start, size});
start += size;
if (i == r - 1) {
size--;
}
}

Parallelization.getTaskExecutor().forEach(list, worker);
if (handler != null) handler.onComplete(params, results);
}
}

0 comments on commit b8d1226

Please sign in to comment.