Skip to content

Support gradient checkpointing in forward_intermediates() #2501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jun 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,18 @@ def test_model_forward(model_name, batch_size):
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'

# Test that grad-checkpointing, if supported, doesn't cause model failures or change in output
try:
model.set_grad_checkpointing()
except:
# throws if not supported, that's fine
pass
else:
outputs2 = model(inputs)
if isinstance(outputs, tuple):
outputs2 = torch.cat(outputs2)
assert torch.allclose(outputs, outputs2, rtol=1e-4, atol=1e-5), 'Output does not match'


@pytest.mark.base
@pytest.mark.timeout(timeout120)
Expand Down Expand Up @@ -529,6 +541,20 @@ def test_model_forward_intermediates(model_name, batch_size):
output2 = model.forward_features(inpt)
assert torch.allclose(output, output2)

# Test that grad-checkpointing, if supported
try:
model.set_grad_checkpointing()
except:
# throws if not supported, that's fine
pass
else:
output3, _ = model.forward_intermediates(
inpt,
output_fmt=output_fmt,
)
assert torch.allclose(output, output3, rtol=1e-4, atol=1e-5), 'Output does not match'



def _create_fx_model(model, train=False):
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
Expand Down Expand Up @@ -717,4 +743,4 @@ def test_model_forward_torchscript_with_features_fx(model_name, batch_size):

for tensor in outputs:
assert tensor.shape[0] == batch_size
assert not torch.isnan(tensor).any(), 'Output included NaNs'
assert not torch.isnan(tensor).any(), 'Output included NaNs'
5 changes: 4 additions & 1 deletion timm/models/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,10 @@ def forward_intermediates(
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x, shared_rel_pos_bias=rel_pos_bias)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias)
else:
x = blk(x, shared_rel_pos_bias=rel_pos_bias)
if i in take_indices:
# normalize intermediates with final norm layer if enabled
intermediates.append(self.norm(x) if norm else x)
Expand Down
5 changes: 4 additions & 1 deletion timm/models/byobnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1508,7 +1508,10 @@ def forward_intermediates(
stages = self.stages[:max_index]
for stage in stages:
feat_idx += 1
x = stage(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(stage, x)
else:
x = stage(x)
if not exclude_final_conv and feat_idx == last_idx:
# default feature_info for this model uses final_conv as the last feature output (if present)
x = self.final_conv(x)
Expand Down
7 changes: 5 additions & 2 deletions timm/models/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, use_fused_attn
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
from ._manipulate import checkpoint, checkpoint_seq
from ._registry import register_model, generate_default_cfgs

__all__ = ['Cait', 'ClassAttn', 'LayerScaleBlockClassAttn', 'LayerScaleBlock', 'TalkingHeadAttn']
Expand Down Expand Up @@ -373,7 +373,10 @@ def forward_intermediates(
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x)
else:
x = blk(x)
if i in take_indices:
# normalize intermediates with final norm layer if enabled
intermediates.append(self.norm(x) if norm else x)
Expand Down
7 changes: 1 addition & 6 deletions timm/models/crossvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,16 @@
NOTE: model names have been renamed from originals to represent actual input res all *_224 -> *_240 and *_384 -> *_408
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
Modified from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""

# Copyright IBM All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0


"""
Modified from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
from functools import partial
from typing import List, Optional, Tuple

import torch
import torch.hub
import torch.nn as nn

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
Expand Down
7 changes: 5 additions & 2 deletions timm/models/davit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq
from ._manipulate import checkpoint, checkpoint_seq
from ._registry import generate_default_cfgs, register_model

__all__ = ['DaVit']
Expand Down Expand Up @@ -671,7 +671,10 @@ def forward_intermediates(
stages = self.stages[:max_index + 1]

for feat_idx, stage in enumerate(stages):
x = stage(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(stage, x)
else:
x = stage(x)
if feat_idx in take_indices:
if norm and feat_idx == last_idx:
x_inter = self.norm_pre(x) # applying final norm to last intermediate
Expand Down
1 change: 0 additions & 1 deletion timm/models/dla.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import create_classifier
Expand Down
8 changes: 5 additions & 3 deletions timm/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,11 @@ def forward_intermediates(
blocks = self.blocks
else:
blocks = self.blocks[:max_index]
for blk in blocks:
feat_idx += 1
x = blk(x)
for feat_idx, blk in enumerate(blocks, start=1):
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(blk, x)
else:
x = blk(x)
if feat_idx in take_indices:
intermediates.append(x)

Expand Down
10 changes: 8 additions & 2 deletions timm/models/efficientvit_mit.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,10 @@ def forward_intermediates(
stages = self.stages[:max_index + 1]

for feat_idx, stage in enumerate(stages):
x = stage(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(stages, x)
else:
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)

Expand Down Expand Up @@ -943,7 +946,10 @@ def forward_intermediates(
stages = self.stages[:max_index + 1]

for feat_idx, stage in enumerate(stages):
x = stage(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(stages, x)
else:
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)

Expand Down
7 changes: 5 additions & 2 deletions timm/models/efficientvit_msra.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from timm.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
from ._manipulate import checkpoint, checkpoint_seq
from ._registry import register_model, generate_default_cfgs


Expand Down Expand Up @@ -510,7 +510,10 @@ def forward_intermediates(
stages = self.stages[:max_index + 1]

for feat_idx, stage in enumerate(stages):
x = stage(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(stage, x)
else:
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)

Expand Down
5 changes: 4 additions & 1 deletion timm/models/eva.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,10 @@ def forward_intermediates(
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x, rope=rot_pos_embed)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x, rope=rot_pos_embed)
else:
x = blk(x, rope=rot_pos_embed)
if i in take_indices:
intermediates.append(self.norm(x) if norm else x)

Expand Down
2 changes: 1 addition & 1 deletion timm/models/fasternet.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __init__(
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.downsample(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x, flatten=True)
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
Expand Down
2 changes: 1 addition & 1 deletion timm/models/focalnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def forward(self, x):
x = self.downsample(x)
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint.checkpoint(blk, x)
x = checkpoint(blk, x)
else:
x = blk(x)
return x
Expand Down
2 changes: 1 addition & 1 deletion timm/models/gcvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def forward(self, x):
global_query = self.global_norm(global_query.permute(0, 2, 3, 1))
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint.checkpoint(blk, x)
x = checkpoint(blk, x, global_query)
else:
x = blk(x, global_query)
x = self.norm(x)
Expand Down
5 changes: 4 additions & 1 deletion timm/models/ghostnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,10 @@ def forward_intermediates(
stages = self.blocks[:max_index + 1]

for feat_idx, stage in enumerate(stages, start=1):
x = stage(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(stage, x)
else:
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)

Expand Down
2 changes: 1 addition & 1 deletion timm/models/hgnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def __init__(
def forward(self, x):
x = self.downsample(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x, flatten=False)
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
Expand Down
7 changes: 5 additions & 2 deletions timm/models/hiera.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# --------------------------------------------------------
import math
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Dict, List, Optional, Tuple, Type, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -719,7 +719,10 @@ def forward_intermediates(
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x)
else:
x = blk(x)
if i in take_indices:
x_int = self.reroll(x, i, mask=mask)
intermediates.append(x_int.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x_int)
Expand Down
20 changes: 13 additions & 7 deletions timm/models/hieradet_sam2.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import math
from copy import deepcopy
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit import Final

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, LayerScale, \
get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit, to_2tuple, use_fused_attn

from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
from ._manipulate import named_apply, checkpoint
from ._registry import generate_default_cfgs, register_model


def window_partition(x, window_size: Tuple[int, int]):
Expand Down Expand Up @@ -289,6 +288,7 @@ def __init__(
norm_layer = get_norm_layer(norm_layer)
act_layer = get_act_layer(act_layer)
assert len(stages) == len(window_spec)
self.grad_checkpointing = False
self.num_classes = num_classes
self.window_spec = window_spec
self.output_fmt = 'NHWC'
Expand Down Expand Up @@ -471,7 +471,10 @@ def forward_intermediates(
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x)
else:
x = blk(x)
if i in take_indices:
x_out = x.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x
intermediates.append(x_out)
Expand Down Expand Up @@ -503,8 +506,11 @@ def prune_intermediate_layers(
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x) # BHWC
x = self._pos_embed(x)
for i, blk in enumerate(self.blocks):
x = blk(x)
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x)
else:
x = blk(x)
return x

def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
Expand Down
1 change: 0 additions & 1 deletion timm/models/hrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import create_classifier
Expand Down
1 change: 0 additions & 1 deletion timm/models/inception_resnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F

from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import create_classifier, ConvNormAct
Expand Down
1 change: 0 additions & 1 deletion timm/models/inception_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
Licensed BSD-Clause 3 https://github.com/pytorch/vision/blob/master/LICENSE
"""
from functools import partial
from typing import Optional

import torch
import torch.nn as nn
Expand Down
7 changes: 5 additions & 2 deletions timm/models/levit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_, ndgrid
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
from ._manipulate import checkpoint, checkpoint_seq
from ._registry import generate_default_cfgs, register_model

__all__ = ['Levit']
Expand Down Expand Up @@ -671,7 +671,10 @@ def forward_intermediates(
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(stage, x)
else:
x = stage(x)
if feat_idx in take_indices:
if self.use_conv:
intermediates.append(x)
Expand Down
Loading