Skip to content

Commit 02106a1

Browse files
林旻佑林旻佑
authored andcommitted
Fix #8366: Add strict shape validation to sliding_window_inference
1 parent e267705 commit 02106a1

File tree

2 files changed

+81
-2
lines changed

2 files changed

+81
-2
lines changed

monai/inferers/utils.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,51 @@
3333
optional_import,
3434
)
3535

36+
3637
tqdm, _ = optional_import("tqdm", name="tqdm")
3738
_nearest_mode = "nearest-exact"
3839

3940
__all__ = ["sliding_window_inference"]
4041

4142

43+
44+
def assert_channel_first(
45+
t: torch.Tensor,
46+
name: str,
47+
num_classes: Optional[int] = None,
48+
allow_binary_two_channel: bool = False,
49+
) -> None:
50+
"""
51+
Enforce channel-first layout without guessing.
52+
Accepts only:
53+
- 4D: NCHW (channel at dim=1)
54+
- 5D: NCDHW (channel at dim=1)
55+
If not satisfied, raise with a clear message asking users to apply
56+
EnsureChannelFirst / EnsureChannelFirstd upstream.
57+
"""
58+
if not isinstance(t, torch.Tensor):
59+
return
60+
if t.ndim not in (4, 5):
61+
return
62+
63+
c = int(t.shape[1])
64+
layout = "NCHW" if t.ndim == 4 else "NCDHW"
65+
layout_last = "NHWC" if t.ndim == 4 else "NDHWC"
66+
67+
if num_classes is not None:
68+
ok = (c == num_classes) or (num_classes == 1 and c == 1)
69+
if allow_binary_two_channel and num_classes == 2:
70+
ok = ok or (c == 2)
71+
if not ok:
72+
raise ValueError(
73+
f"{name}: expected {layout} with C(dim=1)==num_classes, "
74+
f"but got shape={tuple(t.shape)} (C={c}) and num_classes={num_classes}. "
75+
f"If your data is {layout_last}, please apply EnsureChannelFirst/EnsureChannelFirstd upstream."
76+
)
77+
# No guessing when num_classes is None; we simply require channel at dim=1.
78+
# If callers provided NHWC/NDHWC, they must convert upstream.
79+
80+
4281
def sliding_window_inference(
4382
inputs: torch.Tensor | MetaTensor,
4483
roi_size: Sequence[int] | int,
@@ -131,9 +170,29 @@ def sliding_window_inference(
131170
kwargs: optional keyword args to be passed to ``predictor``.
132171
133172
Note:
134-
- input must be channel-first and have a batch dim, supports N-D sliding window.
135-
173+
- Inputs must be channel-first and have a batch dim (NCHW / NCDHW).
174+
- If your data is NHWC/NDHWC, please apply `EnsureChannelFirst` / `EnsureChannelFirstd` upstream.
175+
136176
"""
177+
num_spatial_dims = len(inputs.shape) - 2
178+
179+
# Only perform strict shape validation if roi_size is a sequence (explicit dimensions).
180+
# If roi_size is an integer, it is broadcast to all dimensions, so we cannot
181+
# infer the expected dimensionality to enforce a strict check here.
182+
if not isinstance(roi_size, int):
183+
roi_dims = len(roi_size)
184+
if num_spatial_dims != roi_dims:
185+
raise ValueError(
186+
f"inputs must have {roi_dims + 2} dimensions for {roi_dims}D roi_size "
187+
f"(Batch, Channel, {', '.join(['Spatial'] * roi_dims)}), "
188+
f"but got inputs shape {inputs.shape}.\n"
189+
"If you have channel-last data (e.g. B, D, H, W, C), please use "
190+
"monai.transforms.EnsureChannelFirst or EnsureChannelFirstd upstream."
191+
)
192+
# -----------------------------------------------------------------
193+
# ---- Strict validation: do NOT guess or permute layouts ----
194+
if isinstance(inputs, torch.Tensor):
195+
assert_channel_first(inputs, "inputs")
137196
buffered = buffer_steps is not None and buffer_steps > 0
138197
num_spatial_dims = len(inputs.shape) - 2
139198
if buffered:

tests/inferers/test_sliding_window_inference.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,26 @@ def compute_dict(data):
372372
for rr, _ in zip(result_dict, expected_dict):
373373
np.testing.assert_allclose(result_dict[rr].cpu().numpy(), expected_dict[rr], rtol=1e-4)
374374

375+
def test_strict_shape_validation(self):
376+
"""Test strict shape validation to ensure inputs match roi_size dimensions."""
377+
device = "cpu"
378+
roi_size = (16, 16, 16)
379+
sw_batch_size = 4
380+
381+
def predictor(data):
382+
return data
383+
384+
# Case 1: Input has fewer dimensions than expected (e.g., missing Batch or Channel)
385+
# 3D roi_size requires 5D input (B, C, D, H, W), giving 4D here.
386+
inputs_4d = torch.randn((1, 16, 16, 16), device=device)
387+
with self.assertRaisesRegex(ValueError, "inputs must have 5 dimensions"):
388+
sliding_window_inference(inputs_4d, roi_size, sw_batch_size, predictor)
389+
390+
# Case 2: Input is 3D (missing Batch AND Channel)
391+
inputs_3d = torch.randn((16, 16, 16), device=device)
392+
with self.assertRaisesRegex(ValueError, "inputs must have 5 dimensions"):
393+
sliding_window_inference(inputs_3d, roi_size, sw_batch_size, predictor)
394+
375395

376396
class TestSlidingWindowInferenceCond(unittest.TestCase):
377397
@parameterized.expand(TEST_CASES)

0 commit comments

Comments
 (0)