diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index 2ae03bc8dc..b2563aaf57 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -19,11 +19,18 @@ from monai.utils import optional_import from monai.utils.enums import StrEnum +# Valid model name to download from the repository +HF_MONAI_MODELS = frozenset( + ("medicalnet_resnet10_23datasets", "medicalnet_resnet50_23datasets", "radimagenet_resnet50") +) + LPIPS, _ = optional_import("lpips", name="LPIPS") torchvision, _ = optional_import("torchvision") -class PercetualNetworkType(StrEnum): +class PerceptualNetworkType(StrEnum): + """Types of neural networks that are supported by perceptual loss.""" + alex = "alex" vgg = "vgg" squeeze = "squeeze" @@ -70,7 +77,7 @@ class PerceptualLoss(nn.Module): def __init__( self, spatial_dims: int, - network_type: str = PercetualNetworkType.alex, + network_type: str = PerceptualNetworkType.alex, is_fake_3d: bool = True, fake_3d_ratio: float = 0.5, cache_dir: str | None = None, @@ -84,19 +91,25 @@ def __init__( if spatial_dims not in [2, 3]: raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.") - if (spatial_dims == 2 or is_fake_3d) and "medicalnet_" in network_type: - raise ValueError( - "MedicalNet networks are only compatible with ``spatial_dims=3``." - "Argument is_fake_3d must be set to False." - ) - - if channel_wise and "medicalnet_" not in network_type: + # Strict validation for MedicalNet + if "medicalnet_" in network_type: + if spatial_dims == 2 or is_fake_3d: + raise ValueError( + "MedicalNet networks are only compatible with ``spatial_dims=3``. Argument is_fake_3d must be set to False." + ) + if not channel_wise: + warnings.warn( + "MedicalNet networks supp, ort channel-wise loss. Consider setting channel_wise=True.", stacklevel=2 + ) + + # Channel-wise only for MedicalNet + elif channel_wise: raise ValueError("Channel-wise loss is only compatible with MedicalNet networks.") - if network_type.lower() not in list(PercetualNetworkType): + if network_type.lower() not in list(PerceptualNetworkType): raise ValueError( "Unrecognised criterion entered for Adversarial Loss. Must be one in: %s" - % ", ".join(PercetualNetworkType) + % ", ".join(PerceptualNetworkType) ) if cache_dir: @@ -108,12 +121,16 @@ def __init__( self.spatial_dims = spatial_dims self.perceptual_function: nn.Module + + # If spatial_dims is 3, only MedicalNet supports 3D models, otherwise, spatial_dims=2 and fake_3D must be used. if spatial_dims == 3 and is_fake_3d is False: self.perceptual_function = MedicalNetPerceptualSimilarity( - net=network_type, verbose=False, channel_wise=channel_wise + net=network_type, verbose=False, channel_wise=channel_wise, cache_dir=cache_dir ) elif "radimagenet_" in network_type: - self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False) + self.perceptual_function = RadImageNetPerceptualSimilarity( + net=network_type, verbose=False, cache_dir=cache_dir + ) elif network_type == "resnet50": self.perceptual_function = TorchvisionModelPerceptualSimilarity( net=network_type, @@ -122,7 +139,9 @@ def __init__( pretrained_state_dict_key=pretrained_state_dict_key, ) else: + # VGG, AlexNet and SqueezeNet are independently handled by LPIPS. self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False) + self.is_fake_3d = is_fake_3d self.fake_3d_ratio = fake_3d_ratio self.channel_wise = channel_wise @@ -194,7 +213,7 @@ class MedicalNetPerceptualSimilarity(nn.Module): """ Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from - "Warvito/MedicalNet-models". + "Project-MONAI/perceptual-models". Args: net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``} @@ -205,11 +224,19 @@ class MedicalNetPerceptualSimilarity(nn.Module): """ def __init__( - self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False, channel_wise: bool = False + self, + net: str = "medicalnet_resnet10_23datasets", + verbose: bool = False, + channel_wise: bool = False, + cache_dir: str | None = None, ) -> None: super().__init__() - torch.hub._validate_not_a_forked_repo = lambda a, b, c: True - self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose, trust_repo=True) + if net not in HF_MONAI_MODELS: + raise ValueError(f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}.") + + self.model = torch.hub.load( + "Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, trust_repo=True + ) self.eval() self.channel_wise = channel_wise @@ -258,7 +285,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: for i in range(input.shape[1]): l_idx = i * feats_per_ch r_idx = (i + 1) * feats_per_ch - results[:, i, ...] = feats_diff[:, l_idx : i + r_idx, ...].sum(dim=1) + results[:, i, ...] = feats_diff[:, l_idx:r_idx, ...].sum(dim=1) else: results = feats_diff.sum(dim=1, keepdim=True) @@ -287,7 +314,7 @@ class RadImageNetPerceptualSimilarity(nn.Module): """ Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class - uses torch Hub to download the networks from "Warvito/radimagenet-models". + uses torch Hub to download the networks from "Project-MONAI/perceptual-models". Args: net: {``"radimagenet_resnet50"``} @@ -295,9 +322,13 @@ class RadImageNetPerceptualSimilarity(nn.Module): verbose: if false, mute messages from torch Hub load function. """ - def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None: + def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, cache_dir: str | None = None) -> None: super().__init__() - self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose, trust_repo=True) + if net not in HF_MONAI_MODELS: + raise ValueError(f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}.") + self.model = torch.hub.load( + "Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, trust_repo=True + ) self.eval() for param in self.parameters():