Skip to content

Commit

Permalink
Merge remote-tracking branch 'ancestor-mithril/dev3'
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Apr 9, 2024
2 parents 3d99b8d + ab34ef4 commit 5f363b3
Showing 1 changed file with 10 additions and 18 deletions.
28 changes: 10 additions & 18 deletions nnunetv2/preprocessing/cropping/cropping.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np

from scipy.ndimage import binary_fill_holes

# Hello! crop_to_nonzero is the function you are looking for. Ignore the rest.
from acvl_utils.cropping_and_padding.bounding_boxes import get_bbox_from_mask, crop_to_bbox, bounding_box_to_slice
Expand All @@ -11,14 +11,11 @@ def create_nonzero_mask(data):
:param data:
:return: the mask is True where the data is nonzero
"""
from scipy.ndimage import binary_fill_holes
assert data.ndim in (3, 4), "data must have shape (C, X, Y, Z) or shape (C, X, Y)"
nonzero_mask = np.zeros(data.shape[1:], dtype=bool)
for c in range(data.shape[0]):
this_mask = data[c] != 0
nonzero_mask = nonzero_mask | this_mask
nonzero_mask = binary_fill_holes(nonzero_mask)
return nonzero_mask
nonzero_mask = data[0] != 0
for c in range(1, data.shape[0]):
nonzero_mask |= data[c] != 0
return binary_fill_holes(nonzero_mask)


def crop_to_nonzero(data, seg=None, nonzero_label=-1):
Expand All @@ -31,21 +28,16 @@ def crop_to_nonzero(data, seg=None, nonzero_label=-1):
"""
nonzero_mask = create_nonzero_mask(data)
bbox = get_bbox_from_mask(nonzero_mask)

slicer = bounding_box_to_slice(bbox)
data = data[tuple([slice(None), *slicer])]

if seg is not None:
seg = seg[tuple([slice(None), *slicer])]

nonzero_mask = nonzero_mask[slicer][None]

slicer = (slice(None), ) + slicer
data = data[slicer]
if seg is not None:
seg = seg[slicer]
seg[(seg == 0) & (~nonzero_mask)] = nonzero_label
else:
nonzero_mask = nonzero_mask.astype(np.int8)
nonzero_mask[nonzero_mask == 0] = nonzero_label
nonzero_mask[nonzero_mask > 0] = 0
seg = nonzero_mask
seg = np.where(nonzero_mask, np.int8(0), np.int8(nonzero_label))
return data, seg, bbox


0 comments on commit 5f363b3

Please sign in to comment.