Skip to content
19 changes: 19 additions & 0 deletions monai/networks/nets/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@

import warnings
from collections.abc import Sequence
from typing import cast

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

from monai.networks.blocks.convolutions import Convolution, ResidualUnit
from monai.networks.layers.factories import Act, Norm
Expand All @@ -24,6 +26,17 @@
__all__ = ["UNet", "Unet"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add CheckpointUNet to __all__ exports.

CheckpointUNet is a public class but not exported in __all__.

Apply this diff:

-__all__ = ["UNet", "Unet"]
+__all__ = ["UNet", "Unet", "CheckpointUNet"]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
__all__ = ["UNet", "Unet"]
__all__ = ["UNet", "Unet", "CheckpointUNet"]
🤖 Prompt for AI Agents
In monai/networks/nets/unet.py around line 26, update the module export list to
include the public class CheckpointUNet by adding "CheckpointUNet" to the
__all__ array; modify the existing __all__ = ["UNet", "Unet"] to include
"CheckpointUNet" so the final export list contains "UNet", "Unet", and
"CheckpointUNet".



class _ActivationCheckpointWrapper(nn.Module):
"""Apply activation checkpointing to the wrapped module during training."""

def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module

def forward(self, x: torch.Tensor) -> torch.Tensor:
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
Comment on lines +29 to +37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add Google-style docstrings.

Class and forward docstrings need Args/Returns sections per guidelines. Document the wrapped module, checkpoint guard details, and returned tensor.

As per coding guidelines.

🤖 Prompt for AI Agents
In monai/networks/nets/unet.py around lines 29 to 37, the
_ActivationCheckpointWrapper class and its forward method lack Google-style
docstrings; add a class-level docstring that briefly describes purpose, document
the module parameter as "module: nn.Module — module to wrap for activation
checkpointing", mention that checkpointing is applied during training to save
memory and that use_reentrant=False is used as the checkpoint guard, and add a
forward method docstring with Args: x (torch.Tensor): input tensor to the
wrapped module and Returns: torch.Tensor: output tensor from the wrapped module
(with activations checkpointed); keep wording concise and follow Google-style
"Args/Returns" formatting.



class UNet(nn.Module):
"""
Enhanced version of UNet which has residual units implemented with the ResidualUnit class.
Expand Down Expand Up @@ -298,4 +311,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class CheckpointUNet(UNet):
def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
subblock = _ActivationCheckpointWrapper(subblock)
return super()._get_connection_block(down_path, up_path, subblock)
Comment on lines 314 to 319
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add targeted tests for checkpointing behavior.

New subclass lacks coverage. Add unit tests ensuring training mode uses checkpoint (grad-enabled) and eval/no-grad bypasses it, plus parity with base UNet.

As per coding guidelines.

🤖 Prompt for AI Agents
In monai/networks/nets/unet.py around lines 314 to 317, the new CheckpointUNet
subclass wraps subblocks with _ActivationCheckpointWrapper but has no unit
tests; add targeted tests that (1) construct a small UNet and CheckpointUNet
with identical weights and assert forward outputs match in eval mode (no_grad),
(2) in training mode run a forward+backward pass for both models and assert that
gradients are non-zero and that gradient patterns between UNet and
CheckpointUNet are equivalent, and (3) verify that when checkpointing is active
(training) the wrapped subblocks execute under autograd (e.g., by asserting
requires_grad on intermediary tensors or by checking that backward succeeds) and
when in eval/no_grad the wrapper bypasses torch.autograd.checkpoint (i.e., no
grad graph created). Use small random inputs, deterministic seeds, and parameter
copying to ensure parity; place tests under the networks/nets test module and
follow existing test conventions.

🛠️ Refactor suggestion | 🟠 Major

Document CheckpointUNet.

Provide a Google-style class docstring noting the checkpointing behavior, inherited args, and trade-offs.

As per coding guidelines.

🤖 Prompt for AI Agents
In monai/networks/nets/unet.py around lines 314 to 317, the CheckpointUNet class
lacks a Google-style docstring; add a concise class docstring immediately under
the class declaration that (1) states this class enables activation
checkpointing by wrapping subblocks with _ActivationCheckpointWrapper, (2)
indicates it inherits the same constructor arguments and behavior as UNet
(mention "See UNet for args"), and (3) documents trade-offs: reduced memory at
the cost of extra compute/recomputation and potential differences in backward
behavior; keep wording brief and follow Google docstring format (one-line
summary, extended description, Args inherited note, and a short
“Raises/Warnings” or “Notes” line about the trade-offs).



Unet = UNet
Loading