From 6acc18e2c9b94373fe24819012c9abb7e23def58 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Wed, 15 Jan 2025 17:41:10 +0100 Subject: [PATCH] give up on the Java wrapper --- .../modelrunner/model/Stardist2D.java | 2 +- .../model/stardist/Normalizer.java | 5 - .../model/stardist/StardistConfig.java | 38 ----- .../modelrunner/model/stardist/Utils.java | 21 --- .../AbstractStardist.java | 86 +++++++---- .../stardist_java_deprecate/Normalizer.java | 5 + .../Resizer.java | 2 +- .../StardistConfig.java | 139 ++++++++++++++++++ .../model/stardist_java_deprecate/Utils.java | 78 ++++++++++ .../io/bioimage/modelrunner/tensor/Utils.java | 3 - 10 files changed, 281 insertions(+), 98 deletions(-) delete mode 100644 src/main/java/io/bioimage/modelrunner/model/stardist/Normalizer.java delete mode 100644 src/main/java/io/bioimage/modelrunner/model/stardist/StardistConfig.java delete mode 100644 src/main/java/io/bioimage/modelrunner/model/stardist/Utils.java rename src/main/java/io/bioimage/modelrunner/model/{stardist => stardist_java_deprecate}/AbstractStardist.java (77%) create mode 100644 src/main/java/io/bioimage/modelrunner/model/stardist_java_deprecate/Normalizer.java rename src/main/java/io/bioimage/modelrunner/model/{stardist => stardist_java_deprecate}/Resizer.java (97%) create mode 100644 src/main/java/io/bioimage/modelrunner/model/stardist_java_deprecate/StardistConfig.java create mode 100644 src/main/java/io/bioimage/modelrunner/model/stardist_java_deprecate/Utils.java diff --git a/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java b/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java index 246ae3ac..c6009713 100644 --- a/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java +++ b/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java @@ -49,7 +49,7 @@ import io.bioimage.modelrunner.exceptions.LoadModelException; import io.bioimage.modelrunner.exceptions.RunModelException; import io.bioimage.modelrunner.model.processing.Processing; -import io.bioimage.modelrunner.model.stardist.StardistConfig; +import io.bioimage.modelrunner.model.stardist_java_deprecate.StardistConfig; import io.bioimage.modelrunner.runmode.RunMode; import io.bioimage.modelrunner.runmode.ops.GenericOp; import io.bioimage.modelrunner.tensor.Tensor; diff --git a/src/main/java/io/bioimage/modelrunner/model/stardist/Normalizer.java b/src/main/java/io/bioimage/modelrunner/model/stardist/Normalizer.java deleted file mode 100644 index 1988cd04..00000000 --- a/src/main/java/io/bioimage/modelrunner/model/stardist/Normalizer.java +++ /dev/null @@ -1,5 +0,0 @@ -package io.bioimage.modelrunner.model.stardist; - -public class Normalizer { - -} diff --git a/src/main/java/io/bioimage/modelrunner/model/stardist/StardistConfig.java b/src/main/java/io/bioimage/modelrunner/model/stardist/StardistConfig.java deleted file mode 100644 index ce5241fb..00000000 --- a/src/main/java/io/bioimage/modelrunner/model/stardist/StardistConfig.java +++ /dev/null @@ -1,38 +0,0 @@ -/*- - * #%L - * Use deep learning frameworks from Java in an agnostic and isolated way. - * %% - * Copyright (C) 2022 - 2024 Institut Pasteur and BioImage.IO developers. - * %% - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * #L% - */ -package io.bioimage.modelrunner.model.stardist; - - -/** - * Implementation of the Stardist Config instance in Java. - * - *@author Carlos Garcia - */ -public class StardistConfig { - - public int[] unet_pool; - public String axes; - public double unet_n_depth; - public int[] grid; - public int nChannelIn; - public String backbone; - - -} diff --git a/src/main/java/io/bioimage/modelrunner/model/stardist/Utils.java b/src/main/java/io/bioimage/modelrunner/model/stardist/Utils.java deleted file mode 100644 index 0b10a384..00000000 --- a/src/main/java/io/bioimage/modelrunner/model/stardist/Utils.java +++ /dev/null @@ -1,21 +0,0 @@ -package io.bioimage.modelrunner.model.stardist; - -public class Utils { - - protected static String[] axesCheckAndNormalize(String axes, Integer length, String disallowed) { - String allowed = "STCZYX"; - if (axes == null) - throw new IllegalArgumentException("Axes cannot be null"); - axes = axes.toUpperCase(); - for (String a : axes.split("")) { - if (!allowed.contains(a)) - throw new IllegalArgumentException("Invalid axis: " + a + ", it must be one of " + allowed); - if (axes.replace(a, "").length() + 1 != axes.length()) - throw new IllegalArgumentException("Invalid axis: " + a + " can only appear once."); - } - if (length != null && axes.length() != length) - throw new IllegalArgumentException("Axes (" + axes + ") must have length " + length); - return new String[] {axes, allowed}; - } - -} diff --git a/src/main/java/io/bioimage/modelrunner/model/stardist/AbstractStardist.java b/src/main/java/io/bioimage/modelrunner/model/stardist_java_deprecate/AbstractStardist.java similarity index 77% rename from src/main/java/io/bioimage/modelrunner/model/stardist/AbstractStardist.java rename to src/main/java/io/bioimage/modelrunner/model/stardist_java_deprecate/AbstractStardist.java index 700652f9..a2a3bbb4 100644 --- a/src/main/java/io/bioimage/modelrunner/model/stardist/AbstractStardist.java +++ b/src/main/java/io/bioimage/modelrunner/model/stardist_java_deprecate/AbstractStardist.java @@ -1,4 +1,4 @@ -package io.bioimage.modelrunner.model.stardist; +package io.bioimage.modelrunner.model.stardist_java_deprecate; import java.util.ArrayList; import java.util.HashMap; @@ -9,7 +9,7 @@ import net.imglib2.img.array.ArrayImgs; public class AbstractStardist { - + /* private StardistConfig config; private float probThres; @@ -39,8 +39,14 @@ private void predict_instances(RandomAccessibleInterval img, Map String _axes = normalizeAxes(img, axes); String axesNet = config.axes; - String permuteAxes = permuteAxes(_axes, axesNet); - int[] shapeInst; + long[] shape = permuteAxes(null, null, _axes, axesNet, null, true).dimensionsAsLongArray(); + int[] shapeInst = new int[_axes.toUpperCase().replace("C", "").length()]; + int i = 0; + for (long s : shape) { + if (_axes.toUpperCase().split("")[i].equals("C")) + continue; + shapeInst[i ++] = (int) s; + } Number scale = null; if (kwargs.get("scale") != null && kwargs.get("scale") instanceof Number) @@ -78,7 +84,7 @@ private void predictSparseGenerator(RandomAccessibleInterval img, String axes, N Map returns = predictSetup(img, axes, normalizer, nTiles); RandomAccessibleInterval x = (RandomAccessibleInterval) returns.get("x"); - List nTiles = (List) returns.get("nTiles"); + nTiles = (List) returns.get("nTiles"); axes = (String) returns.get("axes"); String axesNet = (String) returns.get("axesNet"); int[] axesNetDivBy = (int[]) returns.get("axesNetDivBy"); @@ -90,7 +96,9 @@ private void predictSparseGenerator(RandomAccessibleInterval img, String axes, N if (product > 1) { // TODO } else { - + // TODO + RandomAccessibleInterval prob = null; //(256, 256, 1) + RandomAccessibleInterval dist = null; //(256, 256, 32) } } @@ -105,13 +113,12 @@ private Map predictSetup(RandomAccessibleInterval img, String ax + ") should be the same as the tile list lenght (" + nTiles.size() + ")."); axes = normalizeAxes(img, axes); String axesNet = this.config.axes; - // TODO permuteAxes - RandomAccessibleInterval x = null; // TODO - channel = axesDict(axesNet).get("C"); - if (this.config.nChannelIn != x.dimensionsAsLongArray()[channel]) + RandomAccessibleInterval x = permuteAxes(img, axes, axesNet, null, null, true); + int channel = Utils.axesDict(axesNet).get("C"); + if (this.config.n_channel_in != x.dimensionsAsLongArray()[channel]) throw new IllegalArgumentException("The number of channels of the image (" + x.dimensionsAsLongArray()[channel] + ") should be the same as the model config (" - + config.nChannelIn + ")."); + + config.n_channel_in + ")."); int[] axesNetDivBy = axesDivBy(axesNet); int[] grid = config.grid; if (grid.length != axesNet.length() - 1) @@ -131,7 +138,6 @@ private Map predictSetup(RandomAccessibleInterval img, String ax returns.put("axes", axes); returns.put("axesNet", axesNet); returns.put("axesNetDivBy", axesNetDivBy); - returns.put("permuteAxes", permuteAxes); returns.put("resizer", resizer); returns.put("nTiles", nTiles); returns.put("grid", grid); @@ -140,16 +146,15 @@ private Map predictSetup(RandomAccessibleInterval img, String ax return returns; } - private int channel; private Integer[] sh; private void tilingSetup(RandomAccessibleInterval x, List nTiles, int[] axesNetDivBy, - String axesNet, Map gridDict) { + String axesNet, Map gridDict, int channel) { String tilingAxes = axesNet.replace("C", ""); int[] xTilingAxes = new int[tilingAxes.length()]; int c = 0; for (String a : tilingAxes.split("")) { - xTilingAxes[c ++] = axesDict(axesNet).get(a); + xTilingAxes[c ++] = Utils.axesDict(axesNet).get(a); } int[] axesNetTileOverlaps = axesTileOverlap(axesNet); // TODO nTiles = permuteAxes(); @@ -166,7 +171,13 @@ private void tilingSetup(RandomAccessibleInterval x, List nTiles, int[] } } - private RandomAccessibleInterval createEmptyOutput(int nChannel) { + private void indProbThresh(RandomAccessibleInterval prob, float probThresh, Integer b) { + if (b == null) + b = 2; + + } + + private RandomAccessibleInterval createEmptyOutput(int nChannel, int channel) { sh[channel] = nChannel; long[] dims = new long[sh.length]; for (int i = 0; i < sh.length; i ++) @@ -222,23 +233,40 @@ private int[] axesDivBy(String queryAxes) { } return arr; } - - private Map axesDict(String axes) { - String[] strs = Utils.axesCheckAndNormalize(axes, null, null); - axes = strs[0]; - String allowed = strs[1]; - Map map = new HashMap(); - for (String a : allowed.split("")) - map.put(a, axes.indexOf(a) == -1 ? null : axes.indexOf(a)); - return map; - } private String normalizeAxes(RandomAccessibleInterval img, String axes) { - return null; + if (axes == null) { + axes = this.config.axes; + if (!axes.toUpperCase().contains("C")) + throw new IllegalArgumentException(); + if (img.numDimensions() == axes.length() - 1 && this.config.n_channel_in == 1) + axes = axes.replace("C", ""); + } + return Utils.axesCheckAndNormalize(axes, img.numDimensions(), null)[0]; } - private String permuteAxes(String axes, String axesNet) { + private RandomAccessibleInterval permuteAxes(RandomAccessibleInterval data, + String imgAxesIn, String netAxesIn, String netAxesOut, String imgAxesOut, + boolean undo) { + if (data == null) + return null; + if (netAxesOut == null) + netAxesOut = netAxesIn; + if (imgAxesOut == null) + imgAxesOut = imgAxesIn; + if (imgAxesIn.toUpperCase().contains("C") || imgAxesOut.toUpperCase().contains("C")) + throw new IllegalArgumentException(); + if (!netAxesIn.toUpperCase().contains("C") || !netAxesOut.toUpperCase().contains("C")) + throw new IllegalArgumentException(); + + if (undo && imgAxesIn.toUpperCase().contains("C")) { + return Utils.moveImageAxes(data, netAxesOut, imgAxesOut, true); + } else if (undo) { + // TODO + } else { + return Utils.moveImageAxes(data, imgAxesIn, netAxesIn, true); + } return null; } - + */ } diff --git a/src/main/java/io/bioimage/modelrunner/model/stardist_java_deprecate/Normalizer.java b/src/main/java/io/bioimage/modelrunner/model/stardist_java_deprecate/Normalizer.java new file mode 100644 index 00000000..7a641311 --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/model/stardist_java_deprecate/Normalizer.java @@ -0,0 +1,5 @@ +package io.bioimage.modelrunner.model.stardist_java_deprecate; + +public class Normalizer { + +} diff --git a/src/main/java/io/bioimage/modelrunner/model/stardist/Resizer.java b/src/main/java/io/bioimage/modelrunner/model/stardist_java_deprecate/Resizer.java similarity index 97% rename from src/main/java/io/bioimage/modelrunner/model/stardist/Resizer.java rename to src/main/java/io/bioimage/modelrunner/model/stardist_java_deprecate/Resizer.java index 1ba2611d..c9e80f16 100644 --- a/src/main/java/io/bioimage/modelrunner/model/stardist/Resizer.java +++ b/src/main/java/io/bioimage/modelrunner/model/stardist_java_deprecate/Resizer.java @@ -1,4 +1,4 @@ -package io.bioimage.modelrunner.model.stardist; +package io.bioimage.modelrunner.model.stardist_java_deprecate; import java.util.HashMap; import java.util.List; diff --git a/src/main/java/io/bioimage/modelrunner/model/stardist_java_deprecate/StardistConfig.java b/src/main/java/io/bioimage/modelrunner/model/stardist_java_deprecate/StardistConfig.java new file mode 100644 index 00000000..75d2ca46 --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/model/stardist_java_deprecate/StardistConfig.java @@ -0,0 +1,139 @@ +/*- + * #%L + * Use deep learning frameworks from Java in an agnostic and isolated way. + * %% + * Copyright (C) 2022 - 2024 Institut Pasteur and BioImage.IO developers. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package io.bioimage.modelrunner.model.stardist_java_deprecate; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.util.Map; + +import io.bioimage.modelrunner.utils.JSONUtils; + +/** + * Implementation of the Stardist Config instance in Java. + * + *@author Carlos Garcia + */ +public class StardistConfig { + + final public int n_dim; + final public String axes; + final public int n_channel_in; + final public int n_channel_out; + final public String train_checkpoint; + final public String train_checkpoint_last; + final public String train_checkpoint_epoch; + final public int n_rays; + final public int[] grid; + final public String backbone; + final public int n_classes; + final public int unet_n_depth; + final public int[] unet_kernel_size; + final public int unet_n_filter_base; + final public int unet_n_conv_per_depth; + final public int[] unet_pool; + final public String unet_activation; + final public String unet_last_activation; + final public boolean unet_batch_norm; + final public double unet_dropout; + final public String unet_prefix; + final public int net_conv_after_unet; + final public int[] net_input_shape; + final public int[] net_mask_shape; + final public boolean train_shape_completion; + final public int train_completion_crop; + final public int[] train_patch_size; + final public float train_background_reg; + final public float train_foreground_only; + final public boolean train_sample_cache; + final public String train_dist_loss; + final public double[] train_loss_weights; + final public double[] train_class_weights; + final public int train_epochs; + + + final public int train_steps_per_epoch; + final public double train_learning_rate; + final public int train_batch_size; + final public Integer train_n_val_patches; + final public boolean train_tensorboard; + final public Map train_reduce_lr; + final public int patience; + final public double min_delta; + final public double factor; + final public boolean use_gpu; + + public static final String FNAME = "config.json"; + + private StardistConfig(String path) throws IOException { + Map config = JSONUtils.load(path); + n_dim = (int) config.get("n_dim"); + axes = (String) config.get("axes"); + n_channel_in = (int) config.get("n_channel_in"); + n_channel_out = (int) config.get("n_channel_out"); + train_checkpoint = (String) config.get("train_checkpoint"); + train_checkpoint_last = (String) config.get("train_checkpoint_last"); + train_checkpoint_epoch = (String) config.get("train_checkpoint_epoch"); + n_rays = (int) config.get("n_rays"); + grid = (int[]) config.get("grid"); + backbone = (String) config.get("backbone"); + n_classes = (int) config.get("n_classes"); + unet_n_depth = (int) config.get("unet_n_depth"); + unet_kernel_size = (int[]) config.get("unet_kernel_size"); + unet_n_filter_base = (int) config.get("unet_n_filter_base"); + unet_n_conv_per_depth = (int) config.get("unet_n_conv_per_depth"); + unet_pool = (int[]) config.get("unet_pool"); + unet_activation = (String) config.get("unet_activation"); + unet_last_activation = (String) config.get("unet_last_activation"); + unet_batch_norm = (boolean) config.get("unet_batch_norm"); + unet_dropout = (double) config.get("unet_dropout"); + unet_prefix = (String) config.get("unet_prefix"); + net_conv_after_unet = (int) config.get("net_conv_after_unet"); + net_input_shape = (int[]) config.get("net_input_shape"); + net_mask_shape = (int[]) config.get("net_mask_shape"); + train_shape_completion = (boolean) config.get("train_shape_completion"); + train_completion_crop = (int) config.get("train_completion_crop"); + train_patch_size = (int[]) config.get("train_patch_size"); + train_background_reg = (float) config.get("train_background_reg"); + train_foreground_only = (float) config.get("train_foreground_only"); + train_sample_cache = (boolean) config.get("train_sample_cache"); + train_dist_loss = (String) config.get("train_dist_loss"); + train_loss_weights = (double[]) config.get("train_loss_weights"); + train_class_weights = (double[]) config.get("train_class_weights"); + train_epochs = (int) config.get("train_epochs"); + train_steps_per_epoch = (int) config.get("train_steps_per_epoch"); + train_learning_rate = (double) config.get("train_learning_rate"); + train_batch_size = (int) config.get("train_batch_size"); + train_n_val_patches = (Integer) config.get("train_n_val_patches"); + train_tensorboard = (boolean) config.get("train_tensorboard"); + train_reduce_lr = (Map) config.get("train_reduce_lr"); + use_gpu = (boolean) config.get("use_gpu"); + factor = (double) train_reduce_lr.get("factor"); + patience = (int) train_reduce_lr.get("patience"); + min_delta = (double) train_reduce_lr.get("min_delta"); + } + + public static StardistConfig create(String name, String baseDir) throws IOException { + if (new File(baseDir + File.separator + name + File.separator + FNAME).isFile() == false) + throw new IllegalArgumentException("No '" + FNAME + "' found at " + baseDir + File.separator + name); + return new StardistConfig(baseDir + File.separator + name + File.separator + FNAME); + } + +} diff --git a/src/main/java/io/bioimage/modelrunner/model/stardist_java_deprecate/Utils.java b/src/main/java/io/bioimage/modelrunner/model/stardist_java_deprecate/Utils.java new file mode 100644 index 00000000..12d92023 --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/model/stardist_java_deprecate/Utils.java @@ -0,0 +1,78 @@ +package io.bioimage.modelrunner.model.stardist_java_deprecate; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.type.NativeType; +import net.imglib2.type.numeric.RealType; + +public class Utils { + + protected static String[] axesCheckAndNormalize(String axes, Integer length, Boolean disallowed) { + String allowed = "STCZYX"; + if (axes == null) + throw new IllegalArgumentException("Axes cannot be null"); + axes = axes.toUpperCase(); + for (String a : axes.split("")) { + if (!allowed.contains(a)) + throw new IllegalArgumentException("Invalid axis: " + a + ", it must be one of " + allowed); + if (axes.replace(a, "").length() + 1 != axes.length()) + throw new IllegalArgumentException("Invalid axis: " + a + " can only appear once."); + } + if (length != null && axes.length() != length) + throw new IllegalArgumentException("Axes (" + axes + ") must have length " + length); + return new String[] {axes, allowed}; + } + + protected static & RealType> RandomAccessibleInterval + moveImageAxes(RandomAccessibleInterval x, String fr, String to, boolean adjustSingletons) { + String[] strs = axesCheckAndNormalize(fr, x.numDimensions(), false); + fr = strs[0]; + strs = axesCheckAndNormalize(to, null, false); + to = strs[0]; + String frInitial = fr; + long[] xShapeInitial = x.dimensionsAsLongArray(); + if (adjustSingletons) { + // TODO + } + Set toSet = Arrays.asList(to.split("")).stream().collect(Collectors.toSet()); + Set frSet = Arrays.asList(fr.split("")).stream().collect(Collectors.toSet()); + + if (!frSet.equals(toSet)) + throw new IllegalArgumentException("Image dims '" + fr + "' not compatible " + + "with target dims '" + to + "'."); + Map axFrom = axesDict(fr); + Map axTo = axesDict(to); + if (fr.equals(to)) + return x; + int[] src = new int[fr.length()]; + int[] dest = new int[fr.length()]; + for (int i = 0; i < fr.length(); i ++) { + String a = fr.split("")[i]; + src[i] = axFrom.get(a); + dest[i] = axTo.get(a); + } + int[] orderChange = new int[src.length]; + for (int i = 0; i < orderChange.length; i ++) { + int position = Arrays.asList(src).indexOf(dest[i]); + orderChange[i] = position; + } + + return io.bioimage.modelrunner.tensor.Utils.rearangeAxes(x, orderChange); + } + + protected static Map axesDict(String axes) { + String[] strs = Utils.axesCheckAndNormalize(axes, null, null); + axes = strs[0]; + String allowed = strs[1]; + Map map = new HashMap(); + for (String a : allowed.split("")) + map.put(a, axes.indexOf(a) == -1 ? null : axes.indexOf(a)); + return map; + } + +} diff --git a/src/main/java/io/bioimage/modelrunner/tensor/Utils.java b/src/main/java/io/bioimage/modelrunner/tensor/Utils.java index b6f84733..2b793717 100644 --- a/src/main/java/io/bioimage/modelrunner/tensor/Utils.java +++ b/src/main/java/io/bioimage/modelrunner/tensor/Utils.java @@ -19,10 +19,7 @@ */ package io.bioimage.modelrunner.tensor; -import java.util.ArrayList; -import java.util.List; -import net.imglib2.Point; import net.imglib2.RandomAccessibleInterval; import net.imglib2.transform.integer.MixedTransform; import net.imglib2.type.numeric.NumericType;