Skip to content

Commit

Permalink
keep iterating
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Jan 15, 2025
1 parent 6b0d46f commit 1c92bfa
Showing 1 changed file with 96 additions and 21 deletions.
117 changes: 96 additions & 21 deletions src/main/java/io/bioimage/modelrunner/model/Stardist2D.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,14 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Calendar;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.apache.commons.compress.archivers.ArchiveException;

import ai.nets.samj.install.EfficientSamEnvManager;
import ai.nets.samj.models.PythonMethods;
import io.bioimage.modelrunner.apposed.appose.Environment;
import io.bioimage.modelrunner.apposed.appose.Mamba;
import io.bioimage.modelrunner.apposed.appose.MambaInstallException;
Expand All @@ -58,6 +57,7 @@
import io.bioimage.modelrunner.model.stardist_java_deprecate.StardistConfig;
import io.bioimage.modelrunner.runmode.RunMode;
import io.bioimage.modelrunner.runmode.ops.GenericOp;
import io.bioimage.modelrunner.system.PlatformDetection;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.tensor.Utils;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
Expand All @@ -71,6 +71,7 @@
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Util;
import net.imglib2.view.Views;

/**
Expand All @@ -97,17 +98,26 @@ public class Stardist2D {

private final int channels;

private Float nms_threshold;

private Float prob_threshold;

private Environment env;

private Service python;

private static final List<String> STARDIST_DEPS = Arrays.asList(new String[] {"python=3.10", "stardist", "numpy", "appose"});

private static final List<String> STARDIST_CHANNELS = Arrays.asList(new String[] {"conda-forge", "default"});


private static final String COORDS_DTYPE_KEY = "coords_dtype";

private static final String COORDS_SHAPE_KEY = "coords_shape";

private static final String POINTS_DTYPE_KEY = "points_dtype";

private static final String POINTS_SHAPE_KEY = "points_shape";

private static final String POINTS_KEY = "points";

private static final String COORDS_KEY = "coords";

private static final String LOAD_MODEL_CODE = ""
+ "if 'StarDist2D' not in globals().keys():" + System.lineSeparator()
Expand All @@ -116,6 +126,9 @@ public class Stardist2D {
+ "if 'np' not in globals().keys():" + System.lineSeparator()
+ " import numpy as np" + System.lineSeparator()
+ " globals()['np'] = np" + System.lineSeparator()
+ "if 'os' not in globals().keys():" + System.lineSeparator()
+ " import os" + System.lineSeparator()
+ " globals()['os'] = os" + System.lineSeparator()
+ "if 'shared_memory' not in globals().keys():" + System.lineSeparator()
+ " from multiprocessing import shared_memory" + System.lineSeparator()
+ " globals()['shared_memory'] = shared_memory" + System.lineSeparator()
Expand All @@ -131,7 +144,24 @@ public class Stardist2D {
+ "task.outputs['coords_dtype'] = output[1]['coords'].dtype" + System.lineSeparator()
+ "task.outputs['points_shape'] = output[1]['points'].shape" + System.lineSeparator()
+ "task.outputs['points_dtype'] = output[1]['points'].dtype" + System.lineSeparator()
+ "";
+ "coords_shm = "
+ "shared_memory.SharedMemory(create=True, name=os.path.basename(shm_points_id), size=output[1]['coords'].nbytes)" + System.lineSeparator()
+ "shared_coords = np.ndarray(output[1]['coords'].shape, dtype=output[1]['coords'].dtype, buffer=coords_shm.buf)" + System.lineSeparator()
+ "points_shm = "
+ "shared_memory.SharedMemory(create=True, name=os.path.basename(shm_points_id), size=output[1]['points'].nbytes)" + System.lineSeparator()
+ "shared_points = np.ndarray(output[1]['points'].shape, dtype=output[1]['points'].dtype, buffer=points_shm.buf)" + System.lineSeparator()
+ "globals()['shared_points'] = shared_points" + System.lineSeparator()
+ "globals()['shared_coords'] = shared_coords" + System.lineSeparator()
+ "globals()['shared_coords'] = shared_coords" + System.lineSeparator()
+ "if os.name == 'nt':" + System.lineSeparator()
+ " im_shm.close()" + System.lineSeparator()
+ " im_shm.unlink()" + System.lineSeparator();

private static final String CLOSE_SHM_CODE = ""
+ "points_shm.close()" + System.lineSeparator()
+ "points_shm.unlink()" + System.lineSeparator()
+ "coords_shm.close()" + System.lineSeparator()
+ "coords_shm.unlink()" + System.lineSeparator();

private Stardist2D(String modelName, String baseDir) {
this.name = modelName;
Expand All @@ -144,9 +174,7 @@ else if (new File(modelDir, "config.json").isFile() == false)
Map<String, Object> stardistMap = (Map<String, Object>) descriptor.getConfig().getSpecMap().get("stardist");
Map<String, Object> stardistConfig = (Map<String, Object>) stardistMap.get("config");
Map<String, Object> stardistThres = (Map<String, Object>) stardistMap.get("thresholds");
this.channels = (int) stardistConfig.get("n_channel_in");;
this.nms_threshold = new Double((double) stardistThres.get("nms")).floatValue();
this.prob_threshold = new Double((double) stardistThres.get("prob")).floatValue();
this.channels = (int) stardistConfig.get("n_channel_in");

}

Expand Down Expand Up @@ -192,7 +220,8 @@ public void close() {
python.close();
}

public <T extends RealType<T> & NativeType<T>> void run(RandomAccessibleInterval<T> img) throws IOException, InterruptedException {
public <T extends RealType<T> & NativeType<T>>
Map<String, RandomAccessibleInterval<T>> predict(RandomAccessibleInterval<T> img) throws IOException, InterruptedException {

shma = SharedMemoryArray.createSHMAFromRAI(img);
String code = "";
Expand All @@ -203,17 +232,71 @@ public <T extends RealType<T> & NativeType<T>> void run(RandomAccessibleInterval
code += createEncodeImageScript() + System.lineSeparator();
code += RUN_MODEL_CODE + System.lineSeparator();

Task task = python.task(code);

Map<String, Object> inputs = new HashMap<String, Object>();
String shm_coords_id = SharedMemoryArray.createShmName();
String shm_points_id = SharedMemoryArray.createShmName();
inputs.put("shm_coords_id", shm_coords_id);
inputs.put("shm_points_id", shm_points_id);

Task task = python.task(code, inputs);
task.waitFor();
if (task.status == TaskStatus.CANCELED)
throw new RuntimeException("Task canceled");
else if (task.status == TaskStatus.FAILED)
throw new RuntimeException(task.error);
else if (task.status == TaskStatus.CRASHED)
throw new RuntimeException(task.error);
task.outputs.get("");
loaded = true;


return reconstructOutputs(task, shm_coords_id, shm_points_id);
}

private <T extends RealType<T> & NativeType<T>>
Map<String, RandomAccessibleInterval<T>> reconstructOutputs(Task task, String shm_coords_id, String shm_points_id)
throws IOException, InterruptedException {

String coords_dtype = (String) task.outputs.get("coords_dtype");
List<Number> coords_shape = (List<Number>) task.outputs.get("coords_shape");
String points_dtype = (String) task.outputs.get("points_dtype");
List<Number> points_shape = (List<Number>) task.outputs.get("points_shape");

long[] coordsSh = new long[coords_shape.size()];
for (int i = 0; i < coordsSh.length; i ++)
coordsSh[i] = coords_shape.get(i).longValue();
SharedMemoryArray shmCoords = SharedMemoryArray.readOrCreate(shm_coords_id, coordsSh,
Cast.unchecked(CommonUtils.getImgLib2DataType(coords_dtype)), false, false);

long[] pointsSh = new long[points_shape.size()];
for (int i = 0; i < pointsSh.length; i ++)
pointsSh[i] = points_shape.get(i).longValue();
SharedMemoryArray shmPoints = SharedMemoryArray.readOrCreate(shm_points_id, pointsSh,
Cast.unchecked(CommonUtils.getImgLib2DataType(points_dtype)), false, false);

Map<String, RandomAccessibleInterval<T>> outs = new HashMap<String, RandomAccessibleInterval<T>>();
// TODO I do not understand why is complaining when the types align perfectly
RandomAccessibleInterval<T> maskCopy = Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(shma.getSharedRAI()),
Util.getTypeFromInterval(Cast.unchecked(shma.getSharedRAI())));
outs.put("mask", maskCopy);
RandomAccessibleInterval<T> pointsRAI = shmPoints.getSharedRAI();
RandomAccessibleInterval<T> pointsCopy = Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(pointsRAI),
Util.getTypeFromInterval(Cast.unchecked(pointsRAI)));
outs.put("points", pointsCopy);
RandomAccessibleInterval<T> coordsRAI = shmCoords.getSharedRAI();
RandomAccessibleInterval<T> coordsCopy = Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(shmCoords),
Util.getTypeFromInterval(Cast.unchecked(shmCoords)));
outs.put("coords", coordsCopy);

shma.close();
shmCoords.close();
shmPoints.close();

if (PlatformDetection.isWindows()) {
Task closeSHMTask = python.task(CLOSE_SHM_CODE);
closeSHMTask.waitFor();
}
return outs;
}

/**
Expand Down Expand Up @@ -341,14 +424,6 @@ public static void installRequirements() throws IOException, InterruptedExceptio
// TODO add logging for environment installation
mamba.create("stardist", true, STARDIST_CHANNELS, STARDIST_DEPS);
};
String envPath = mamba.getEnvsDir() + File.separator + "stardist";
String scriptPath = envPath + File.separator + STARDIST2D_SCRIPT_NAME;
if (!Paths.get(scriptPath).toFile().isFile()) {
try (InputStream scriptStream = Stardist2D.class.getClassLoader()
.getResourceAsStream(STARDIST2D_PATH_IN_RESOURCES + STARDIST2D_SCRIPT_NAME)){
Files.copy(scriptStream, Paths.get(scriptPath), StandardCopyOption.REPLACE_EXISTING);
}
}
}

/**
Expand Down

0 comments on commit 1c92bfa

Please sign in to comment.