diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 37c78fae60..d9a3050223 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -233,10 +233,14 @@ def __init__( self.kernel_type = look_up_option(kernel_type, ["gaussian", "b-spline"]) self.num_bins = num_bins self.kernel_type = kernel_type + # declared as buffers so they move with the module (e.g. ``.to(device)``); only populated for the + # gaussian kernel, hence the ``Tensor`` annotation reflects the type at the use sites in that path. + self.preterm: torch.Tensor | None self.bin_centers: torch.Tensor | None + self.register_buffer("preterm", None, persistent=False) self.register_buffer("bin_centers", None, persistent=False) if self.kernel_type == "gaussian": - self.preterm = 1 / (2 * sigma**2) + self.register_buffer("preterm", 1 / (2 * sigma**2), persistent=False) self.register_buffer("bin_centers", bin_centers[None, None, ...], persistent=False) self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) @@ -316,8 +320,8 @@ def parzen_windowing_gaussian(self, img: torch.Tensor) -> tuple[torch.Tensor, to """ img = torch.clamp(img, 0, 1) img = img.reshape(img.shape[0], -1, 1) # (batch, num_sample, 1) - if self.bin_centers is None: - raise ValueError("bin_centers must be defined for gaussian parzen windowing.") + if self.bin_centers is None or self.preterm is None: + raise ValueError("bin_centers and preterm must be defined for gaussian parzen windowing.") weight = torch.exp( -self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2 ) # (batch, num_sample, num_bin) diff --git a/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py b/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py index a16499ac11..19a60f7219 100644 --- a/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +++ b/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py @@ -164,5 +164,51 @@ def test_ill_opts(self, num_bins, reduction, expected_exception, expected_messag GlobalMutualInformationLoss(num_bins=num_bins, reduction=reduction)(pred, target) +class TestGlobalMutualInformationLossBuffers(unittest.TestCase): + def test_gaussian_kernel_registers_buffers(self): + """Verify gaussian kernel registers preterm and bin_centers as non-trainable, non-persistent buffers.""" + loss = GlobalMutualInformationLoss(kernel_type="gaussian") + self.assertIn("preterm", loss._buffers) + self.assertIn("bin_centers", loss._buffers) + self.assertFalse(loss.preterm.requires_grad) + self.assertFalse(loss.bin_centers.requires_grad) + self.assertEqual(loss.bin_centers.ndim, 3) + state = loss.state_dict() + self.assertNotIn("preterm", state) + self.assertNotIn("bin_centers", state) + + def test_bspline_kernel_has_no_gaussian_buffers(self): + """Verify b-spline kernel does not populate gaussian-specific buffers.""" + loss = GlobalMutualInformationLoss(kernel_type="b-spline") + self.assertIsNone(loss.preterm) + self.assertIsNone(loss.bin_centers) + state = loss.state_dict() + self.assertNotIn("preterm", state) + self.assertNotIn("bin_centers", state) + + def test_gaussian_kernel_forward_correct(self): + """Verify gaussian kernel forward pass returns a scalar loss tensor.""" + pred = torch.rand(2, 1, 8, 8, dtype=torch.float32) + target = torch.rand(2, 1, 8, 8, dtype=torch.float32) + loss = GlobalMutualInformationLoss(kernel_type="gaussian") + result = loss(pred, target) + self.assertEqual(result.shape, torch.Size([])) + + def test_gaussian_buffers_move_with_module(self): + """Verify preterm and bin_centers buffers move to the target device with the module.""" + loss = GlobalMutualInformationLoss(kernel_type="gaussian") + self.assertEqual(loss.preterm.device.type, "cpu") + self.assertEqual(loss.bin_centers.device.type, "cpu") + if not torch.cuda.is_available(): + self.skipTest("CUDA not available") + loss = loss.cuda() + self.assertEqual(loss.preterm.device.type, "cuda") + self.assertEqual(loss.bin_centers.device.type, "cuda") + pred = torch.rand(2, 1, 8, 8, device="cuda") + target = torch.rand(2, 1, 8, 8, device="cuda") + result = loss(pred, target) + self.assertEqual(result.device.type, "cuda") + + if __name__ == "__main__": unittest.main()