Skip to content

Commit

Permalink
Merge pull request #112 from wolny/prediction_options
Browse files Browse the repository at this point in the history
Add 'save_segmentation' option to prediction
  • Loading branch information
wolny authored Apr 9, 2024
2 parents 63a9e63 + a468a34 commit 7a2c00b
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 77 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ predict3dunet --config <CONFIG>
In order to predict on your own data, just provide the path to your model as well as paths to HDF5 test files (see example [test_config_segmentation.yaml](resources/3DUnet_confocal_boundary/test_config.yml)).

### Prediction tips
In order to avoid patch boundary artifacts in the output prediction masks the patch predictions are averaged, so make sure that `patch/stride` params lead to overlapping blocks, e.g. `patch: [64, 128, 128] stride: [32, 96, 96]` will give you a 'halo' of 32 voxels in each direction.
1. In order to avoid patch boundary artifacts in the output prediction masks the patch predictions are averaged, so make sure that `patch/stride` params lead to overlapping blocks, e.g. `patch: [64, 128, 128] stride: [32, 96, 96]` will give you a 'halo' of 32 voxels in each direction.
2. If your model predicts multiple classes (see e.g. [train_config_multiclass](resources/3DUnet_multiclass/train_config.yaml)), consider saving only the final segmentation instead of the probability maps which can be time and space consuming.
To do so, set `save_segmentation: true` in the `predictor` section of the config (see [test_config_multiclass](resources/3DUnet_multiclass/test_config.yaml)).

## Data Parallelism
By default, if multiple GPUs are available training/prediction will be run on all the GPUs using [DataParallel](https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html).
Expand Down
19 changes: 10 additions & 9 deletions pytorch3dunet/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,20 @@
logger = utils.get_logger('UNet3DPredict')


def get_predictor(model, output_dir, config):
def get_predictor(model, config):
output_dir = config['loaders'].get('output_dir', None)
# override output_dir if provided in the 'predictor' section of the config
output_dir = config.get('predictor', {}).get('output_dir', output_dir)
if output_dir is not None:
os.makedirs(output_dir, exist_ok=True)

predictor_config = config.get('predictor', {})
class_name = predictor_config.get('name', 'StandardPredictor')

m = importlib.import_module('pytorch3dunet.unet3d.predictor')
predictor_class = getattr(m, class_name)

return predictor_class(model, output_dir, config, **predictor_config)
out_channels = config['model'].get('out_channels')
return predictor_class(model, output_dir, out_channels, **predictor_config)


def main():
Expand All @@ -41,13 +47,8 @@ def main():
if torch.cuda.is_available() and not config['device'] == 'cpu':
model = model.cuda()

output_dir = config['loaders'].get('output_dir', None)
if output_dir is not None:
os.makedirs(output_dir, exist_ok=True)
logger.info(f'Saving predictions to: {output_dir}')

# create predictor instance
predictor = get_predictor(model, output_dir, config)
predictor = get_predictor(model, config)

for test_loader in get_test_loaders(config):
# run the model prediction on the test_loader and save the results in the output_dir
Expand Down
151 changes: 88 additions & 63 deletions pytorch3dunet/unet3d/predictor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import time
from concurrent import futures
from pathlib import Path

import h5py
import numpy as np
Expand All @@ -21,12 +22,8 @@ def _get_output_file(dataset, suffix='_predictions', output_dir=None):
input_dir, file_name = os.path.split(dataset.file_path)
if output_dir is None:
output_dir = input_dir
output_file = os.path.join(output_dir, os.path.splitext(file_name)[0] + suffix + '.h5')
return output_file


def _get_dataset_name(config, prefix='predictions'):
return config.get('dest_dataset_name', 'predictions')
output_filename = os.path.splitext(file_name)[0] + suffix + '.h5'
return Path(output_dir) / output_filename


def _is_2d_model(model):
Expand All @@ -36,11 +33,35 @@ def _is_2d_model(model):


class _AbstractPredictor:
def __init__(self, model, output_dir, config, **kwargs):
def __init__(self,
model: nn.Module,
output_dir: str,
out_channels: int,
output_dataset: str = 'predictions',
save_segmentation: bool = False,
prediction_channel: int = None,
patch_halo: tuple[int, int, int] = (4, 4, 4),
**kwargs):
"""
Base class for predictors.
Args:
model: segmentation model
output_dir: directory where the predictions will be saved
out_channels: number of output channels of the model
output_dataset: name of the dataset in the H5 file where the predictions will be saved
save_segmentation: if true the segmentation will be saved instead of the probability maps
prediction_channel: save only the specified channel from the network output
"""
self.model = model
self.output_dir = output_dir
self.config = config
self.predictor_config = kwargs
self.out_channels = out_channels
self.output_dataset = output_dataset
self.save_segmentation = save_segmentation
self.prediction_channel = prediction_channel
# evey patch will be mirror-padded with the following halo
self.patch_halo = list(patch_halo)
if _is_2d_model(self.model):
self.patch_halo[0] = 0

@staticmethod
def volume_shape(dataset):
Expand All @@ -60,45 +81,33 @@ class StandardPredictor(_AbstractPredictor):
Predictions from the network are kept in memory. If the results from the network don't fit in into RAM
use `LazyPredictor` instead.
The output dataset names inside the H5 is given by `dest_dataset_name` config argument. If the argument is
not present in the config 'predictions' is used as a default dataset name.
Args:
model (Unet3D): trained 3D UNet model used for prediction
output_dir (str): path to the output directory (optional)
config (dict): global config dict
The output dataset names inside the H5 is given by `output_dataset` config argument.
"""

def __init__(self, model, output_dir, config, **kwargs):
super().__init__(model, output_dir, config, **kwargs)
def __init__(self,
model: nn.Module,
output_dir: str,
out_channels: int,
output_dataset: str = 'predictions',
save_segmentation: bool = False,
prediction_channel: int = None,
**kwargs):
super().__init__(model, output_dir, out_channels, output_dataset, save_segmentation, prediction_channel,
**kwargs)

def __call__(self, test_loader):
assert isinstance(test_loader.dataset, AbstractHDF5Dataset)
logger.info(f"Processing '{test_loader.dataset.file_path}'...")
start = time.time()

prediction_channel = self.config.get('prediction_channel', None)
if prediction_channel is not None:
logger.info(f"Saving only channel '{prediction_channel}' from the network output")
start = time.perf_counter()

logger.info(f'Running inference on {len(test_loader)} batches')

# dimensionality of the output predictions
volume_shape = self.volume_shape(test_loader.dataset)
out_channels = self.config['model'].get('out_channels')
if prediction_channel is None:
prediction_maps_shape = (out_channels,) + volume_shape
else:
if self.prediction_channel is not None:
# single channel prediction map
prediction_maps_shape = (1,) + volume_shape

logger.info(f'The shape of the output prediction maps (CDHW): {prediction_maps_shape}')

# evey patch will be mirror-padded with the following halo
patch_halo = self.predictor_config.get('patch_halo', (4, 4, 4))
if _is_2d_model(self.model):
patch_halo = list(patch_halo)
patch_halo[0] = 0
else:
prediction_maps_shape = (self.out_channels,) + volume_shape

# create destination H5 file
output_file = _get_output_file(dataset=test_loader.dataset, output_dir=self.output_dir)
Expand All @@ -115,9 +124,10 @@ def __call__(self, test_loader):
for input, indices in tqdm(test_loader):
# send batch to gpu
if torch.cuda.is_available():
input = input.cuda(non_blocking=True)
input = input.pin_memory().cuda(non_blocking=True)

input = _pad(input, patch_halo)
# pad the input patch
input = _pad(input, self.patch_halo)

if _is_2d_model(self.model):
# remove the singleton z-dimension from the input
Expand All @@ -130,19 +140,19 @@ def __call__(self, test_loader):
# forward pass
prediction = self.model(input)

# unpad
prediction = _unpad(prediction, patch_halo)
# unpad the input patch
prediction = _unpad(prediction, self.patch_halo)
# convert to numpy array
prediction = prediction.cpu().numpy()
# for each batch sample
for pred, index in zip(prediction, indices):
# save patch index: (C,D,H,W)
if prediction_channel is None:
channel_slice = slice(0, out_channels)
if self.prediction_channel is None:
channel_slice = slice(0, self.out_channels)
else:
# use only the specified channel
channel_slice = slice(0, 1)
pred = np.expand_dims(pred[prediction_channel], axis=0)
pred = np.expand_dims(pred[self.prediction_channel], axis=0)

# add channel dimension to the index
index = (channel_slice,) + tuple(index)
Expand All @@ -151,9 +161,10 @@ def __call__(self, test_loader):
# count voxel visits for normalization
normalization_mask[index] += 1

logger.info(f'Finished inference in {time.time() - start:.2f} seconds')
logger.info(f'Finished inference in {time.perf_counter() - start:.2f} seconds')
# save results
logger.info(f'Saving predictions to: {output_file}')
output_type = 'segmentation' if self.save_segmentation else 'probability maps'
logger.info(f'Saving {output_type} to: {output_file}')
self._save_results(prediction_map, normalization_mask, h5_output_file, test_loader.dataset)
# close the output H5 file
h5_output_file.close()
Expand All @@ -166,9 +177,10 @@ def _allocate_prediction_maps(self, output_shape, output_file):
return prediction_map, normalization_mask

def _save_results(self, prediction_map, normalization_mask, output_file, dataset):
dataset_name = _get_dataset_name(self.config)
prediction_map = prediction_map / normalization_mask
output_file.create_dataset(dataset_name, data=prediction_map, compression="gzip")
result = prediction_map / normalization_mask
if self.save_segmentation:
result = np.argmax(result, axis=0).astype('uint16')
output_file.create_dataset(self.output_dataset, data=result, compression="gzip")


def _pad(m, patch_halo):
Expand All @@ -193,47 +205,60 @@ class LazyPredictor(StandardPredictor):
Applies the model on the given dataset and saves the result in the `output_file` in the H5 format.
Predicted patches are directly saved into the H5 and they won't be stored in memory. Since this predictor
is slower than the `StandardPredictor` it should only be used when the predicted volume does not fit into RAM.
The output dataset names inside the H5 is given by `des_dataset_name` config argument. If the argument is
not present in the config 'predictions{n}' is used as a default dataset name, where `n` denotes the number
of the output head from the network.
Args:
model (Unet3D): trained 3D UNet model used for prediction
output_dir (str): path to the output directory (optional)
config (dict): global config dict
"""

def __init__(self, model, output_dir, config, **kwargs):
super().__init__(model, output_dir, config, **kwargs)
def __init__(self,
model: nn.Module,
output_dir: str,
out_channels: int,
output_dataset: str = 'predictions',
save_segmentation: bool = False,
prediction_channel: int = None,
**kwargs):
super().__init__(model, output_dir, out_channels, output_dataset, save_segmentation, prediction_channel,
**kwargs)

def _allocate_prediction_maps(self, output_shape, output_file):
dataset_name = _get_dataset_name(self.config)
# allocate datasets for probability maps
prediction_map = output_file.create_dataset(dataset_name, shape=output_shape, dtype='float32', chunks=True,
prediction_map = output_file.create_dataset(self.output_dataset,
shape=output_shape,
dtype='float32',
chunks=True,
compression='gzip')
# allocate datasets for normalization masks
normalization_mask = output_file.create_dataset('normalization', shape=output_shape, dtype='uint8', chunks=True,
normalization_mask = output_file.create_dataset('normalization',
shape=output_shape,
dtype='uint8',
chunks=True,
compression='gzip')
return prediction_map, normalization_mask

def _save_results(self, prediction_map, normalization_mask, output_file, dataset):
z, y, x = prediction_map.shape[1:]
# take slices which are 1/27 of the original volume
patch_shape = (z // 3, y // 3, x // 3)
if self.save_segmentation:
output_file.create_dataset('segmentation', shape=(z, y, x), dtype='uint16', chunks=True, compression='gzip')

for index in SliceBuilder._build_slices(prediction_map, patch_shape=patch_shape, stride_shape=patch_shape):
logger.info(f'Normalizing slice: {index}')
prediction_map[index] /= normalization_mask[index]
# make sure to reset the slice that has been visited already in order to avoid 'double' normalization
# when the patches overlap with each other
normalization_mask[index] = 1
# save segmentation
if self.save_segmentation:
output_file['segmentation'][index[1:]] = np.argmax(prediction_map[index], axis=0).astype('uint16')

del output_file['normalization']
if self.save_segmentation:
del output_file[self.output_dataset]


class DSB2018Predictor(_AbstractPredictor):
def __init__(self, model, output_dir, config, save_segmentation=True, pmaps_thershold=0.5, **kwargs):
super().__init__(model, output_dir, config, **kwargs)
self.pmaps_thershold = pmaps_thershold
self.pmaps_threshold = pmaps_thershold
self.save_segmentation = save_segmentation

def _slice_from_pad(self, pad):
Expand Down
2 changes: 2 additions & 0 deletions resources/3DUnet_multiclass/test_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ model:
predictor:
# standard in memory predictor
name: 'StandardPredictor'
# save the output segmentation instead of probability maps, i.e. apply argmax to the output
save_segmentation: true
# specify the test datasets
loaders:
# batch dimension; if number of GPUs is N > 1, then a batch_size of N * batch_size will automatically be taken for DataParallel
Expand Down
27 changes: 23 additions & 4 deletions tests/test_predictor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pathlib import Path
from tempfile import NamedTemporaryFile

import h5py
Expand All @@ -24,19 +25,37 @@ def _run_prediction(test_config, tmpdir, shape):
if torch.cuda.is_available():
model.cuda()
for test_loader in get_test_loaders(test_config):
predictor = get_predictor(model, tmpdir, test_config)
predictor = get_predictor(model, test_config)
# run the model prediction on the entire dataset and save to the 'output_file' H5
predictor(test_loader)
return tmp


def _get_result_shape(result_path: str, dataset_name: str = 'predictions'):
with h5py.File(result_path, 'r') as f:
return f[dataset_name].shape


class TestPredictor:
def test_3d_predictor(self, tmpdir, test_config):
tmp = _run_prediction(test_config, tmpdir, shape=(32, 64, 64))

assert os.path.exists(os.path.join(tmpdir, os.path.split(tmp.name)[1] + '_predictions.h5'))
output_filename = os.path.split(tmp.name)[1] + '_predictions.h5'
output_path = Path(tmpdir) / output_filename
assert output_path.exists()
assert _get_result_shape(output_path) == (2, 32, 64, 64)

def test_2d_predictor(self, tmpdir, test_config_2d):
tmp = _run_prediction(test_config_2d, tmpdir, shape=(3, 1, 256, 256))
output_filename = os.path.split(tmp.name)[1] + '_predictions.h5'
output_path = Path(tmpdir) / output_filename
assert output_path.exists()
assert _get_result_shape(output_path) == (2, 1, 256, 256)

assert os.path.exists(os.path.join(tmpdir, os.path.split(tmp.name)[1] + '_predictions.h5'))
def test_predictor_save_segmentation(self, tmpdir, test_config):
test_config['predictor']['save_segmentation'] = True
tmp = _run_prediction(test_config, tmpdir, shape=(32, 64, 64))
output_filename = os.path.split(tmp.name)[1] + '_predictions.h5'
output_path = Path(tmpdir) / output_filename
assert output_path.exists()
# check that the output segmentation is saved, with the channel dimension reduced via argmax operation
assert _get_result_shape(output_path) == (32, 64, 64)

0 comments on commit 7a2c00b

Please sign in to comment.