Skip to content

Commit

Permalink
working prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Jan 16, 2025
1 parent 1c92bfa commit 042f67a
Showing 1 changed file with 124 additions and 93 deletions.
217 changes: 124 additions & 93 deletions src/main/java/io/bioimage/modelrunner/model/Stardist2D.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,11 @@
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Calendar;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.apache.commons.compress.archivers.ArchiveException;

Expand All @@ -53,13 +44,8 @@
import io.bioimage.modelrunner.exceptions.LoadEngineException;
import io.bioimage.modelrunner.exceptions.LoadModelException;
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.model.processing.Processing;
import io.bioimage.modelrunner.model.stardist_java_deprecate.StardistConfig;
import io.bioimage.modelrunner.runmode.RunMode;
import io.bioimage.modelrunner.runmode.ops.GenericOp;
import io.bioimage.modelrunner.system.PlatformDetection;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.tensor.Utils;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import io.bioimage.modelrunner.utils.CommonUtils;
import io.bioimage.modelrunner.utils.Constants;
Expand All @@ -72,34 +58,30 @@
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Util;
import net.imglib2.view.Views;

/**
* Implementation of an API to run Stardist 2D models out of the box with little configuration.
*
*TODO add fine tuning
*TODO add support for Mac arm
*
*@author Carlos Garcia
*/
public class Stardist2D {

private String modelDir;
private final String modelDir;

private final String name;

private final String basedir;

private final int nChannels;

private boolean loaded = false;

private SharedMemoryArray shma;

private ModelDescriptor descriptor;

private final int channels;

private Environment env;


private Service python;

private static final List<String> STARDIST_DEPS = Arrays.asList(new String[] {"python=3.10", "stardist", "numpy", "appose"});
Expand Down Expand Up @@ -136,67 +118,90 @@ public class Stardist2D {
+ "globals()['model'] = model";

private static final String RUN_MODEL_CODE = ""
+ "shm_coords_id = task.inputs['shm_coords_id']" + System.lineSeparator()
+ "shm_points_id = task.inputs['shm_points_id']" + System.lineSeparator()
+ "output = model.predict_instances(im, returnPredcit=False)" + System.lineSeparator()
+ "output = model.predict_instances(im, return_predict=False)" + System.lineSeparator()
+ "im[:] = output[0]" + System.lineSeparator()
+ "task.outputs['coords_shape'] = output[1]['coords'].shape" + System.lineSeparator()
+ "task.outputs['coords_dtype'] = output[1]['coords'].dtype" + System.lineSeparator()
+ "task.outputs['points_shape'] = output[1]['points'].shape" + System.lineSeparator()
+ "task.outputs['points_dtype'] = output[1]['points'].dtype" + System.lineSeparator()
+ "coords_shm = "
+ "shared_memory.SharedMemory(create=True, name=os.path.basename(shm_points_id), size=output[1]['coords'].nbytes)" + System.lineSeparator()
+ "shared_coords = np.ndarray(output[1]['coords'].shape, dtype=output[1]['coords'].dtype, buffer=coords_shm.buf)" + System.lineSeparator()
+ "points_shm = "
+ "shared_memory.SharedMemory(create=True, name=os.path.basename(shm_points_id), size=output[1]['points'].nbytes)" + System.lineSeparator()
+ "shared_points = np.ndarray(output[1]['points'].shape, dtype=output[1]['points'].dtype, buffer=points_shm.buf)" + System.lineSeparator()
+ "globals()['shared_points'] = shared_points" + System.lineSeparator()
+ "globals()['shared_coords'] = shared_coords" + System.lineSeparator()
+ "globals()['shared_coords'] = shared_coords" + System.lineSeparator()
+ "if output[1]['points'].nbytes == 0:" + System.lineSeparator()
+ " task.outputs['points_shape'] = None" + System.lineSeparator()
+ "else:" + System.lineSeparator()
+ " task.outputs['points_shape'] = output[1]['points'].shape" + System.lineSeparator()
+ " task.outputs['points_dtype'] = output[1]['points'].dtype" + System.lineSeparator()
+ " points_shm = "
+ " shared_memory.SharedMemory(create=True, name=os.path.basename(shm_points_id), size=output[1]['points'].nbytes)" + System.lineSeparator()
+ " shared_points = np.ndarray(output[1]['points'].shape, dtype=output[1]['points'].dtype, buffer=points_shm.buf)" + System.lineSeparator()
+ " globals()['shared_points'] = shared_points" + System.lineSeparator()
+ "if output[1]['coord'].nbytes == 0:" + System.lineSeparator()
+ " task.outputs['coords_shape'] = None" + System.lineSeparator()
+ "else:" + System.lineSeparator()
+ " task.outputs['coords_shape'] = output[1]['coord'].shape" + System.lineSeparator()
+ " task.outputs['coords_dtype'] = output[1]['coord'].dtype" + System.lineSeparator()
+ " coords_shm = "
+ " shared_memory.SharedMemory(create=True, name=os.path.basename(shm_points_id), size=output[1]['coord'].nbytes)" + System.lineSeparator()
+ " shared_coords = np.ndarray(output[1]['coord'].shape, dtype=output[1]['coord'].dtype, buffer=coords_shm.buf)" + System.lineSeparator()
+ " globals()['shared_coords'] = shared_coords" + System.lineSeparator()
+ "if os.name == 'nt':" + System.lineSeparator()
+ " im_shm.close()" + System.lineSeparator()
+ " im_shm.unlink()" + System.lineSeparator();

private static final String CLOSE_SHM_CODE = ""
+ "points_shm.close()" + System.lineSeparator()
+ "points_shm.unlink()" + System.lineSeparator()
+ "coords_shm.close()" + System.lineSeparator()
+ "coords_shm.unlink()" + System.lineSeparator();

private Stardist2D(String modelName, String baseDir) {
+ "if 'points_shm' in globals().keys():" + System.lineSeparator()
+ " points_shm.close()" + System.lineSeparator()
+ " points_shm.unlink()" + System.lineSeparator()
+ "if 'coords_shm' in globals().keys():" + System.lineSeparator()
+ " coords_shm.close()" + System.lineSeparator()
+ " coords_shm.unlink()" + System.lineSeparator();

private Stardist2D(String modelName, String baseDir) throws IOException, ModelSpecsException {
this.name = modelName;
this.basedir = baseDir;
modelDir = new File(baseDir, modelName).getAbsolutePath();
if (new File(modelDir, "config.json").isFile() == false && new File(modelDir, Constants.RDF_FNAME).isFile() == false)
throw new IllegalArgumentException("No 'config.json' file found in the model directory");
else if (new File(modelDir, "config.json").isFile() == false)
createConfigFromBioimageio();
Map<String, Object> stardistMap = (Map<String, Object>) descriptor.getConfig().getSpecMap().get("stardist");
Map<String, Object> stardistConfig = (Map<String, Object>) stardistMap.get("config");
Map<String, Object> stardistThres = (Map<String, Object>) stardistMap.get("thresholds");
this.channels = (int) stardistConfig.get("n_channel_in");

if (new File(modelDir, "thresholds.json").isFile() == false && new File(modelDir, Constants.RDF_FNAME).isFile() == false)
throw new IllegalArgumentException("No 'thresholds.json' file found in the model directory");
else if (new File(modelDir, "thresholds.json").isFile() == false)
createThresholdsFromBioimageio();
this.nChannels = ((Number) JSONUtils.load(new File(modelDir, "config.json").getAbsolutePath()).get("n_channel_in")).intValue();
createPythonService();
}

private void createConfigFromBioimageio() {

private Stardist2D(ModelDescriptor descriptor) throws IOException, ModelSpecsException {
this.descriptor = descriptor;
this.name = new File(descriptor.getModelPath()).getName();
this.basedir = new File(descriptor.getModelPath()).getParentFile().getAbsolutePath();
modelDir = descriptor.getModelPath();
if (new File(modelDir, "config.json").isFile() == false)
createConfigFromBioimageio();
if (new File(modelDir, "thresholds.json").isFile() == false)
createThresholdsFromBioimageio();
this.nChannels = ((Number) JSONUtils.load(new File(modelDir, "config.json").getAbsolutePath()).get("n_channel_in")).intValue();
createPythonService();
}

private void loadModel() throws IOException, InterruptedException {
if (loaded)
return;
String code = String.format(LOAD_MODEL_CODE, this.name, this.basedir);
Task task = python.task(code);
task.waitFor();
if (task.status == TaskStatus.CANCELED)
throw new RuntimeException("Task canceled");
else if (task.status == TaskStatus.FAILED)
throw new RuntimeException(task.error);
else if (task.status == TaskStatus.CRASHED)
throw new RuntimeException(task.error);
loaded = true;
}
private void createConfigFromBioimageio() throws IOException, ModelSpecsException {
if (descriptor == null)
descriptor = ModelDescriptorFactory.readFromLocalFile(modelDir + File.separator + Constants.RDF_FNAME);
Map<String, Object> stardistMap = (Map<String, Object>) descriptor.getConfig().getSpecMap().get("stardist");
Map<String, Object> stardistConfig = (Map<String, Object>) stardistMap.get("config");
JSONUtils.writeJSONFile(new File(modelDir, "config.json").getAbsolutePath(), stardistConfig);
}

private void createThresholdsFromBioimageio() throws IOException, ModelSpecsException {
if (descriptor == null)
descriptor = ModelDescriptorFactory.readFromLocalFile(modelDir + File.separator + Constants.RDF_FNAME);
Map<String, Object> stardistMap = (Map<String, Object>) descriptor.getConfig().getSpecMap().get("stardist");
Map<String, Object> stardistThres = (Map<String, Object>) stardistMap.get("thresholds");
JSONUtils.writeJSONFile(new File(modelDir, "thresholds.json").getAbsolutePath(), stardistThres);
}

private void createPythonService() throws IOException {
Environment env = new Environment() {
@Override public String base() { return new Mamba().getEnvsDir() + File.separator + "stardist"; }
};
python = env.python();
python.debug(System.err::println);
}

protected String createEncodeImageScript() {
String code = "";
Expand All @@ -206,7 +211,9 @@ protected String createEncodeImageScript() {
code += "im_shm = shared_memory.SharedMemory(name='"
+ shma.getNameForPython() + "', size=" + shma.getSize()
+ ")" + System.lineSeparator();
code += "im = np.ndarray(" + shma.getSize() + ", dtype='" + CommonUtils.getDataTypeFromRAI(Cast.unchecked(shma.getSharedRAI()))
long nElems = 1;
for (long elem : shma.getOriginalShape()) nElems *= elem;
code += "im = np.ndarray(" + nElems + ", dtype='" + CommonUtils.getDataTypeFromRAI(Cast.unchecked(shma.getSharedRAI()))
+ "', buffer=im_shm.buf).reshape([";
for (int i = 0; i < shma.getOriginalShape().length; i ++)
code += shma.getOriginalShape()[i] + ", ";
Expand All @@ -223,7 +230,7 @@ public void close() {
public <T extends RealType<T> & NativeType<T>>
Map<String, RandomAccessibleInterval<T>> predict(RandomAccessibleInterval<T> img) throws IOException, InterruptedException {

shma = SharedMemoryArray.createSHMAFromRAI(img);
shma = SharedMemoryArray.createSHMAFromRAI(img, false, false);
String code = "";
if (!loaded) {
code += String.format(LOAD_MODEL_CODE, this.name, this.basedir) + System.lineSeparator();
Expand Down Expand Up @@ -257,46 +264,70 @@ else if (task.status == TaskStatus.CRASHED)
Map<String, RandomAccessibleInterval<T>> reconstructOutputs(Task task, String shm_coords_id, String shm_points_id)
throws IOException, InterruptedException {

Map<String, RandomAccessibleInterval<T>> outs = new HashMap<String, RandomAccessibleInterval<T>>();
// TODO I do not understand why is complaining when the types align perfectly
RandomAccessibleInterval<T> maskCopy = Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(shma.getSharedRAI()),
Util.getTypeFromInterval(Cast.unchecked(shma.getSharedRAI())));
outs.put("mask", maskCopy);
outs.put("points", reconstructPoints(task, shm_points_id));
outs.put("coord", reconstructCoord(task, shm_coords_id));

shma.close();

if (PlatformDetection.isWindows()) {
Task closeSHMTask = python.task(CLOSE_SHM_CODE);
closeSHMTask.waitFor();
}
return outs;
}

private <T extends RealType<T> & NativeType<T>>
RandomAccessibleInterval<T> reconstructCoord(Task task, String shm_coords_id) throws IOException {

String coords_dtype = (String) task.outputs.get("coords_dtype");
List<Number> coords_shape = (List<Number>) task.outputs.get("coords_shape");
String points_dtype = (String) task.outputs.get("points_dtype");
List<Number> points_shape = (List<Number>) task.outputs.get("points_shape");
if (coords_shape == null)
return null;

long[] coordsSh = new long[coords_shape.size()];
for (int i = 0; i < coordsSh.length; i ++)
coordsSh[i] = coords_shape.get(i).longValue();
SharedMemoryArray shmCoords = SharedMemoryArray.readOrCreate(shm_coords_id, coordsSh,
Cast.unchecked(CommonUtils.getImgLib2DataType(coords_dtype)), false, false);

Map<String, RandomAccessibleInterval<T>> outs = new HashMap<String, RandomAccessibleInterval<T>>();
// TODO I do not understand why is complaining when the types align perfectly
RandomAccessibleInterval<T> coordsRAI = shmCoords.getSharedRAI();
RandomAccessibleInterval<T> coordsCopy = Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(coordsRAI),
Util.getTypeFromInterval(Cast.unchecked(shmCoords)));
outs.put("coords", coordsCopy);

shmCoords.close();

return coordsCopy;
}

private <T extends RealType<T> & NativeType<T>>
RandomAccessibleInterval<T> reconstructPoints(Task task, String shm_points_id) throws IOException {

String points_dtype = (String) task.outputs.get("points_dtype");
List<Number> points_shape = (List<Number>) task.outputs.get("points_shape");
if (points_shape == null)
return null;


long[] pointsSh = new long[points_shape.size()];
for (int i = 0; i < pointsSh.length; i ++)
pointsSh[i] = points_shape.get(i).longValue();
SharedMemoryArray shmPoints = SharedMemoryArray.readOrCreate(shm_points_id, pointsSh,
Cast.unchecked(CommonUtils.getImgLib2DataType(points_dtype)), false, false);

Map<String, RandomAccessibleInterval<T>> outs = new HashMap<String, RandomAccessibleInterval<T>>();
// TODO I do not understand why is complaining when the types align perfectly
RandomAccessibleInterval<T> maskCopy = Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(shma.getSharedRAI()),
Util.getTypeFromInterval(Cast.unchecked(shma.getSharedRAI())));
outs.put("mask", maskCopy);
RandomAccessibleInterval<T> pointsRAI = shmPoints.getSharedRAI();
RandomAccessibleInterval<T> pointsCopy = Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(pointsRAI),
Util.getTypeFromInterval(Cast.unchecked(pointsRAI)));
outs.put("points", pointsCopy);
RandomAccessibleInterval<T> coordsRAI = shmCoords.getSharedRAI();
RandomAccessibleInterval<T> coordsCopy = Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(shmCoords),
Util.getTypeFromInterval(Cast.unchecked(shmCoords)));
outs.put("coords", coordsCopy);

shma.close();
shmCoords.close();
shmPoints.close();

if (PlatformDetection.isWindows()) {
Task closeSHMTask = python.task(CLOSE_SHM_CODE);
closeSHMTask.waitFor();
}
return outs;
return pointsCopy;
}

/**
Expand All @@ -310,7 +341,7 @@ Map<String, RandomAccessibleInterval<T>> reconstructOutputs(Task task, String sh
*/
public static Stardist2D fromBioimageioModel(String modelPath) throws ModelSpecsException, FileNotFoundException, IOException {
ModelDescriptor descriptor = ModelDescriptorFactory.readFromLocalFile(modelPath + File.separator + Constants.RDF_FNAME);
return new Stardist2D(modelPath);
return new Stardist2D(descriptor);
}

/**
Expand Down Expand Up @@ -374,12 +405,12 @@ public static Stardist2D fromPretained(String pretrainedModel, String installDir
}

private <T extends RealType<T> & NativeType<T>> void checkInput(RandomAccessibleInterval<T> image) {
if (image.dimensionsAsLongArray().length == 2 && this.channels != 1)
if (image.dimensionsAsLongArray().length == 2 && this.nChannels != 1)
throw new IllegalArgumentException("Stardist2D needs an image with three dimensions: XYC");
else if (image.dimensionsAsLongArray().length != 3 && this.channels != 1)
else if (image.dimensionsAsLongArray().length != 3 && this.nChannels != 1)
throw new IllegalArgumentException("Stardist2D needs an image with three dimensions: XYC");
else if (image.dimensionsAsLongArray().length != 2 && image.dimensionsAsLongArray()[2] != channels)
throw new IllegalArgumentException("This Stardist2D model requires " + channels + " channels.");
else if (image.dimensionsAsLongArray().length != 2 && image.dimensionsAsLongArray()[2] != nChannels)
throw new IllegalArgumentException("This Stardist2D model requires " + nChannels + " channels.");
else if (image.dimensionsAsLongArray().length > 3 || image.dimensionsAsLongArray().length < 2)
throw new IllegalArgumentException("Stardist2D model requires an image with dimensions XYC.");
}
Expand Down Expand Up @@ -451,7 +482,7 @@ public static void main(String[] args) throws IOException, InterruptedException,

RandomAccessibleInterval<FloatType> img = ArrayImgs.floats(new long[] {512, 512});

RandomAccessibleInterval<FloatType> res = model.predict(img);
Map<String, RandomAccessibleInterval<FloatType>> res = model.predict(img);
System.out.println(true);
}
}

0 comments on commit 042f67a

Please sign in to comment.