Skip to content

Commit

Permalink
Merge pull request #39 from bioimage-io/python-appose
Browse files Browse the repository at this point in the history
Python appose
  • Loading branch information
carlosuc3m authored Dec 12, 2023
2 parents ed4c088 + 039ff50 commit 52be2e1
Show file tree
Hide file tree
Showing 58 changed files with 8,964 additions and 372 deletions.
15 changes: 15 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,20 @@
<artifactId>jackson-dataformat-msgpack</artifactId>
<version>0.9.5</version>
</dependency>
<dependency>
<groupId>org.apposed</groupId>
<artifactId>appose</artifactId>
<version>0.1.1-SNAPSHOT</version>
</dependency>
<!--<dependency>
<groupId>org.msgpack</groupId>
<artifactId>jackson-dataformat-msgpack</artifactId>
<version>0.9.0</version>
</dependency> -->
<dependency>
<groupId>net.java.dev.jna</groupId>
<artifactId>jna</artifactId>
<version>5.13.0</version>
</dependency>
</dependencies>
</project>
15 changes: 15 additions & 0 deletions python/op_environments/stardist_environment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: stardist_op
channels:
- conda-forge
- defaults
dependencies:
- bioimageio.core
- black
- conda-build
- dask
- mypy
- pip
- python==3.9.*
- stardist
- tensorflow==2.*
- xarray
10 changes: 10 additions & 0 deletions python/op_environments/stardist_postprocessing.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: stardist
channels:
- conda-forge
- defaults
dependencies:
- python=3.10
- appose
- numpy
- xarray
- stardist
55 changes: 55 additions & 0 deletions python/ops/stardist_fine_tune/stardist_fine_tune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import xarray as xr
from stardist.models import StarDist2D

import os
import shutil
import os
from pathlib import Path


def assertions(model_path, images, ground_truth, ):

assert isinstance(model_path, str), "The input argument 'model_path' must be a string, either the name" \
+ " of one of the default pre-trained Stardist models or the directory to a pre-trained Stardist model"

assert isinstance(images, xr.DataArray), "the training samples should be a xr.DataArray"
assert isinstance(ground_truth, xr.DataArray), "the ground thruth should be a xr.DataArray"

assert images.ndim == 4, "the training samples array must have 4 dimensions"
assert ground_truth.ndim == 3, "the training samples array must have 3 dimensions"

assert "".join(images.dims) == "byxc", "the training samples axes order should be 'byxc', not '" + "".join(images.dims) + "' as provided."
assert "".join(ground_truth.dims) == "byx", "the ground truth samples axes order should be 'byx', not '" + "".join(ground_truth.dims) + "' as provided."

axes_dict = {"batch size": 0, "width": 1, "height": 2}

for ks, vs in axes_dict.items():
assert images.shape[vs] == ground_truth.shape[vs], "The training samples " \
+ "and the ground truth need to have the same " + ks + " : " \
+ str(images.shape[vs]) + " vs " + str(ground_truth.shape[vs])


def finetune_stardist(model_path, images, ground_truth, weights_file=None):
"""
model_path: String, path to pretrained model or pretrained model from the stardsit available
images: list of tensors or single tensor? If a list of tensors, it would need to be ensured taht they all have same dims,
or reconstruct to have same dims. Check the number of channels and check if the channels of the images coincide
Also for a path, check that it has the needed files fo a stardist model
ground_truth: list of tensors or single tensor? It needs to have the same type and size than images
epochs and batch_size might have a warning for CPu if selected too large
"""
assertions(model_path, images, ground_truth)

model = StarDist2D(None, model_path)
if weights_file is not None:
model.load_weights("weights_last.h5")

# finetune on new data
history = model.train(images, ground_truth, validation_data=(images, ground_truth))

Path(model_path).mkdir(parents=True, exist_ok=True)
#model.keras_model.save(os.path.join(model_path, "stardist_weights.h5"))
model.export_TF(os.path.join(os.path.dirname(model_path), "TF_SavedModel.zip"))

return history.history
252 changes: 252 additions & 0 deletions python/ops/stardist_inference/stardist_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
import tempfile
import warnings
from math import ceil
from os import PathLike
from os import path
from pathlib import Path
from typing import Dict, IO, List, Optional, Tuple, Union
from csbdeep.utils import axes_check_and_normalize, normalize, _raise
from bioimageio.spec import load_raw_resource_description

import xarray as xr
from stardist import import_bioimageio as stardist_import_bioimageio

from bioimageio.core import export_resource_package, load_resource_description
from bioimageio.core.prediction_pipeline._combined_processing import CombinedProcessing
from bioimageio.core.prediction_pipeline._measure_groups import compute_measures
from bioimageio.core.resource_io.utils import SourceNodeTransformer, resolve_source, RawNodeTypeTransformer
from bioimageio.core.resource_io.nodes import Model
from bioimageio.core.resource_io.io_ import nodes
from bioimageio.spec.model import raw_nodes
from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription

import numpy as np

from bioimageio.spec.shared.node_transformer import UriNodeTransformer

RDF_YAML_SUFF = 'rdf.yaml'
RDF_YAML_SUFF_DEPR = 'model.yaml'

STARDIST_OP_NAME = 'stardist_op'

def stardist_prediction_2d_mine(
model_rdf: Union[str, PathLike, dict, IO, bytes, raw_nodes.URI, RawResourceDescription],
input_tensor: xr.DataArray,
tile: Optional[Dict[str, int]] = None,
) -> Tuple[xr.DataArray, dict]:
model = None
if isinstance(model_rdf, str) \
and (model_rdf.endswith(RDF_YAML_SUFF_DEPR) or model_rdf.endswith(RDF_YAML_SUFF)) \
and path.exists(model_rdf):
import shutil
from csbdeep.utils import save_json
from stardist.models import StarDist2D, StarDist3D
biomodel = load_raw_resource_description(model_rdf, update_to_format="latest")
"""
rd = UriNodeTransformer(root_path=biomodel.root_path, uri_only_if_in_package=True).transform(
biomodel)
rd2 = UriNodeTransformer(root_path=biomodel.root_path, uri_only_if_in_package=False).transform(
biomodel)
aa = isinstance(rd, Model)
rd = SourceNodeTransformer().transform(rd)
cc = isinstance(rd, Model)
rd = RawNodeTypeTransformer(nodes).transform(rd)
cc = isinstance(rd, Model)
model = load_resource_description(model_rdf)
"""
biomodel = RawNodeTypeTransformer(nodes).transform(biomodel)
# read the stardist specific content
if 'stardist' not in biomodel.config:
raise(RuntimeError("bioimage.io model not compatible"))
config = biomodel.config['stardist']['config']
thresholds = biomodel.config['stardist']['thresholds']
weights = biomodel.config['stardist']['weights']

# make sure that the keras weights are in the attachments
weights_file = None
for f in biomodel.attachments.files:
if str(f).endswith("/" + weights):
weights_file = f
break
weights_file is not None or _raise(FileNotFoundError(f"couldn't find weights file '{weights}'"))


# save the config and threshold to json, and weights to hdf5 to enable loading as stardist model
# copy bioimageio files to separate sub-folder
outpath = Path(Path(path.dirname(model_rdf)) / STARDIST_OP_NAME)

outpath.mkdir(parents=True, exist_ok=True)
save_json(config, str(outpath / 'config.json'))
save_json(thresholds, str(outpath / 'thresholds.json'))
if path.exists(Path(path.dirname(model_rdf)) / weights):
shutil.copy(str(weights_file), str(outpath / "weights_bioimageio.h5"))
else:
resolve_source(weights_file, Path(model_rdf), Path(str(outpath / "weights_bioimageio.h5")))

model_class = (StarDist2D if config['n_dim'] == 2 else StarDist3D)
imported_stardist_model = model_class(None, outpath.name, basedir=str(outpath.parent))

#assert isinstance(biomodel, Model)
if len(biomodel.inputs) != 1:
raise NotImplementedError("Multiple inputs for stardist models not yet implemented")

if len(biomodel.outputs) != 1:
raise NotImplementedError("Multiple outputs for stardist models not yet implemented")

# rename tensor axes to single letters to match model RDF
#map_axes = {k: v for k, v in AXIS_NAME_TO_LETTER.items() if k in input_tensor.dims}
map_axes = None
if map_axes:
input_tensor = input_tensor.rename(map_axes)

prep = CombinedProcessing.from_tensor_specs(biomodel.inputs)
ipt_name = biomodel.inputs[0].name
sample = {ipt_name: input_tensor}
computed_measures = compute_measures(prep.required_measures, sample=sample)
prep.apply(sample, computed_measures)

preprocessed_input = sample[ipt_name]
#map_axes_back = {k: v for k, v in AXIS_LETTER_TO_NAME.items() if k in preprocessed_input.dims}
map_axes_back = None
if map_axes_back:
preprocessed_input = preprocessed_input.rename(map_axes_back)

#input_axis_order = [AXIS_LETTER_TO_NAME.get(a, a) for a in model.inputs[0].axes]
input_axis_order = "byxc"
if tile is None:
n_tiles: Optional[List[int]] = None
else:
n_tiles = []
for a in input_axis_order:
t = tile[a]
s = preprocessed_input.sizes[a]
n_tiles.append(max(ceil(s / t), 1))

warnings.warn(f"translated tile {tile} to n_tiles: {n_tiles} for stardist library.")

img = preprocessed_input.transpose(*input_axis_order).to_numpy()
labels, polys = imported_stardist_model.predict_instances(
img,
axes="".join([{"b": "S"}.get(a[0], a[0].capitalize()) for a in biomodel.inputs[0].axes]),
n_tiles=n_tiles,
)

if len(labels.shape) == 2: # batch dim got squeezed
labels = labels[None]

output_axes_wo_channels = tuple(a for a in biomodel.outputs[0].axes if a != "c")
assert output_axes_wo_channels == tuple("byx")
return xr.DataArray(labels, dims=output_axes_wo_channels), polys


def stardist_prediction_2d(
model_rdf: Union[str, PathLike, dict, IO, bytes, raw_nodes.URI, RawResourceDescription],
input_tensor: xr.DataArray,
tile: Optional[Dict[str, int]] = None,
) -> Tuple[xr.DataArray, dict]:
"""stardist prediction 2d
A workflow to apply a stardist model and the stardist postprocessing.
This workflow is loosely based on https://nbviewer.org/github/stardist/stardist/blob/master/examples/2D/3_prediction.ipynb
.. code-block:: yaml
authors: [{name: Fynn Beuttenmüller, github_user: fynnbe}]
cite:
- text: BioImage.IO
doi: 10.1101/2022.06.07.495102
- text: "Stardist: Cell Detection with Star-Convex Polygons"
doi: 10.1007/978-3-030-00934-2_30
- text: "Stardist: Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy"
doi: 10.1109/WACV45572.2020.9093435
Args:
model_rdf: the (source/raw) model RDF that describes the stardist model to be used for inference
input_tensor: raw input
axes:
- type: batch
- type: channel
- type: space
name: y
- type: space
name: x
tile: Tile shape for model input. Defaults to no tiling. Currently ignored for preprocessing.
Returns:
labels. Labels of detected objects
axes:
- type: batch
- type: space
name: y
- type: space
name: x
polys. Dictionary describing the labeled object's polygons
"""
# todo: use inference_with_dask for model inference and then apply stardist postprocessing.
# outputs = await inference_with_dask(model_rdf, input_tensor, boundary_mode=boundary_mode, enable_preprocessing=enable_preprocessing, enable_postprocessing=True, tiles=[tile])
# assert len(outputs) == 1
# output = outputs["output"]

package_path = export_resource_package(model_rdf)
with tempfile.TemporaryDirectory() as tmp_dir:
import_dir = Path(tmp_dir) / "import_dir"
imported_stardist_model = stardist_import_bioimageio(package_path, import_dir)

model = load_resource_description(package_path)
assert isinstance(model, Model)
if len(model.inputs) != 1:
raise NotImplementedError("Multiple inputs for stardist models not yet implemented")

if len(model.outputs) != 1:
raise NotImplementedError("Multiple outputs for stardist models not yet implemented")

# rename tensor axes to single letters to match model RDF
#map_axes = {k: v for k, v in AXIS_NAME_TO_LETTER.items() if k in input_tensor.dims}
map_axes = "byxc"
if map_axes:
input_tensor = input_tensor.rename(map_axes)

prep = CombinedProcessing.from_tensor_specs(model.inputs)
ipt_name = model.inputs[0].name
sample = {ipt_name: input_tensor}
computed_measures = compute_measures(prep.required_measures, sample=sample)
prep.apply(sample, computed_measures)

preprocessed_input = sample[ipt_name]
#map_axes_back = {k: v for k, v in AXIS_LETTER_TO_NAME.items() if k in preprocessed_input.dims}
map_axes_back = "byxc"
if map_axes_back:
preprocessed_input = preprocessed_input.rename(map_axes_back)

#input_axis_order = [AXIS_LETTER_TO_NAME.get(a, a) for a in model.inputs[0].axes]
input_axis_order = "byxc"
if tile is None:
n_tiles: Optional[List[int]] = None
else:
n_tiles = []
for a in input_axis_order:
t = tile[a]
s = preprocessed_input.sizes[a]
n_tiles.append(max(ceil(s / t), 1))

warnings.warn(f"translated tile {tile} to n_tiles: {n_tiles} for stardist library.")

img = preprocessed_input.transpose(*input_axis_order).to_numpy()
labels, polys = imported_stardist_model.predict_instances(
img,
axes="".join([{"b": "S"}.get(a[0], a[0].capitalize()) for a in model.inputs[0].axes]),
n_tiles=n_tiles,
)

if len(labels.shape) == 2: # batch dim got squeezed
labels = labels[None]

output_axes_wo_channels = tuple(a for a in model.outputs[0].axes if a != "c")
assert output_axes_wo_channels == tuple("byx")
return xr.DataArray(labels, dims=output_axes_wo_channels), polys

#arr = np.zeros((1, 208, 208, 3))
#xarr = xr.DataArray(arr, dims=["b", "y", "x", "c"], name="input")
#model_path = "chatty-frog"
#model_path = r'C:\Users\angel\OneDrive\Documentos\pasteur\git\model-runner-java\models\StarDist H&E Nuclei Segmentation_06092023_020924\rdf.yaml'
#stardist_prediction_2d_mine(model_path, xarr)
Loading

0 comments on commit 52be2e1

Please sign in to comment.