From ff6a6b72752250ff9854490a1208eb89af8b7cbc Mon Sep 17 00:00:00 2001 From: Karl5766 Date: Wed, 11 Dec 2024 16:01:29 -0500 Subject: [PATCH] update docs for nnunet training --- docs/GettingStarted/nnunet.rst | 111 +++++++++++++++++- .../examples/mousebrain_processing.py | 12 +- .../examples/mousebrain_processing_inspect.py | 4 +- src/cvpl_tools/nnunet/annotate.py | 27 +++-- src/cvpl_tools/nnunet/cli.py | 2 +- 5 files changed, 135 insertions(+), 21 deletions(-) diff --git a/docs/GettingStarted/nnunet.rst b/docs/GettingStarted/nnunet.rst index 3465b18..5bd3f51 100644 --- a/docs/GettingStarted/nnunet.rst +++ b/docs/GettingStarted/nnunet.rst @@ -6,7 +6,7 @@ nn-UNet Overview ******** -nn-UNet is a UNet based library designed to segment medical images, refer to +nn-UNet is a 2d/3d U-NET library designed to segment medical images, refer to `github `_ and the following citation: - Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring @@ -100,8 +100,9 @@ In this next part, we discuss the annotation (part 2), training (part 3) and pre Annotation ********** -Data quality is the most crucial to accurate predictions, in which case this is relevant to us in terms of how -well we can annotate the 3d image volume at hand. Our annotation is the negative masking of edge areas of the +Data quality is the most crucial to accurate predictions when training supervised models, in which case this is +relevant to us in terms of how well we can annotate 3d image volumes at hand. +Our annotation is the negative masking of edge areas of the brain to remove edges before applying simple thresholding. We model how good an annotation of negative mask by looking at: @@ -114,7 +115,8 @@ looking at: 3. The number of voxels covered by brain edge areas above threshold t, and how many of them are correctly annotated as 1, and how many of them are incorrectly annotated as 0 -these metrics are best summarized as IOU or DICE scores. We look at an example segmentation below. +these metrics are best summarized as IOU or DICE scores. A DICE score curve can be obtained in training process, +automatically generated by nn-UNet. We look at an example segmentation below. .. figure:: ../assets/mb_unmasked.png :alt: Slice of mouse brain, unsegmented @@ -129,4 +131,103 @@ these metrics are best summarized as IOU or DICE scores. We look at an example s Here the algorithm, as intended, marks not only the outer edges of the brain but also some of the brighter inner structures as edge areas to be removed, since they can't be plaques. The bright spots on the upper left of the images are left as is, for they are all plaques. Overall, the annotation requires quite a bit of labour and it is preferred -to obtain a high quality annotated volume over many low quality ones. \ No newline at end of file +to obtain a high quality annotated volume over many low quality ones. + +In :code:`cvpl_tools`, the annotation is done using a Napari based GUI with a 2d cross-sectional viewer and +ball-shaped paint brush. Follow the following steps to get started: + +1. In a Python script, prepare an image you would like to annotate :code:`im_annotate` in Numpy array format, + which may requires downsample the original image: + +.. code-block:: Python + + import cvpl_tools.nnunet.lightsheet_preprocess as lightsheet_preprocess + + # original image is, say, an OME ZARR image of size (3, 1610, 9653, 9634) + OME_ZARR_PATH = 'gcs://khanlab-lightsheet/data/mouse_appmaptapoe/bids/sub-F4A1Te3/micr/sub-F4A1Te3_sample-brain_acq-blaze4x_SPIM.ome.zarr' + BA_CHANNEL = 0 # only the first channel is relevant to Beta-Amyloid detection + + FIRST_DOWNSAMPLE_PATH = 'o22/first_downsample.ome.zarr' # path to be saved + first_downsample = lightsheet_preprocess.downsample( + OME_ZARR_PATH, reduce_fn=np.max, ndownsample_level=(1, 2, 2), ba_channel=BA_CHANNEL, + write_loc=FIRST_DOWNSAMPLE_PATH + ) + print(f'Shape of image after downsampling: {first_downsample.shape}') + +Ideally the downsampled image should also go through n4 bias correction before the next step. + +2. Next, convert the image you just downsampled to a numpy array, and use :code:`annotate` function to add + layers to a napari viewer and start annotation: + +.. code-block:: Python + + from cvpl_tools.nnunet.annotate import annotate + import cvpl_tools.ome_zarr.io as ome_io + import napari + + viewer = napari.Viewer(ndisplay=2) + im_annotate = first_downsample.compute() # this is a numpy array, to be annotated + ndownsample_level = (1, 1, 1) # downsample by 2 ^ 1 on three axes + + # image layer and canvas layer will be added here + annotate(viewer, im_annotate, 'o22/annotated.tiff', ndownsample_level) + + viewer.show(block=True) + +Note saving is manual, press :code:`ctrl+shift+s` to save what's annotated (which creates a tiff +file "o22/annotated.tiff"). :code:`im_annotate` is lightsheet image first corrected by bias, +then downsampled by levels (1, 2, 2) i.e. a factor of (2, 4, 4) in three directions to a size +that can be conveniently displayed locally, in real-time and without latency. + +In this example, we choose to use a binary annotation volume of shape (2, 2, 2) times smaller than the +original image in all three directions. This is to save space during data transfer. Later nn-UNet will +also need image of same shape as the annotation, so we also want to keep a further downsampled image +file that is the same size as the annotation. We will see this in the training section below. + +3. Due to the large image size, you may need multiple sessions in order to completely annotate one + scan. This can be done by running the same code in step 2, which will automatically load the annotation + back up, and you can overwrite the old tiff file with updated annotation by, again, :code:`ctrl+shift+s` + +Training +******** + +In the above annotation phase, we obtained two dataset: one is the annotated tiff volume at path +:code:`'o22/annotated.tiff'`, the other is the downsampled image at path 'o22/first_downsample.ome.zarr'. We +will use the latter as the training images and the former as the training labels for nn-UNet training. +Here the images need to be once further downsampled in order to match image and label volume shapes: + +.. code-block:: Python + + import cvpl_tools.nnunet.lightsheet_preprocess as lightsheet_preprocess + + FIRST_DOWNSAMPLE_PATH = 'o22/first_downsample.ome.zarr' # path to be saved + SECOND_DOWNSAMPLE_PATH = 'o22/second_downsample.ome.zarr' + second_downsample = lightsheet_preprocess.downsample( + FIRST_DOWNSAMPLE_PATH, reduce_fn=np.max, ndownsample_level=(1, 1, 1), ba_channel=BA_CHANNEL, + write_loc=SECOND_DOWNSAMPLE_PATH + ) + +Next, we feed the images to nn-UNet for training. This requires torch installation and a GPU on the +computer. + +.. code-block:: Python + + import cvpl_tools.nnunet.triplanar as triplanar + + train_args = { + "cache_url": 'nnunet_trained', # this is the path to which training files and trained model will be saved + "train_im": SECOND_DOWNSAMPLE_PATH, # image + "train_seg": 'o22/annotated.tiff', # label + "nepoch": 250, + "stack_channels": 0, + "triplanar": False, + "dataset_id": 1, + "fold": '0', + "max_threshold": 7500., + } + triplanar.train_triplanar(train_args) + +Prediction +********** + +TODO diff --git a/src/cvpl_tools/examples/mousebrain_processing.py b/src/cvpl_tools/examples/mousebrain_processing.py index 7f9c3b0..53f62da 100644 --- a/src/cvpl_tools/examples/mousebrain_processing.py +++ b/src/cvpl_tools/examples/mousebrain_processing.py @@ -99,7 +99,7 @@ def get_subject(SUBJECT_ID): def main(subject: Subject, run_nnunet: bool = True, run_coiled_process: bool = True): import numpy as np - import cvpl_tools.nnunet.lightsheet_preprocess as current_im_py + import cvpl_tools.nnunet.lightsheet_preprocess as lightsheet_preprocess import cvpl_tools.nnunet.n4 as n4 import cvpl_tools.ome_zarr.io as ome_io import cvpl_tools.im.algs.dask_ndinterp as dask_ndinterp @@ -107,17 +107,17 @@ def main(subject: Subject, run_nnunet: bool = True, run_coiled_process: bool = T import cvpl_tools.nnunet.triplanar as triplanar print(f'first downsample: from path {subject.OME_ZARR_PATH}') - first_downsample = current_im_py.downsample( + first_downsample = lightsheet_preprocess.downsample( subject.OME_ZARR_PATH, reduce_fn=np.max, ndownsample_level=(1, 2, 2), ba_channel=subject.BA_CHANNEL, write_loc=subject.FIRST_DOWNSAMPLE_PATH ) print(f'first downsample done. result is of shape {first_downsample.shape}') - second_downsample = current_im_py.downsample( + second_downsample = lightsheet_preprocess.downsample( first_downsample, reduce_fn=np.max, ndownsample_level=(1,) * 3, write_loc=subject.SECOND_DOWNSAMPLE_PATH ) - third_downsample = current_im_py.downsample( + third_downsample = lightsheet_preprocess.downsample( second_downsample, reduce_fn=np.max, ndownsample_level=(1,) * 3, write_loc=subject.THIRD_DOWNSAMPLE_PATH ) @@ -131,7 +131,7 @@ def main(subject: Subject, run_nnunet: bool = True, run_coiled_process: bool = T second_downsample_bias = dask_ndinterp.scale_nearest(third_downsample_bias, scale=(2, 2, 2), output_shape=second_downsample.shape, output_chunks=(4, 4096, 4096)).persist() - second_downsample_corr = current_im_py.apply_bias(second_downsample, (1,) * 3, second_downsample_bias, (1,) * 3) + second_downsample_corr = lightsheet_preprocess.apply_bias(second_downsample, (1,) * 3, second_downsample_bias, (1,) * 3) asyncio.run(ome_io.write_ome_zarr_image(subject.SECOND_DOWNSAMPLE_CORR_PATH, da_arr=second_downsample_corr, MAX_LAYER=1)) print('second downsample corrected image done') @@ -139,7 +139,7 @@ def main(subject: Subject, run_nnunet: bool = True, run_coiled_process: bool = T # first_downsample_bias = dask_ndinterp.scale_nearest(third_downsample_bias, scale=(4, 4, 4), # output_shape=first_downsample.shape, # output_chunks=(4, 4096, 4096)).persist() - # first_downsample_corr = current_im_py.apply_bias(first_downsample, (1,) * 3, first_downsample_bias, (1,) * 3) + # first_downsample_corr = lightsheet_preprocess.apply_bias(first_downsample, (1,) * 3, first_downsample_bias, (1,) * 3) # asyncio.run(ome_io.write_ome_zarr_image(first_downsample_correct_path, da_arr=first_downsample_corr, MAX_LAYER=2)) if run_nnunet is False: diff --git a/src/cvpl_tools/examples/mousebrain_processing_inspect.py b/src/cvpl_tools/examples/mousebrain_processing_inspect.py index f09613b..77becbb 100644 --- a/src/cvpl_tools/examples/mousebrain_processing_inspect.py +++ b/src/cvpl_tools/examples/mousebrain_processing_inspect.py @@ -112,11 +112,11 @@ def annotate_neg_mask(SUBJECT_ID): print('second downsample corrected image done') viewer = napari.Viewer(ndisplay=2) + canvas_shape = ome_io.load_dask_array_from_path(subject.SECOND_DOWNSAMPLE_PATH, mode='r', level=0).shape annotate.annotate(viewer, first_downsample_corr, - annotation_folder=subject.SUBJECT_FOLDER, canvas_path=subject.NNUNET_OUTPUT_TIFF_PATH, - SUBJECT_ID=subject.SUBJECT_ID) + canvas_shape=canvas_shape) viewer.show(block=True) diff --git a/src/cvpl_tools/nnunet/annotate.py b/src/cvpl_tools/nnunet/annotate.py index c1df1f8..a073564 100644 --- a/src/cvpl_tools/nnunet/annotate.py +++ b/src/cvpl_tools/nnunet/annotate.py @@ -24,19 +24,32 @@ def get_canvas(canvas_path, canvas_ref_path, canvas_shape): return canvas -def annotate(viewer, im_annotate, annotation_folder, canvas_path, SUBJECT_ID: str): - """ - usage: - import cvpl_tools.nnunet.annotate as ann - ann.annotate() +def annotate(viewer, im_annotate, canvas_path, ndownsample_level: int | tuple = None): + """Produce a new annotation mask, or fix an existing one if one already exists under canvas_path + + canvas refers to the annotated binary mask, which is painted manually using 2d/3d brushes in Napari viewer. + + Args: + viewer: napari.Viewer object that will be responsible for GUI display, and annotation using paint brush + im_annotate: Single channel 3d image volume to be annotated + canvas_path: Path to which the annotated tiff file will be or has been saved to + ndownsample_level: Downsample level from im_annotate to the canvas; integer or tuple of 3 integers """ + import magicgui import cvpl_tools.nnunet.lightsheet_preprocess as lightsheet_preprocess im_layer = viewer.add_image(im_annotate, name='im', **lightsheet_preprocess.calc_tr_sc_args(voxel_scale=(1,) * 3, display_shape=im_annotate.shape)) - canvas = get_canvas(canvas_path, None, im_annotate.shape) - canvas_layer = viewer.add_labels(canvas, name='canvas', **lightsheet_preprocess.calc_tr_sc_args(voxel_scale=(2,) * 3, display_shape=im_annotate.shape)) + if ndownsample_level is not None: + if isinstance(ndownsample_level, int): + ndownsample_level = (ndownsample_level,) * 3 + canvas_shape = tuple(im_annotate.shape[i] // (2 ** ndownsample_level[i]) for i in range(3)) + else: + canvas_shape = im_annotate.shape + canvas = get_canvas(canvas_path, None, canvas_shape) + canvas_layer = viewer.add_labels(canvas, name='canvas', **lightsheet_preprocess.calc_tr_sc_args( + voxel_scale=(2,) * 3, display_shape=im_annotate.shape)) for path in tuple( # 'C:/Users/than83/Documents/progtools/datasets/annotated/canvas_o22_ref.tiff', diff --git a/src/cvpl_tools/nnunet/cli.py b/src/cvpl_tools/nnunet/cli.py index 28e6180..753407a 100644 --- a/src/cvpl_tools/nnunet/cli.py +++ b/src/cvpl_tools/nnunet/cli.py @@ -63,7 +63,7 @@ def parse_args(): "fold": args.fold, "max_threshold": args.max_threshold, } - model = triplanar.train_triplanar(train_args) + triplanar.train_triplanar(train_args) elif args.command == "predict": pred_args = {