diff --git a/monai/networks/blocks/activation_checkpointing.py b/monai/networks/blocks/activation_checkpointing.py new file mode 100644 index 0000000000..283bcd19e1 --- /dev/null +++ b/monai/networks/blocks/activation_checkpointing.py @@ -0,0 +1,41 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import cast + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + + +class ActivationCheckpointWrapper(nn.Module): + """Wrapper applying activation checkpointing to a module during training. + + Args: + module: The module to wrap with activation checkpointing. + """ + + def __init__(self, module: nn.Module) -> None: + super().__init__() + self.module = module + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass with optional activation checkpointing. + + Args: + x: Input tensor. + + Returns: + Output tensor from the wrapped module. + """ + return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index eac0ddab39..1fa5cbf7f2 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -17,11 +17,12 @@ import torch import torch.nn as nn +from monai.networks.blocks.activation_checkpointing import ActivationCheckpointWrapper from monai.networks.blocks.convolutions import Convolution, ResidualUnit from monai.networks.layers.factories import Act, Norm from monai.networks.layers.simplelayers import SkipConnection -__all__ = ["UNet", "Unet"] +__all__ = ["UNet", "Unet", "CheckpointUNet"] class UNet(nn.Module): @@ -298,4 +299,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class CheckpointUNet(UNet): + """UNet variant that wraps internal connection blocks with activation checkpointing. + + See `UNet` for constructor arguments. During training with gradients enabled, + intermediate activations inside encoder–decoder connections are recomputed in + the backward pass to reduce peak memory usage at the cost of extra compute. + """ + + def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: + subblock = ActivationCheckpointWrapper(subblock) + down_path = ActivationCheckpointWrapper(down_path) + up_path = ActivationCheckpointWrapper(up_path) + return super()._get_connection_block(down_path, up_path, subblock) + + Unet = UNet