Skip to content

Commit

Permalink
update docs for nnunet training
Browse files Browse the repository at this point in the history
  • Loading branch information
Karl5766 committed Dec 11, 2024
1 parent ca246a1 commit ff6a6b7
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 21 deletions.
111 changes: 106 additions & 5 deletions docs/GettingStarted/nnunet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ nn-UNet
Overview
********

nn-UNet is a UNet based library designed to segment medical images, refer to
nn-UNet is a 2d/3d U-NET library designed to segment medical images, refer to
`github <https://github.com/MIC-DKFZ/nnUNet>`_ and the following citation:

- Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring
Expand Down Expand Up @@ -100,8 +100,9 @@ In this next part, we discuss the annotation (part 2), training (part 3) and pre
Annotation
**********

Data quality is the most crucial to accurate predictions, in which case this is relevant to us in terms of how
well we can annotate the 3d image volume at hand. Our annotation is the negative masking of edge areas of the
Data quality is the most crucial to accurate predictions when training supervised models, in which case this is
relevant to us in terms of how well we can annotate 3d image volumes at hand.
Our annotation is the negative masking of edge areas of the
brain to remove edges before applying simple thresholding. We model how good an annotation of negative mask by
looking at:

Expand All @@ -114,7 +115,8 @@ looking at:
3. The number of voxels covered by brain edge areas above threshold t, and how many of them are correctly annotated
as 1, and how many of them are incorrectly annotated as 0

these metrics are best summarized as IOU or DICE scores. We look at an example segmentation below.
these metrics are best summarized as IOU or DICE scores. A DICE score curve can be obtained in training process,
automatically generated by nn-UNet. We look at an example segmentation below.

.. figure:: ../assets/mb_unmasked.png
:alt: Slice of mouse brain, unsegmented
Expand All @@ -129,4 +131,103 @@ these metrics are best summarized as IOU or DICE scores. We look at an example s
Here the algorithm, as intended, marks not only the outer edges of the brain but also some of the brighter inner
structures as edge areas to be removed, since they can't be plaques. The bright spots on the upper left of the images
are left as is, for they are all plaques. Overall, the annotation requires quite a bit of labour and it is preferred
to obtain a high quality annotated volume over many low quality ones.
to obtain a high quality annotated volume over many low quality ones.

In :code:`cvpl_tools`, the annotation is done using a Napari based GUI with a 2d cross-sectional viewer and
ball-shaped paint brush. Follow the following steps to get started:

1. In a Python script, prepare an image you would like to annotate :code:`im_annotate` in Numpy array format,
which may requires downsample the original image:

.. code-block:: Python
import cvpl_tools.nnunet.lightsheet_preprocess as lightsheet_preprocess
# original image is, say, an OME ZARR image of size (3, 1610, 9653, 9634)
OME_ZARR_PATH = 'gcs://khanlab-lightsheet/data/mouse_appmaptapoe/bids/sub-F4A1Te3/micr/sub-F4A1Te3_sample-brain_acq-blaze4x_SPIM.ome.zarr'
BA_CHANNEL = 0 # only the first channel is relevant to Beta-Amyloid detection
FIRST_DOWNSAMPLE_PATH = 'o22/first_downsample.ome.zarr' # path to be saved
first_downsample = lightsheet_preprocess.downsample(
OME_ZARR_PATH, reduce_fn=np.max, ndownsample_level=(1, 2, 2), ba_channel=BA_CHANNEL,
write_loc=FIRST_DOWNSAMPLE_PATH
)
print(f'Shape of image after downsampling: {first_downsample.shape}')
Ideally the downsampled image should also go through n4 bias correction before the next step.

2. Next, convert the image you just downsampled to a numpy array, and use :code:`annotate` function to add
layers to a napari viewer and start annotation:

.. code-block:: Python
from cvpl_tools.nnunet.annotate import annotate
import cvpl_tools.ome_zarr.io as ome_io
import napari
viewer = napari.Viewer(ndisplay=2)
im_annotate = first_downsample.compute() # this is a numpy array, to be annotated
ndownsample_level = (1, 1, 1) # downsample by 2 ^ 1 on three axes
# image layer and canvas layer will be added here
annotate(viewer, im_annotate, 'o22/annotated.tiff', ndownsample_level)
viewer.show(block=True)
Note saving is manual, press :code:`ctrl+shift+s` to save what's annotated (which creates a tiff
file "o22/annotated.tiff"). :code:`im_annotate` is lightsheet image first corrected by bias,
then downsampled by levels (1, 2, 2) i.e. a factor of (2, 4, 4) in three directions to a size
that can be conveniently displayed locally, in real-time and without latency.

In this example, we choose to use a binary annotation volume of shape (2, 2, 2) times smaller than the
original image in all three directions. This is to save space during data transfer. Later nn-UNet will
also need image of same shape as the annotation, so we also want to keep a further downsampled image
file that is the same size as the annotation. We will see this in the training section below.

3. Due to the large image size, you may need multiple sessions in order to completely annotate one
scan. This can be done by running the same code in step 2, which will automatically load the annotation
back up, and you can overwrite the old tiff file with updated annotation by, again, :code:`ctrl+shift+s`

Training
********

In the above annotation phase, we obtained two dataset: one is the annotated tiff volume at path
:code:`'o22/annotated.tiff'`, the other is the downsampled image at path 'o22/first_downsample.ome.zarr'. We
will use the latter as the training images and the former as the training labels for nn-UNet training.
Here the images need to be once further downsampled in order to match image and label volume shapes:

.. code-block:: Python
import cvpl_tools.nnunet.lightsheet_preprocess as lightsheet_preprocess
FIRST_DOWNSAMPLE_PATH = 'o22/first_downsample.ome.zarr' # path to be saved
SECOND_DOWNSAMPLE_PATH = 'o22/second_downsample.ome.zarr'
second_downsample = lightsheet_preprocess.downsample(
FIRST_DOWNSAMPLE_PATH, reduce_fn=np.max, ndownsample_level=(1, 1, 1), ba_channel=BA_CHANNEL,
write_loc=SECOND_DOWNSAMPLE_PATH
)
Next, we feed the images to nn-UNet for training. This requires torch installation and a GPU on the
computer.

.. code-block:: Python
import cvpl_tools.nnunet.triplanar as triplanar
train_args = {
"cache_url": 'nnunet_trained', # this is the path to which training files and trained model will be saved
"train_im": SECOND_DOWNSAMPLE_PATH, # image
"train_seg": 'o22/annotated.tiff', # label
"nepoch": 250,
"stack_channels": 0,
"triplanar": False,
"dataset_id": 1,
"fold": '0',
"max_threshold": 7500.,
}
triplanar.train_triplanar(train_args)
Prediction
**********

TODO
12 changes: 6 additions & 6 deletions src/cvpl_tools/examples/mousebrain_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,25 +99,25 @@ def get_subject(SUBJECT_ID):

def main(subject: Subject, run_nnunet: bool = True, run_coiled_process: bool = True):
import numpy as np
import cvpl_tools.nnunet.lightsheet_preprocess as current_im_py
import cvpl_tools.nnunet.lightsheet_preprocess as lightsheet_preprocess
import cvpl_tools.nnunet.n4 as n4
import cvpl_tools.ome_zarr.io as ome_io
import cvpl_tools.im.algs.dask_ndinterp as dask_ndinterp
import asyncio
import cvpl_tools.nnunet.triplanar as triplanar

print(f'first downsample: from path {subject.OME_ZARR_PATH}')
first_downsample = current_im_py.downsample(
first_downsample = lightsheet_preprocess.downsample(
subject.OME_ZARR_PATH, reduce_fn=np.max, ndownsample_level=(1, 2, 2), ba_channel=subject.BA_CHANNEL,
write_loc=subject.FIRST_DOWNSAMPLE_PATH
)
print(f'first downsample done. result is of shape {first_downsample.shape}')

second_downsample = current_im_py.downsample(
second_downsample = lightsheet_preprocess.downsample(
first_downsample, reduce_fn=np.max, ndownsample_level=(1,) * 3,
write_loc=subject.SECOND_DOWNSAMPLE_PATH
)
third_downsample = current_im_py.downsample(
third_downsample = lightsheet_preprocess.downsample(
second_downsample, reduce_fn=np.max, ndownsample_level=(1,) * 3,
write_loc=subject.THIRD_DOWNSAMPLE_PATH
)
Expand All @@ -131,15 +131,15 @@ def main(subject: Subject, run_nnunet: bool = True, run_coiled_process: bool = T
second_downsample_bias = dask_ndinterp.scale_nearest(third_downsample_bias, scale=(2, 2, 2),
output_shape=second_downsample.shape, output_chunks=(4, 4096, 4096)).persist()

second_downsample_corr = current_im_py.apply_bias(second_downsample, (1,) * 3, second_downsample_bias, (1,) * 3)
second_downsample_corr = lightsheet_preprocess.apply_bias(second_downsample, (1,) * 3, second_downsample_bias, (1,) * 3)
asyncio.run(ome_io.write_ome_zarr_image(subject.SECOND_DOWNSAMPLE_CORR_PATH, da_arr=second_downsample_corr, MAX_LAYER=1))
print('second downsample corrected image done')

# first_downsample_correct_path = f'C:/Users/than83/Documents/progtools/datasets/lightsheet_downsample/sub-{SUBJECT_ID}_corrected.ome.zarr'
# first_downsample_bias = dask_ndinterp.scale_nearest(third_downsample_bias, scale=(4, 4, 4),
# output_shape=first_downsample.shape,
# output_chunks=(4, 4096, 4096)).persist()
# first_downsample_corr = current_im_py.apply_bias(first_downsample, (1,) * 3, first_downsample_bias, (1,) * 3)
# first_downsample_corr = lightsheet_preprocess.apply_bias(first_downsample, (1,) * 3, first_downsample_bias, (1,) * 3)
# asyncio.run(ome_io.write_ome_zarr_image(first_downsample_correct_path, da_arr=first_downsample_corr, MAX_LAYER=2))

if run_nnunet is False:
Expand Down
4 changes: 2 additions & 2 deletions src/cvpl_tools/examples/mousebrain_processing_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ def annotate_neg_mask(SUBJECT_ID):
print('second downsample corrected image done')

viewer = napari.Viewer(ndisplay=2)
canvas_shape = ome_io.load_dask_array_from_path(subject.SECOND_DOWNSAMPLE_PATH, mode='r', level=0).shape
annotate.annotate(viewer,
first_downsample_corr,
annotation_folder=subject.SUBJECT_FOLDER,
canvas_path=subject.NNUNET_OUTPUT_TIFF_PATH,
SUBJECT_ID=subject.SUBJECT_ID)
canvas_shape=canvas_shape)
viewer.show(block=True)


Expand Down
27 changes: 20 additions & 7 deletions src/cvpl_tools/nnunet/annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,32 @@ def get_canvas(canvas_path, canvas_ref_path, canvas_shape):
return canvas


def annotate(viewer, im_annotate, annotation_folder, canvas_path, SUBJECT_ID: str):
"""
usage:
import cvpl_tools.nnunet.annotate as ann
ann.annotate()
def annotate(viewer, im_annotate, canvas_path, ndownsample_level: int | tuple = None):
"""Produce a new annotation mask, or fix an existing one if one already exists under canvas_path
canvas refers to the annotated binary mask, which is painted manually using 2d/3d brushes in Napari viewer.
Args:
viewer: napari.Viewer object that will be responsible for GUI display, and annotation using paint brush
im_annotate: Single channel 3d image volume to be annotated
canvas_path: Path to which the annotated tiff file will be or has been saved to
ndownsample_level: Downsample level from im_annotate to the canvas; integer or tuple of 3 integers
"""

import magicgui
import cvpl_tools.nnunet.lightsheet_preprocess as lightsheet_preprocess

im_layer = viewer.add_image(im_annotate, name='im', **lightsheet_preprocess.calc_tr_sc_args(voxel_scale=(1,) * 3, display_shape=im_annotate.shape))

canvas = get_canvas(canvas_path, None, im_annotate.shape)
canvas_layer = viewer.add_labels(canvas, name='canvas', **lightsheet_preprocess.calc_tr_sc_args(voxel_scale=(2,) * 3, display_shape=im_annotate.shape))
if ndownsample_level is not None:
if isinstance(ndownsample_level, int):
ndownsample_level = (ndownsample_level,) * 3
canvas_shape = tuple(im_annotate.shape[i] // (2 ** ndownsample_level[i]) for i in range(3))
else:
canvas_shape = im_annotate.shape
canvas = get_canvas(canvas_path, None, canvas_shape)
canvas_layer = viewer.add_labels(canvas, name='canvas', **lightsheet_preprocess.calc_tr_sc_args(
voxel_scale=(2,) * 3, display_shape=im_annotate.shape))

for path in tuple(
# 'C:/Users/than83/Documents/progtools/datasets/annotated/canvas_o22_ref.tiff',
Expand Down
2 changes: 1 addition & 1 deletion src/cvpl_tools/nnunet/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def parse_args():
"fold": args.fold,
"max_threshold": args.max_threshold,
}
model = triplanar.train_triplanar(train_args)
triplanar.train_triplanar(train_args)

elif args.command == "predict":
pred_args = {
Expand Down

0 comments on commit ff6a6b7

Please sign in to comment.