diff --git a/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java b/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java index 65f7ec51..53845cea 100644 --- a/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java +++ b/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java @@ -31,6 +31,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Calendar; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -38,8 +39,6 @@ import org.apache.commons.compress.archivers.ArchiveException; -import ai.nets.samj.install.EfficientSamEnvManager; -import ai.nets.samj.models.PythonMethods; import io.bioimage.modelrunner.apposed.appose.Environment; import io.bioimage.modelrunner.apposed.appose.Mamba; import io.bioimage.modelrunner.apposed.appose.MambaInstallException; @@ -58,6 +57,7 @@ import io.bioimage.modelrunner.model.stardist_java_deprecate.StardistConfig; import io.bioimage.modelrunner.runmode.RunMode; import io.bioimage.modelrunner.runmode.ops.GenericOp; +import io.bioimage.modelrunner.system.PlatformDetection; import io.bioimage.modelrunner.tensor.Tensor; import io.bioimage.modelrunner.tensor.Utils; import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; @@ -71,6 +71,7 @@ import net.imglib2.type.numeric.RealType; import net.imglib2.type.numeric.real.FloatType; import net.imglib2.util.Cast; +import net.imglib2.util.Util; import net.imglib2.view.Views; /** @@ -97,10 +98,6 @@ public class Stardist2D { private final int channels; - private Float nms_threshold; - - private Float prob_threshold; - private Environment env; private Service python; @@ -108,6 +105,19 @@ public class Stardist2D { private static final List STARDIST_DEPS = Arrays.asList(new String[] {"python=3.10", "stardist", "numpy", "appose"}); private static final List STARDIST_CHANNELS = Arrays.asList(new String[] {"conda-forge", "default"}); + + + private static final String COORDS_DTYPE_KEY = "coords_dtype"; + + private static final String COORDS_SHAPE_KEY = "coords_shape"; + + private static final String POINTS_DTYPE_KEY = "points_dtype"; + + private static final String POINTS_SHAPE_KEY = "points_shape"; + + private static final String POINTS_KEY = "points"; + + private static final String COORDS_KEY = "coords"; private static final String LOAD_MODEL_CODE = "" + "if 'StarDist2D' not in globals().keys():" + System.lineSeparator() @@ -116,6 +126,9 @@ public class Stardist2D { + "if 'np' not in globals().keys():" + System.lineSeparator() + " import numpy as np" + System.lineSeparator() + " globals()['np'] = np" + System.lineSeparator() + + "if 'os' not in globals().keys():" + System.lineSeparator() + + " import os" + System.lineSeparator() + + " globals()['os'] = os" + System.lineSeparator() + "if 'shared_memory' not in globals().keys():" + System.lineSeparator() + " from multiprocessing import shared_memory" + System.lineSeparator() + " globals()['shared_memory'] = shared_memory" + System.lineSeparator() @@ -131,7 +144,24 @@ public class Stardist2D { + "task.outputs['coords_dtype'] = output[1]['coords'].dtype" + System.lineSeparator() + "task.outputs['points_shape'] = output[1]['points'].shape" + System.lineSeparator() + "task.outputs['points_dtype'] = output[1]['points'].dtype" + System.lineSeparator() - + ""; + + "coords_shm = " + + "shared_memory.SharedMemory(create=True, name=os.path.basename(shm_points_id), size=output[1]['coords'].nbytes)" + System.lineSeparator() + + "shared_coords = np.ndarray(output[1]['coords'].shape, dtype=output[1]['coords'].dtype, buffer=coords_shm.buf)" + System.lineSeparator() + + "points_shm = " + + "shared_memory.SharedMemory(create=True, name=os.path.basename(shm_points_id), size=output[1]['points'].nbytes)" + System.lineSeparator() + + "shared_points = np.ndarray(output[1]['points'].shape, dtype=output[1]['points'].dtype, buffer=points_shm.buf)" + System.lineSeparator() + + "globals()['shared_points'] = shared_points" + System.lineSeparator() + + "globals()['shared_coords'] = shared_coords" + System.lineSeparator() + + "globals()['shared_coords'] = shared_coords" + System.lineSeparator() + + "if os.name == 'nt':" + System.lineSeparator() + + " im_shm.close()" + System.lineSeparator() + + " im_shm.unlink()" + System.lineSeparator(); + + private static final String CLOSE_SHM_CODE = "" + + "points_shm.close()" + System.lineSeparator() + + "points_shm.unlink()" + System.lineSeparator() + + "coords_shm.close()" + System.lineSeparator() + + "coords_shm.unlink()" + System.lineSeparator(); private Stardist2D(String modelName, String baseDir) { this.name = modelName; @@ -144,9 +174,7 @@ else if (new File(modelDir, "config.json").isFile() == false) Map stardistMap = (Map) descriptor.getConfig().getSpecMap().get("stardist"); Map stardistConfig = (Map) stardistMap.get("config"); Map stardistThres = (Map) stardistMap.get("thresholds"); - this.channels = (int) stardistConfig.get("n_channel_in");; - this.nms_threshold = new Double((double) stardistThres.get("nms")).floatValue(); - this.prob_threshold = new Double((double) stardistThres.get("prob")).floatValue(); + this.channels = (int) stardistConfig.get("n_channel_in"); } @@ -192,7 +220,8 @@ public void close() { python.close(); } - public & NativeType> void run(RandomAccessibleInterval img) throws IOException, InterruptedException { + public & NativeType> + Map> predict(RandomAccessibleInterval img) throws IOException, InterruptedException { shma = SharedMemoryArray.createSHMAFromRAI(img); String code = ""; @@ -203,7 +232,14 @@ public & NativeType> void run(RandomAccessibleInterval code += createEncodeImageScript() + System.lineSeparator(); code += RUN_MODEL_CODE + System.lineSeparator(); - Task task = python.task(code); + + Map inputs = new HashMap(); + String shm_coords_id = SharedMemoryArray.createShmName(); + String shm_points_id = SharedMemoryArray.createShmName(); + inputs.put("shm_coords_id", shm_coords_id); + inputs.put("shm_points_id", shm_points_id); + + Task task = python.task(code, inputs); task.waitFor(); if (task.status == TaskStatus.CANCELED) throw new RuntimeException("Task canceled"); @@ -211,9 +247,56 @@ else if (task.status == TaskStatus.FAILED) throw new RuntimeException(task.error); else if (task.status == TaskStatus.CRASHED) throw new RuntimeException(task.error); - task.outputs.get(""); + loaded = true; + + + return reconstructOutputs(task, shm_coords_id, shm_points_id); + } + + private & NativeType> + Map> reconstructOutputs(Task task, String shm_coords_id, String shm_points_id) + throws IOException, InterruptedException { + String coords_dtype = (String) task.outputs.get("coords_dtype"); + List coords_shape = (List) task.outputs.get("coords_shape"); + String points_dtype = (String) task.outputs.get("points_dtype"); + List points_shape = (List) task.outputs.get("points_shape"); + long[] coordsSh = new long[coords_shape.size()]; + for (int i = 0; i < coordsSh.length; i ++) + coordsSh[i] = coords_shape.get(i).longValue(); + SharedMemoryArray shmCoords = SharedMemoryArray.readOrCreate(shm_coords_id, coordsSh, + Cast.unchecked(CommonUtils.getImgLib2DataType(coords_dtype)), false, false); + + long[] pointsSh = new long[points_shape.size()]; + for (int i = 0; i < pointsSh.length; i ++) + pointsSh[i] = points_shape.get(i).longValue(); + SharedMemoryArray shmPoints = SharedMemoryArray.readOrCreate(shm_points_id, pointsSh, + Cast.unchecked(CommonUtils.getImgLib2DataType(points_dtype)), false, false); + + Map> outs = new HashMap>(); + // TODO I do not understand why is complaining when the types align perfectly + RandomAccessibleInterval maskCopy = Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(shma.getSharedRAI()), + Util.getTypeFromInterval(Cast.unchecked(shma.getSharedRAI()))); + outs.put("mask", maskCopy); + RandomAccessibleInterval pointsRAI = shmPoints.getSharedRAI(); + RandomAccessibleInterval pointsCopy = Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(pointsRAI), + Util.getTypeFromInterval(Cast.unchecked(pointsRAI))); + outs.put("points", pointsCopy); + RandomAccessibleInterval coordsRAI = shmCoords.getSharedRAI(); + RandomAccessibleInterval coordsCopy = Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(shmCoords), + Util.getTypeFromInterval(Cast.unchecked(shmCoords))); + outs.put("coords", coordsCopy); + + shma.close(); + shmCoords.close(); + shmPoints.close(); + + if (PlatformDetection.isWindows()) { + Task closeSHMTask = python.task(CLOSE_SHM_CODE); + closeSHMTask.waitFor(); + } + return outs; } /** @@ -341,14 +424,6 @@ public static void installRequirements() throws IOException, InterruptedExceptio // TODO add logging for environment installation mamba.create("stardist", true, STARDIST_CHANNELS, STARDIST_DEPS); }; - String envPath = mamba.getEnvsDir() + File.separator + "stardist"; - String scriptPath = envPath + File.separator + STARDIST2D_SCRIPT_NAME; - if (!Paths.get(scriptPath).toFile().isFile()) { - try (InputStream scriptStream = Stardist2D.class.getClassLoader() - .getResourceAsStream(STARDIST2D_PATH_IN_RESOURCES + STARDIST2D_SCRIPT_NAME)){ - Files.copy(scriptStream, Paths.get(scriptPath), StandardCopyOption.REPLACE_EXISTING); - } - } } /**