-
Notifications
You must be signed in to change notification settings - Fork 1.3k
feat: add activation checkpointing to unet #8554
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from 10 commits
de2b6bd
66edcb5
e66e357
feefcaa
e112457
f673ca1
69540ff
42ec757
a2e8474
4c4782e
515c659
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,9 +13,11 @@ | |
|
|
||
| import warnings | ||
| from collections.abc import Sequence | ||
| from typing import cast | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from torch.utils.checkpoint import checkpoint | ||
|
|
||
| from monai.networks.blocks.convolutions import Convolution, ResidualUnit | ||
| from monai.networks.layers.factories import Act, Norm | ||
|
|
@@ -24,6 +26,17 @@ | |
| __all__ = ["UNet", "Unet"] | ||
|
|
||
|
|
||
| class _ActivationCheckpointWrapper(nn.Module): | ||
| """Apply activation checkpointing to the wrapped module during training.""" | ||
|
|
||
| def __init__(self, module: nn.Module) -> None: | ||
| super().__init__() | ||
| self.module = module | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) | ||
|
Comment on lines
+29
to
+37
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Add Google-style docstrings. Class and forward docstrings need Args/Returns sections per guidelines. Document the wrapped module, checkpoint guard details, and returned tensor. As per coding guidelines. 🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| class UNet(nn.Module): | ||
| """ | ||
| Enhanced version of UNet which has residual units implemented with the ResidualUnit class. | ||
|
|
@@ -298,4 +311,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x | ||
|
|
||
|
|
||
| class CheckpointUNet(UNet): | ||
| def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: | ||
| subblock = _ActivationCheckpointWrapper(subblock) | ||
| return super()._get_connection_block(down_path, up_path, subblock) | ||
|
Comment on lines
314
to
319
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add targeted tests for checkpointing behavior. New subclass lacks coverage. Add unit tests ensuring training mode uses checkpoint (grad-enabled) and eval/no-grad bypasses it, plus parity with base As per coding guidelines. 🤖 Prompt for AI Agents🛠️ Refactor suggestion | 🟠 Major Document Provide a Google-style class docstring noting the checkpointing behavior, inherited args, and trade-offs. As per coding guidelines. 🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| Unet = UNet | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add
CheckpointUNetto__all__exports.CheckpointUNetis a public class but not exported in__all__.Apply this diff:
📝 Committable suggestion
🤖 Prompt for AI Agents