From 11037477bc87b3e294fbb8b801d8275968d936ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E6=97=BB=E4=BD=91?= Date: Sun, 30 Nov 2025 14:09:02 +0800 Subject: [PATCH 1/3] Fix #8366: Add strict shape validation to sliding_window_inference MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 林旻佑 --- monai/inferers/utils.py | 20 +++++++++++++++++-- .../inferers/test_sliding_window_inference.py | 20 +++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 766486a807..c5baa2a522 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -131,11 +131,27 @@ def sliding_window_inference( kwargs: optional keyword args to be passed to ``predictor``. Note: - - input must be channel-first and have a batch dim, supports N-D sliding window. + - Inputs must be channel-first and have a batch dim (NCHW / NCDHW). + - If your data is NHWC/NDHWC, please apply `EnsureChannelFirst` / `EnsureChannelFirstd` upstream. """ - buffered = buffer_steps is not None and buffer_steps > 0 num_spatial_dims = len(inputs.shape) - 2 + + # Only perform strict shape validation if roi_size is a sequence (explicit dimensions). + # If roi_size is an integer, it is broadcast to all dimensions, so we cannot + # infer the expected dimensionality to enforce a strict check here. + if isinstance(roi_size, Sequence): + roi_dims = len(roi_size) + if num_spatial_dims != roi_dims: + raise ValueError( + f"inputs must have {roi_dims + 2} dimensions for {roi_dims}D roi_size " + f"(Batch, Channel, {', '.join(['Spatial'] * roi_dims)}), " + f"but got inputs shape {inputs.shape}.\n" + "If you have channel-last data (e.g. B, D, H, W, C), please use " + "monai.transforms.EnsureChannelFirst or EnsureChannelFirstd upstream." + ) + # ----------------------------------------------------------------- + buffered = buffer_steps is not None and buffer_steps > 0 if buffered: if buffer_dim < -num_spatial_dims or buffer_dim > num_spatial_dims: raise ValueError(f"buffer_dim must be in [{-num_spatial_dims}, {num_spatial_dims}], got {buffer_dim}.") diff --git a/tests/inferers/test_sliding_window_inference.py b/tests/inferers/test_sliding_window_inference.py index f97cbb9299..bada80d006 100644 --- a/tests/inferers/test_sliding_window_inference.py +++ b/tests/inferers/test_sliding_window_inference.py @@ -372,6 +372,26 @@ def compute_dict(data): for rr, _ in zip(result_dict, expected_dict): np.testing.assert_allclose(result_dict[rr].cpu().numpy(), expected_dict[rr], rtol=1e-4) + def test_strict_shape_validation(self): + """Test strict shape validation to ensure inputs match roi_size dimensions.""" + device = "cpu" + roi_size = (16, 16, 16) + sw_batch_size = 4 + + def predictor(data): + return data + + # Case 1: Input has fewer dimensions than expected (e.g., missing Batch or Channel) + # 3D roi_size requires 5D input (B, C, D, H, W), giving 4D here. + inputs_4d = torch.randn((1, 16, 16, 16), device=device) + with self.assertRaisesRegex(ValueError, "inputs must have 5 dimensions"): + sliding_window_inference(inputs_4d, roi_size, sw_batch_size, predictor) + + # Case 2: Input is 3D (missing Batch AND Channel) + inputs_3d = torch.randn((16, 16, 16), device=device) + with self.assertRaisesRegex(ValueError, "inputs must have 5 dimensions"): + sliding_window_inference(inputs_3d, roi_size, sw_batch_size, predictor) + class TestSlidingWindowInferenceCond(unittest.TestCase): @parameterized.expand(TEST_CASES) From 0b3ffae80eddd3a4427e18e2a44cae165268f0c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E6=97=BB=E4=BD=91?= Date: Sat, 6 Dec 2025 00:29:50 +0800 Subject: [PATCH 2/3] Fix: Capitalize error messages and update docstrings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 林旻佑 --- monai/inferers/utils.py | 8 ++++++-- tests/inferers/test_sliding_window_inference.py | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index c5baa2a522..3dfbc2032c 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -76,7 +76,8 @@ def sliding_window_inference( Args: inputs: input image to be processed (assuming NCHW[D]) - roi_size: the spatial window size for inferences. + roi_size: the spatial window size for inferences, this must be a single value or a tuple with values + for each spatial dimension (eg. 2 for 2D, 3 for 3D). When its components have None or non-positives, the corresponding inputs dimension will be used. if the components of the `roi_size` are non-positive values, the transform will use the corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted @@ -134,6 +135,9 @@ def sliding_window_inference( - Inputs must be channel-first and have a batch dim (NCHW / NCDHW). - If your data is NHWC/NDHWC, please apply `EnsureChannelFirst` / `EnsureChannelFirstd` upstream. + Raises: + ValueError: When the input dimensions do not match the expected dimensions based on ``roi_size``. + """ num_spatial_dims = len(inputs.shape) - 2 @@ -144,7 +148,7 @@ def sliding_window_inference( roi_dims = len(roi_size) if num_spatial_dims != roi_dims: raise ValueError( - f"inputs must have {roi_dims + 2} dimensions for {roi_dims}D roi_size " + f"Inputs must have {roi_dims + 2} dimensions for {roi_dims}D roi_size " f"(Batch, Channel, {', '.join(['Spatial'] * roi_dims)}), " f"but got inputs shape {inputs.shape}.\n" "If you have channel-last data (e.g. B, D, H, W, C), please use " diff --git a/tests/inferers/test_sliding_window_inference.py b/tests/inferers/test_sliding_window_inference.py index bada80d006..8700c4fcd0 100644 --- a/tests/inferers/test_sliding_window_inference.py +++ b/tests/inferers/test_sliding_window_inference.py @@ -384,12 +384,12 @@ def predictor(data): # Case 1: Input has fewer dimensions than expected (e.g., missing Batch or Channel) # 3D roi_size requires 5D input (B, C, D, H, W), giving 4D here. inputs_4d = torch.randn((1, 16, 16, 16), device=device) - with self.assertRaisesRegex(ValueError, "inputs must have 5 dimensions"): + with self.assertRaisesRegex(ValueError, "Inputs must have 5 dimensions"): sliding_window_inference(inputs_4d, roi_size, sw_batch_size, predictor) # Case 2: Input is 3D (missing Batch AND Channel) inputs_3d = torch.randn((16, 16, 16), device=device) - with self.assertRaisesRegex(ValueError, "inputs must have 5 dimensions"): + with self.assertRaisesRegex(ValueError, "Inputs must have 5 dimensions"): sliding_window_inference(inputs_3d, roi_size, sw_batch_size, predictor) From bdb78a9f333d330da8fb2491e4e22a474a4473a3 Mon Sep 17 00:00:00 2001 From: Matt Lin Date: Sat, 6 Dec 2025 11:24:23 +0800 Subject: [PATCH 3/3] Trigger CI re-run Signed-off-by: Matt Lin