Skip to content

Commit

Permalink
give up on the Java wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Jan 15, 2025
1 parent c54a950 commit 6acc18e
Show file tree
Hide file tree
Showing 10 changed files with 281 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
import io.bioimage.modelrunner.exceptions.LoadModelException;
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.model.processing.Processing;
import io.bioimage.modelrunner.model.stardist.StardistConfig;
import io.bioimage.modelrunner.model.stardist_java_deprecate.StardistConfig;
import io.bioimage.modelrunner.runmode.RunMode;
import io.bioimage.modelrunner.runmode.ops.GenericOp;
import io.bioimage.modelrunner.tensor.Tensor;
Expand Down

This file was deleted.

This file was deleted.

21 changes: 0 additions & 21 deletions src/main/java/io/bioimage/modelrunner/model/stardist/Utils.java

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.bioimage.modelrunner.model.stardist;
package io.bioimage.modelrunner.model.stardist_java_deprecate;

import java.util.ArrayList;
import java.util.HashMap;
Expand All @@ -9,7 +9,7 @@
import net.imglib2.img.array.ArrayImgs;

public class AbstractStardist {

/*
private StardistConfig config;
private float probThres;
Expand Down Expand Up @@ -39,8 +39,14 @@ private void predict_instances(RandomAccessibleInterval img, Map<String, Object>
String _axes = normalizeAxes(img, axes);
String axesNet = config.axes;
String permuteAxes = permuteAxes(_axes, axesNet);
int[] shapeInst;
long[] shape = permuteAxes(null, null, _axes, axesNet, null, true).dimensionsAsLongArray();
int[] shapeInst = new int[_axes.toUpperCase().replace("C", "").length()];
int i = 0;
for (long s : shape) {
if (_axes.toUpperCase().split("")[i].equals("C"))
continue;
shapeInst[i ++] = (int) s;
}
Number scale = null;
if (kwargs.get("scale") != null && kwargs.get("scale") instanceof Number)
Expand Down Expand Up @@ -78,7 +84,7 @@ private void predictSparseGenerator(RandomAccessibleInterval img, String axes, N
Map<String, Object> returns = predictSetup(img, axes, normalizer, nTiles);
RandomAccessibleInterval x = (RandomAccessibleInterval) returns.get("x");
List<Integer> nTiles = (List<Integer>) returns.get("nTiles");
nTiles = (List<Integer>) returns.get("nTiles");
axes = (String) returns.get("axes");
String axesNet = (String) returns.get("axesNet");
int[] axesNetDivBy = (int[]) returns.get("axesNetDivBy");
Expand All @@ -90,7 +96,9 @@ private void predictSparseGenerator(RandomAccessibleInterval img, String axes, N
if (product > 1) {
// TODO
} else {

// TODO
RandomAccessibleInterval prob = null; //(256, 256, 1)
RandomAccessibleInterval dist = null; //(256, 256, 32)
}
}
Expand All @@ -105,13 +113,12 @@ private Map<String, Object> predictSetup(RandomAccessibleInterval img, String ax
+ ") should be the same as the tile list lenght (" + nTiles.size() + ").");
axes = normalizeAxes(img, axes);
String axesNet = this.config.axes;
// TODO permuteAxes
RandomAccessibleInterval x = null; // TODO
channel = axesDict(axesNet).get("C");
if (this.config.nChannelIn != x.dimensionsAsLongArray()[channel])
RandomAccessibleInterval x = permuteAxes(img, axes, axesNet, null, null, true);
int channel = Utils.axesDict(axesNet).get("C");
if (this.config.n_channel_in != x.dimensionsAsLongArray()[channel])
throw new IllegalArgumentException("The number of channels of the image ("
+ x.dimensionsAsLongArray()[channel] + ") should be the same as the model config ("
+ config.nChannelIn + ").");
+ config.n_channel_in + ").");
int[] axesNetDivBy = axesDivBy(axesNet);
int[] grid = config.grid;
if (grid.length != axesNet.length() - 1)
Expand All @@ -131,7 +138,6 @@ private Map<String, Object> predictSetup(RandomAccessibleInterval img, String ax
returns.put("axes", axes);
returns.put("axesNet", axesNet);
returns.put("axesNetDivBy", axesNetDivBy);
returns.put("permuteAxes", permuteAxes);
returns.put("resizer", resizer);
returns.put("nTiles", nTiles);
returns.put("grid", grid);
Expand All @@ -140,16 +146,15 @@ private Map<String, Object> predictSetup(RandomAccessibleInterval img, String ax
return returns;
}
private int channel;
private Integer[] sh;
private void tilingSetup(RandomAccessibleInterval x, List<Integer> nTiles, int[] axesNetDivBy,
String axesNet, Map<String, Integer> gridDict) {
String axesNet, Map<String, Integer> gridDict, int channel) {
String tilingAxes = axesNet.replace("C", "");
int[] xTilingAxes = new int[tilingAxes.length()];
int c = 0;
for (String a : tilingAxes.split("")) {
xTilingAxes[c ++] = axesDict(axesNet).get(a);
xTilingAxes[c ++] = Utils.axesDict(axesNet).get(a);
}
int[] axesNetTileOverlaps = axesTileOverlap(axesNet);
// TODO nTiles = permuteAxes();
Expand All @@ -166,7 +171,13 @@ private void tilingSetup(RandomAccessibleInterval x, List<Integer> nTiles, int[]
}
}
private RandomAccessibleInterval createEmptyOutput(int nChannel) {
private void indProbThresh(RandomAccessibleInterval prob, float probThresh, Integer b) {
if (b == null)
b = 2;
}
private RandomAccessibleInterval createEmptyOutput(int nChannel, int channel) {
sh[channel] = nChannel;
long[] dims = new long[sh.length];
for (int i = 0; i < sh.length; i ++)
Expand Down Expand Up @@ -222,23 +233,40 @@ private int[] axesDivBy(String queryAxes) {
}
return arr;
}

private Map<String, Integer> axesDict(String axes) {
String[] strs = Utils.axesCheckAndNormalize(axes, null, null);
axes = strs[0];
String allowed = strs[1];
Map<String, Integer> map = new HashMap<String, Integer>();
for (String a : allowed.split(""))
map.put(a, axes.indexOf(a) == -1 ? null : axes.indexOf(a));
return map;
}
private String normalizeAxes(RandomAccessibleInterval img, String axes) {
return null;
if (axes == null) {
axes = this.config.axes;
if (!axes.toUpperCase().contains("C"))
throw new IllegalArgumentException();
if (img.numDimensions() == axes.length() - 1 && this.config.n_channel_in == 1)
axes = axes.replace("C", "");
}
return Utils.axesCheckAndNormalize(axes, img.numDimensions(), null)[0];
}
private String permuteAxes(String axes, String axesNet) {
private RandomAccessibleInterval permuteAxes(RandomAccessibleInterval data,
String imgAxesIn, String netAxesIn, String netAxesOut, String imgAxesOut,
boolean undo) {
if (data == null)
return null;
if (netAxesOut == null)
netAxesOut = netAxesIn;
if (imgAxesOut == null)
imgAxesOut = imgAxesIn;
if (imgAxesIn.toUpperCase().contains("C") || imgAxesOut.toUpperCase().contains("C"))
throw new IllegalArgumentException();
if (!netAxesIn.toUpperCase().contains("C") || !netAxesOut.toUpperCase().contains("C"))
throw new IllegalArgumentException();
if (undo && imgAxesIn.toUpperCase().contains("C")) {
return Utils.moveImageAxes(data, netAxesOut, imgAxesOut, true);
} else if (undo) {
// TODO
} else {
return Utils.moveImageAxes(data, imgAxesIn, netAxesIn, true);
}
return null;
}

*/
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package io.bioimage.modelrunner.model.stardist_java_deprecate;

public class Normalizer {

}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.bioimage.modelrunner.model.stardist;
package io.bioimage.modelrunner.model.stardist_java_deprecate;

import java.util.HashMap;
import java.util.List;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*-
* #%L
* Use deep learning frameworks from Java in an agnostic and isolated way.
* %%
* Copyright (C) 2022 - 2024 Institut Pasteur and BioImage.IO developers.
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package io.bioimage.modelrunner.model.stardist_java_deprecate;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.Map;

import io.bioimage.modelrunner.utils.JSONUtils;

/**
* Implementation of the Stardist Config instance in Java.
*
*@author Carlos Garcia
*/
public class StardistConfig {

final public int n_dim;
final public String axes;
final public int n_channel_in;
final public int n_channel_out;
final public String train_checkpoint;
final public String train_checkpoint_last;
final public String train_checkpoint_epoch;
final public int n_rays;
final public int[] grid;
final public String backbone;
final public int n_classes;
final public int unet_n_depth;
final public int[] unet_kernel_size;
final public int unet_n_filter_base;
final public int unet_n_conv_per_depth;
final public int[] unet_pool;
final public String unet_activation;
final public String unet_last_activation;
final public boolean unet_batch_norm;
final public double unet_dropout;
final public String unet_prefix;
final public int net_conv_after_unet;
final public int[] net_input_shape;
final public int[] net_mask_shape;
final public boolean train_shape_completion;
final public int train_completion_crop;
final public int[] train_patch_size;
final public float train_background_reg;
final public float train_foreground_only;
final public boolean train_sample_cache;
final public String train_dist_loss;
final public double[] train_loss_weights;
final public double[] train_class_weights;
final public int train_epochs;


final public int train_steps_per_epoch;
final public double train_learning_rate;
final public int train_batch_size;
final public Integer train_n_val_patches;
final public boolean train_tensorboard;
final public Map<String, Object> train_reduce_lr;
final public int patience;
final public double min_delta;
final public double factor;
final public boolean use_gpu;

public static final String FNAME = "config.json";

private StardistConfig(String path) throws IOException {
Map<String, Object> config = JSONUtils.load(path);
n_dim = (int) config.get("n_dim");
axes = (String) config.get("axes");
n_channel_in = (int) config.get("n_channel_in");
n_channel_out = (int) config.get("n_channel_out");
train_checkpoint = (String) config.get("train_checkpoint");
train_checkpoint_last = (String) config.get("train_checkpoint_last");
train_checkpoint_epoch = (String) config.get("train_checkpoint_epoch");
n_rays = (int) config.get("n_rays");
grid = (int[]) config.get("grid");
backbone = (String) config.get("backbone");
n_classes = (int) config.get("n_classes");
unet_n_depth = (int) config.get("unet_n_depth");
unet_kernel_size = (int[]) config.get("unet_kernel_size");
unet_n_filter_base = (int) config.get("unet_n_filter_base");
unet_n_conv_per_depth = (int) config.get("unet_n_conv_per_depth");
unet_pool = (int[]) config.get("unet_pool");
unet_activation = (String) config.get("unet_activation");
unet_last_activation = (String) config.get("unet_last_activation");
unet_batch_norm = (boolean) config.get("unet_batch_norm");
unet_dropout = (double) config.get("unet_dropout");
unet_prefix = (String) config.get("unet_prefix");
net_conv_after_unet = (int) config.get("net_conv_after_unet");
net_input_shape = (int[]) config.get("net_input_shape");
net_mask_shape = (int[]) config.get("net_mask_shape");
train_shape_completion = (boolean) config.get("train_shape_completion");
train_completion_crop = (int) config.get("train_completion_crop");
train_patch_size = (int[]) config.get("train_patch_size");
train_background_reg = (float) config.get("train_background_reg");
train_foreground_only = (float) config.get("train_foreground_only");
train_sample_cache = (boolean) config.get("train_sample_cache");
train_dist_loss = (String) config.get("train_dist_loss");
train_loss_weights = (double[]) config.get("train_loss_weights");
train_class_weights = (double[]) config.get("train_class_weights");
train_epochs = (int) config.get("train_epochs");
train_steps_per_epoch = (int) config.get("train_steps_per_epoch");
train_learning_rate = (double) config.get("train_learning_rate");
train_batch_size = (int) config.get("train_batch_size");
train_n_val_patches = (Integer) config.get("train_n_val_patches");
train_tensorboard = (boolean) config.get("train_tensorboard");
train_reduce_lr = (Map<String, Object>) config.get("train_reduce_lr");
use_gpu = (boolean) config.get("use_gpu");
factor = (double) train_reduce_lr.get("factor");
patience = (int) train_reduce_lr.get("patience");
min_delta = (double) train_reduce_lr.get("min_delta");
}

public static StardistConfig create(String name, String baseDir) throws IOException {
if (new File(baseDir + File.separator + name + File.separator + FNAME).isFile() == false)
throw new IllegalArgumentException("No '" + FNAME + "' found at " + baseDir + File.separator + name);
return new StardistConfig(baseDir + File.separator + name + File.separator + FNAME);
}

}
Loading

0 comments on commit 6acc18e

Please sign in to comment.