-
Notifications
You must be signed in to change notification settings - Fork 1.4k
8627 perceptual loss errors out after hitting the maximum number of downloads #8652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
virginiafdez
wants to merge
16
commits into
Project-MONAI:dev
Choose a base branch
from
virginiafdez:8627-perceptual-loss-errors-out-after-hitting-the-maximum-number-of-downloads
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+52
−21
Open
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
b1e4a50
Perceptual loss changes.
0aeb4d9
Merge branch '8627-perceptual-loss-errors-out-after-hitting-the-maxim…
fa0639b
Fixes
685aee2
Merge branch 'dev' into 8627-perceptual-loss-errors-out-after-hitting…
virginiafdez 915de5f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 5594bfe
Unnecessary import
b065de7
Merge branch '8627-perceptual-loss-errors-out-after-hitting-the-maxim…
c99e16e
Add check of network name
717b99b
Update monai/losses/perceptual.py
ericspod b276f3c
Update monai/losses/perceptual.py
ericspod 2156b84
Update monai/losses/perceptual.py
ericspod e2b982e
Update monai/losses/perceptual.py
virginiafdez e3be8de
Bug
b02053b
Merge branch '8627-perceptual-loss-errors-out-after-hitting-the-maxim…
6dfc209
DCO Remediation Commit for Virginia Fernandez <virginia.fernandez@kcl…
d258390
Reformatting
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,11 +19,18 @@ | |
| from monai.utils import optional_import | ||
| from monai.utils.enums import StrEnum | ||
|
|
||
| # Valid model name to download from the repository | ||
| HF_MONAI_MODELS = frozenset( | ||
| ("medicalnet_resnet10_23datasets", "medicalnet_resnet50_23datasets", "radimagenet_resnet50") | ||
| ) | ||
|
|
||
| LPIPS, _ = optional_import("lpips", name="LPIPS") | ||
| torchvision, _ = optional_import("torchvision") | ||
|
|
||
|
|
||
| class PercetualNetworkType(StrEnum): | ||
| class PerceptualNetworkType(StrEnum): | ||
| """Types of neural networks that are supported by perceptual loss.""" | ||
|
|
||
| alex = "alex" | ||
| vgg = "vgg" | ||
| squeeze = "squeeze" | ||
|
|
@@ -70,7 +77,7 @@ class PerceptualLoss(nn.Module): | |
| def __init__( | ||
| self, | ||
| spatial_dims: int, | ||
| network_type: str = PercetualNetworkType.alex, | ||
| network_type: str = PerceptualNetworkType.alex, | ||
| is_fake_3d: bool = True, | ||
| fake_3d_ratio: float = 0.5, | ||
| cache_dir: str | None = None, | ||
|
|
@@ -84,19 +91,25 @@ def __init__( | |
| if spatial_dims not in [2, 3]: | ||
| raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.") | ||
|
|
||
| if (spatial_dims == 2 or is_fake_3d) and "medicalnet_" in network_type: | ||
| raise ValueError( | ||
| "MedicalNet networks are only compatible with ``spatial_dims=3``." | ||
| "Argument is_fake_3d must be set to False." | ||
| ) | ||
|
|
||
| if channel_wise and "medicalnet_" not in network_type: | ||
| # Strict validation for MedicalNet | ||
| if "medicalnet_" in network_type: | ||
| if spatial_dims == 2 or is_fake_3d: | ||
| raise ValueError( | ||
| "MedicalNet networks are only compatible with ``spatial_dims=3``. Argument is_fake_3d must be set to False." | ||
| ) | ||
| if not channel_wise: | ||
| warnings.warn( | ||
| "MedicalNet networks supp, ort channel-wise loss. Consider setting channel_wise=True.", stacklevel=2 | ||
| ) | ||
|
|
||
| # Channel-wise only for MedicalNet | ||
| elif channel_wise: | ||
| raise ValueError("Channel-wise loss is only compatible with MedicalNet networks.") | ||
|
|
||
| if network_type.lower() not in list(PercetualNetworkType): | ||
| if network_type.lower() not in list(PerceptualNetworkType): | ||
| raise ValueError( | ||
| "Unrecognised criterion entered for Adversarial Loss. Must be one in: %s" | ||
| % ", ".join(PercetualNetworkType) | ||
| % ", ".join(PerceptualNetworkType) | ||
| ) | ||
|
Comment on lines
+109
to
113
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Error message still mentions “Adversarial Loss”. The validation is for - if network_type.lower() not in list(PerceptualNetworkType):
- raise ValueError(
- "Unrecognised criterion entered for Adversarial Loss. Must be one in: %s"
- % ", ".join(PerceptualNetworkType)
- )
+ if network_type.lower() not in list(PerceptualNetworkType):
+ raise ValueError(
+ "Unrecognised network_type for PerceptualLoss. Must be one of: %s"
+ % ", ".join(PerceptualNetworkType)
+ )🤖 Prompt for AI Agents |
||
|
|
||
| if cache_dir: | ||
|
|
@@ -108,12 +121,16 @@ def __init__( | |
|
|
||
| self.spatial_dims = spatial_dims | ||
| self.perceptual_function: nn.Module | ||
|
|
||
| # If spatial_dims is 3, only MedicalNet supports 3D models, otherwise, spatial_dims=2 and fake_3D must be used. | ||
| if spatial_dims == 3 and is_fake_3d is False: | ||
| self.perceptual_function = MedicalNetPerceptualSimilarity( | ||
| net=network_type, verbose=False, channel_wise=channel_wise | ||
| net=network_type, verbose=False, channel_wise=channel_wise, cache_dir=cache_dir | ||
| ) | ||
| elif "radimagenet_" in network_type: | ||
| self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False) | ||
| self.perceptual_function = RadImageNetPerceptualSimilarity( | ||
| net=network_type, verbose=False, cache_dir=cache_dir | ||
| ) | ||
| elif network_type == "resnet50": | ||
| self.perceptual_function = TorchvisionModelPerceptualSimilarity( | ||
| net=network_type, | ||
|
|
@@ -122,7 +139,9 @@ def __init__( | |
| pretrained_state_dict_key=pretrained_state_dict_key, | ||
| ) | ||
| else: | ||
| # VGG, AlexNet and SqueezeNet are independently handled by LPIPS. | ||
| self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False) | ||
|
|
||
| self.is_fake_3d = is_fake_3d | ||
| self.fake_3d_ratio = fake_3d_ratio | ||
| self.channel_wise = channel_wise | ||
|
|
@@ -194,7 +213,7 @@ class MedicalNetPerceptualSimilarity(nn.Module): | |
| """ | ||
| Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer | ||
| Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from | ||
| "Warvito/MedicalNet-models". | ||
| "Project-MONAI/perceptual-models". | ||
|
|
||
| Args: | ||
| net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``} | ||
|
|
@@ -205,11 +224,19 @@ class MedicalNetPerceptualSimilarity(nn.Module): | |
| """ | ||
|
|
||
| def __init__( | ||
virginiafdez marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False, channel_wise: bool = False | ||
| self, | ||
| net: str = "medicalnet_resnet10_23datasets", | ||
| verbose: bool = False, | ||
| channel_wise: bool = False, | ||
| cache_dir: str | None = None, | ||
| ) -> None: | ||
| super().__init__() | ||
| torch.hub._validate_not_a_forked_repo = lambda a, b, c: True | ||
| self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose, trust_repo=True) | ||
| if net not in HF_MONAI_MODELS: | ||
| raise ValueError(f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}.") | ||
|
|
||
| self.model = torch.hub.load( | ||
virginiafdez marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, trust_repo=True | ||
| ) | ||
| self.eval() | ||
|
|
||
| self.channel_wise = channel_wise | ||
|
|
@@ -258,7 +285,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
| for i in range(input.shape[1]): | ||
| l_idx = i * feats_per_ch | ||
| r_idx = (i + 1) * feats_per_ch | ||
| results[:, i, ...] = feats_diff[:, l_idx : i + r_idx, ...].sum(dim=1) | ||
| results[:, i, ...] = feats_diff[:, l_idx:r_idx, ...].sum(dim=1) | ||
| else: | ||
| results = feats_diff.sum(dim=1, keepdim=True) | ||
|
|
||
|
|
@@ -287,17 +314,21 @@ class RadImageNetPerceptualSimilarity(nn.Module): | |
| """ | ||
| Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et | ||
| al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class | ||
| uses torch Hub to download the networks from "Warvito/radimagenet-models". | ||
| uses torch Hub to download the networks from "Project-MONAI/perceptual-models". | ||
|
|
||
| Args: | ||
| net: {``"radimagenet_resnet50"``} | ||
| Specifies the network architecture to use. Defaults to ``"radimagenet_resnet50"``. | ||
| verbose: if false, mute messages from torch Hub load function. | ||
| """ | ||
|
|
||
| def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None: | ||
| def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, cache_dir: str | None = None) -> None: | ||
| super().__init__() | ||
| self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose, trust_repo=True) | ||
| if net not in HF_MONAI_MODELS: | ||
| raise ValueError(f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}.") | ||
| self.model = torch.hub.load( | ||
| "Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, trust_repo=True | ||
| ) | ||
| self.eval() | ||
|
|
||
| for param in self.parameters(): | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Restrict model validation per family and guard 3D path to MedicalNet only.
As written,
HF_MONAI_MODELSis shared by both MedicalNet and RadImageNet, and the 3D branch inPerceptualLossalways instantiatesMedicalNetPerceptualSimilaritywhenspatial_dims == 3 and is_fake_3d is False, regardless ofnetwork_type. This leads to:network_type="radimagenet_resnet50"withspatial_dims=3andis_fake_3d=Falsebeing passed intoMedicalNetPerceptualSimilarity, which will attempt to run a 2D RadImageNet backbone in a 3D MedicalNet path (shape/device errors at runtime instead of a clean validation failure).MedicalNetPerceptualSimilarityandRadImageNetPerceptualSimilarityboth accepting each other’s model names because they both useHF_MONAI_MODELSdirectly.Recommend:
network_typeis a MedicalNet variant before constructingMedicalNetPerceptualSimilarity, otherwise raise a clearValueError.Example patch:
Also applies to: 94-107, 125-133, 234-239, 325-331