-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
start creating new stardist wrapper in java with appose
- Loading branch information
1 parent
d2a65f2
commit 2af780d
Showing
6 changed files
with
341 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
155 changes: 155 additions & 0 deletions
155
src/main/java/io/bioimage/modelrunner/model/stardist/AbstractStardist.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
package io.bioimage.modelrunner.model.stardist; | ||
|
||
import java.util.ArrayList; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import net.imglib2.RandomAccessibleInterval; | ||
|
||
public class AbstractStardist { | ||
|
||
private StardistConfig config; | ||
|
||
private float probThres; | ||
|
||
private float nmsThres; | ||
|
||
private void predict_instances(RandomAccessibleInterval img, Map<String, Object> kwargs) { | ||
Map<String, Object> predictKwargs = new HashMap<String, Object>(); | ||
if (kwargs.get("predictKwargs") != null && kwargs.get("predictKwargs") instanceof Map) | ||
predictKwargs = (Map<String, Object>) kwargs.get("predictKwargs"); | ||
Map<String, Object> nmsKwargs = new HashMap<String, Object>(); | ||
if (kwargs.get("nmsKwargs") != null && kwargs.get("nmsKwargs") instanceof Map) | ||
nmsKwargs = (Map<String, Object>) kwargs.get("nmsKwargs"); | ||
boolean returnPredict = false; | ||
boolean sparse = true; | ||
if (kwargs.get("returnPredict") != null && kwargs.get("returnPredict") instanceof Boolean) | ||
returnPredict = (boolean) kwargs.get("returnPredict"); | ||
if (kwargs.get("sparse") != null && kwargs.get("sparse") instanceof Boolean) | ||
sparse = (boolean) kwargs.get("sparse"); | ||
if (returnPredict && sparse) | ||
sparse = false; | ||
String axes = ""; | ||
if (kwargs.get("axes") != null && kwargs.get("axes") instanceof String) | ||
axes = (String) kwargs.get("axes"); | ||
|
||
String _axes = normalizeAxes(img, axes); | ||
String axesNet = config.axes; | ||
String permuteAxes = permuteAxes(_axes, axesNet); | ||
int[] shapeInst; | ||
|
||
Number scale = null; | ||
if (kwargs.get("scale") != null && kwargs.get("scale") instanceof Number) | ||
scale = (Number) kwargs.get("scale"); | ||
|
||
if (scale != null) { | ||
// TODO | ||
for (String ax : _axes.split("")) { | ||
} | ||
} | ||
|
||
Normalizer normalizer = null; | ||
if (kwargs.get("normalizer") != null && kwargs.get("normalizer") instanceof Normalizer) | ||
normalizer = (Normalizer) kwargs.get("normalizer"); | ||
List<Integer> nTiles = null; | ||
if (kwargs.get("nTiles") != null && kwargs.get("nTiles") instanceof List) | ||
nTiles = ((List<Integer>) kwargs.get("nTiles")); | ||
Float probThresh = null; | ||
if (kwargs.get("probThresh") != null && kwargs.get("probThresh") instanceof Number) | ||
probThresh = ((Number) kwargs.get("probThresh")).floatValue(); | ||
|
||
if (sparse) { | ||
predictSparseGenerator(img, axes, normalizer, nTiles, probThresh); | ||
} else { | ||
|
||
} | ||
|
||
|
||
} | ||
|
||
private void predictSparseGenerator(RandomAccessibleInterval img, String axes, Normalizer normalizer, List<Integer> nTiles, | ||
Float probThresh) { | ||
if (probThresh == null) probThresh = this.probThres; | ||
|
||
Map<String, Object> args = predictSetup(img, axes, normalizer, nTiles); | ||
|
||
} | ||
|
||
private Map<String, Object> predictSetup(RandomAccessibleInterval img, String axes, Normalizer normalizer, List<Integer> nTiles) { | ||
if (nTiles == null) { | ||
nTiles = new ArrayList<Integer>(); | ||
for (int i = 0; i < img.dimensionsAsLongArray().length; i ++) nTiles.add(1); | ||
} | ||
if (nTiles.size() != img.dimensionsAsLongArray().length) | ||
throw new IllegalArgumentException("The number of image dimensions (" + img.dimensionsAsLongArray().length | ||
+ ") should be the same as the tile list lenght (" + nTiles.size() + ")."); | ||
axes = normalizeAxes(img, axes); | ||
String axesNet = this.config.axes; | ||
// TODO permuteAxes | ||
RandomAccessibleInterval x = null; // TODO | ||
int channel = axesDict(axesNet).get("C"); | ||
if (this.config.nChannelIn != x.dimensionsAsLongArray()[channel]) | ||
throw new IllegalArgumentException("The number of channels of the image (" | ||
+ x.dimensionsAsLongArray()[channel] + ") should be the same as the model config (" | ||
+ config.nChannelIn + ")."); | ||
int[] axesNetDivBy = axesDivBy(axesNet); | ||
int[] grid = config.grid; | ||
if (grid.length != axesNet.length() - 1) | ||
throw new IllegalArgumentException(); | ||
Map<String, Integer> gridDict = new HashMap<String, Integer>(); | ||
int i = 0; | ||
for (String a : axesNet.toUpperCase().split("")) { | ||
if (a == "C") | ||
continue; | ||
gridDict.put(a, grid[i ++]); | ||
} | ||
|
||
Resizer resizer = new Resizer(gridDict); | ||
return null; | ||
} | ||
|
||
private int[] axesDivBy(String queryAxes) { | ||
if (this.config.backbone.equals("unet")) | ||
throw new IllegalArgumentException("Backbone '" + config.backbone + "' not implemented."); | ||
String[] strs = Utils.axesCheckAndNormalize(queryAxes, null, null); | ||
queryAxes = strs[0]; | ||
|
||
if (config.unet_pool.length != config.grid.length) | ||
throw new IllegalArgumentException(); | ||
int i = 0; | ||
Map<String, Integer> divBy = new HashMap<String, Integer>(); | ||
for (String a : config.axes.split("")) { | ||
if (a.toUpperCase().equals("C")) | ||
continue; | ||
int val = (int) (Math.pow(config.unet_pool[i], config.unet_n_depth) * config.grid[i]); | ||
divBy.put(a.toUpperCase(), val); | ||
i ++; | ||
} | ||
int[] arr = new int[queryAxes.length()]; | ||
i = 0; | ||
for (String a : queryAxes.split("")) { | ||
arr[i ++] = divBy.keySet().contains(a) ? divBy.get(a) : 1; | ||
} | ||
return arr; | ||
} | ||
|
||
private Map<String, Integer> axesDict(String axes) { | ||
String[] strs = Utils.axesCheckAndNormalize(axes, null, null); | ||
axes = strs[0]; | ||
String allowed = strs[1]; | ||
Map<String, Integer> map = new HashMap<String, Integer>(); | ||
for (String a : allowed.split("")) | ||
map.put(a, axes.indexOf(a) == -1 ? null : axes.indexOf(a)); | ||
return map; | ||
} | ||
|
||
private String normalizeAxes(RandomAccessibleInterval img, String axes) { | ||
return null; | ||
} | ||
|
||
private String permuteAxes(String axes, String axesNet) { | ||
return null; | ||
} | ||
|
||
} |
5 changes: 5 additions & 0 deletions
5
src/main/java/io/bioimage/modelrunner/model/stardist/Normalizer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
package io.bioimage.modelrunner.model.stardist; | ||
|
||
public class Normalizer { | ||
|
||
} |
73 changes: 73 additions & 0 deletions
73
src/main/java/io/bioimage/modelrunner/model/stardist/Resizer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
package io.bioimage.modelrunner.model.stardist; | ||
|
||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import net.imglib2.FinalInterval; | ||
import net.imglib2.RandomAccessibleInterval; | ||
import net.imglib2.type.NativeType; | ||
import net.imglib2.type.numeric.RealType; | ||
import net.imglib2.view.Views; | ||
|
||
public class Resizer { | ||
|
||
|
||
private final Map<String, Integer> grid; | ||
|
||
private Map<String, Integer> pad; | ||
|
||
private Map<String, Integer> paddedShape; | ||
|
||
protected Resizer(Map<String, Integer> grid) { | ||
this.grid = grid; | ||
} | ||
|
||
protected <T extends NativeType<T> & RealType<T>> RandomAccessibleInterval<T> | ||
before(RandomAccessibleInterval<T> x, String axes, int[] axesDivBy) { | ||
for (int i = 0; i < axes.length(); i ++) { | ||
String ax = axes.split("")[i]; | ||
int g = grid.keySet().contains(ax) ? grid.get(ax) : 1; | ||
int a = axesDivBy[i]; | ||
if (a % g != 0) | ||
throw new IllegalArgumentException(); | ||
} | ||
|
||
String[] strs = Utils.axesCheckAndNormalize(axes, x.numDimensions(), null); | ||
axes = strs[0]; | ||
pad = new HashMap<String, Integer>(); | ||
for (int i = 0; i < axes.length(); i ++) { | ||
long val = (axesDivBy[i] - x.dimensionsAsLongArray()[i] % axesDivBy[i]) % axesDivBy[i]; | ||
String a = axes.split("")[i]; | ||
pad.put(a, (int) val); | ||
} | ||
long[] minLim = x.minAsLongArray(); | ||
long[] maxLim = x.maxAsLongArray(); | ||
int i = 0; | ||
for (String a : axes.split("")) { | ||
maxLim[ i ++] += pad.get(a); | ||
} | ||
RandomAccessibleInterval<T> xPad = Views.interval( | ||
Views.extendMirrorDouble(x), new FinalInterval( minLim, maxLim )); | ||
paddedShape = new HashMap<String, Integer>(); | ||
for (int j = 0; j < axes.length(); j ++) { | ||
String ax = axes.split("")[j]; | ||
if (ax.toUpperCase().equals("C")) | ||
continue; | ||
paddedShape.put(ax.toUpperCase(), (int) xPad.dimensionsAsLongArray()[j]); | ||
} | ||
return xPad; | ||
} | ||
|
||
protected <T extends NativeType<T> & RealType<T>> RandomAccessibleInterval<T> | ||
after(RandomAccessibleInterval<T> x, String axes) { | ||
String[] strs = Utils.axesCheckAndNormalize(axes, x.numDimensions(), null); | ||
axes = strs[0]; | ||
RandomAccessibleInterval<T> crop = null; | ||
return crop; | ||
} | ||
|
||
protected void filterPoints() { | ||
|
||
} | ||
} |
38 changes: 38 additions & 0 deletions
38
src/main/java/io/bioimage/modelrunner/model/stardist/StardistConfig.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
/*- | ||
* #%L | ||
* Use deep learning frameworks from Java in an agnostic and isolated way. | ||
* %% | ||
* Copyright (C) 2022 - 2024 Institut Pasteur and BioImage.IO developers. | ||
* %% | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* #L% | ||
*/ | ||
package io.bioimage.modelrunner.model.stardist; | ||
|
||
|
||
/** | ||
* Implementation of the Stardist Config instance in Java. | ||
* | ||
*@author Carlos Garcia | ||
*/ | ||
public class StardistConfig { | ||
|
||
public int[] unet_pool; | ||
public String axes; | ||
public double unet_n_depth; | ||
public int[] grid; | ||
public int nChannelIn; | ||
public String backbone; | ||
|
||
|
||
} |
21 changes: 21 additions & 0 deletions
21
src/main/java/io/bioimage/modelrunner/model/stardist/Utils.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
package io.bioimage.modelrunner.model.stardist; | ||
|
||
public class Utils { | ||
|
||
protected static String[] axesCheckAndNormalize(String axes, Integer length, String disallowed) { | ||
String allowed = "STCZYX"; | ||
if (axes == null) | ||
throw new IllegalArgumentException("Axes cannot be null"); | ||
axes = axes.toUpperCase(); | ||
for (String a : axes.split("")) { | ||
if (!allowed.contains(a)) | ||
throw new IllegalArgumentException("Invalid axis: " + a + ", it must be one of " + allowed); | ||
if (axes.replace(a, "").length() + 1 != axes.length()) | ||
throw new IllegalArgumentException("Invalid axis: " + a + " can only appear once."); | ||
} | ||
if (length != null && axes.length() != length) | ||
throw new IllegalArgumentException("Axes (" + axes + ") must have length " + length); | ||
return new String[] {axes, allowed}; | ||
} | ||
|
||
} |