diff --git a/tests/test_models.py b/tests/test_models.py index b4686a3efe..58dad77e27 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -186,6 +186,18 @@ def test_model_forward(model_name, batch_size): assert outputs.shape[0] == batch_size assert not torch.isnan(outputs).any(), 'Output included NaNs' + # Test that grad-checkpointing, if supported, doesn't cause model failures or change in output + try: + model.set_grad_checkpointing() + except: + # throws if not supported, that's fine + pass + else: + outputs2 = model(inputs) + if isinstance(outputs, tuple): + outputs2 = torch.cat(outputs2) + assert torch.allclose(outputs, outputs2, rtol=1e-4, atol=1e-5), 'Output does not match' + @pytest.mark.base @pytest.mark.timeout(timeout120) @@ -529,6 +541,20 @@ def test_model_forward_intermediates(model_name, batch_size): output2 = model.forward_features(inpt) assert torch.allclose(output, output2) + # Test that grad-checkpointing, if supported + try: + model.set_grad_checkpointing() + except: + # throws if not supported, that's fine + pass + else: + output3, _ = model.forward_intermediates( + inpt, + output_fmt=output_fmt, + ) + assert torch.allclose(output, output3, rtol=1e-4, atol=1e-5), 'Output does not match' + + def _create_fx_model(model, train=False): # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode @@ -717,4 +743,4 @@ def test_model_forward_torchscript_with_features_fx(model_name, batch_size): for tensor in outputs: assert tensor.shape[0] == batch_size - assert not torch.isnan(tensor).any(), 'Output included NaNs' \ No newline at end of file + assert not torch.isnan(tensor).any(), 'Output included NaNs' diff --git a/timm/models/beit.py b/timm/models/beit.py index 2ee5fbb01e..5b5973801d 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -615,7 +615,10 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x, shared_rel_pos_bias=rel_pos_bias) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias) + else: + x = blk(x, shared_rel_pos_bias=rel_pos_bias) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(x) if norm else x) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 18da1f2dc7..e37d25b6ef 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -1508,7 +1508,10 @@ def forward_intermediates( stages = self.stages[:max_index] for stage in stages: feat_idx += 1 - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(stage, x) + else: + x = stage(x) if not exclude_final_conv and feat_idx == last_idx: # default feature_info for this model uses final_conv as the last feature output (if present) x = self.final_conv(x) diff --git a/timm/models/cait.py b/timm/models/cait.py index 28e14ec756..2c500ec3dd 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -18,7 +18,7 @@ from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, use_fused_attn from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import register_model, generate_default_cfgs __all__ = ['Cait', 'ClassAttn', 'LayerScaleBlockClassAttn', 'LayerScaleBlock', 'TalkingHeadAttn'] @@ -373,7 +373,10 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x) + else: + x = blk(x) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(x) if norm else x) diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index f3d52f8e49..0e1de2fadd 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -14,21 +14,16 @@ NOTE: model names have been renamed from originals to represent actual input res all *_224 -> *_240 and *_384 -> *_408 Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman +Modified from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py """ # Copyright IBM All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 - -""" -Modified from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py - -""" from functools import partial from typing import List, Optional, Tuple import torch -import torch.hub import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD diff --git a/timm/models/davit.py b/timm/models/davit.py index a82f2e5fae..22b4a1a05f 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -25,7 +25,7 @@ from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_function -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import generate_default_cfgs, register_model __all__ = ['DaVit'] @@ -671,7 +671,10 @@ def forward_intermediates( stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: if norm and feat_idx == last_idx: x_inter = self.norm_pre(x) # applying final norm to last intermediate diff --git a/timm/models/dla.py b/timm/models/dla.py index 666acd9d9c..197060e4e6 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -10,7 +10,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import create_classifier diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 245f544066..7f2f5aa341 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -259,9 +259,11 @@ def forward_intermediates( blocks = self.blocks else: blocks = self.blocks[:max_index] - for blk in blocks: - feat_idx += 1 - x = blk(x) + for feat_idx, blk in enumerate(blocks, start=1): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(blk, x) + else: + x = blk(x) if feat_idx in take_indices: intermediates.append(x) diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py index 80fbe43314..8b35a04c87 100644 --- a/timm/models/efficientvit_mit.py +++ b/timm/models/efficientvit_mit.py @@ -789,7 +789,10 @@ def forward_intermediates( stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(stages, x) + else: + x = stage(x) if feat_idx in take_indices: intermediates.append(x) @@ -943,7 +946,10 @@ def forward_intermediates( stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(stages, x) + else: + x = stage(x) if feat_idx in take_indices: intermediates.append(x) diff --git a/timm/models/efficientvit_msra.py b/timm/models/efficientvit_msra.py index f3c8db74cd..80c5d99995 100644 --- a/timm/models/efficientvit_msra.py +++ b/timm/models/efficientvit_msra.py @@ -18,7 +18,7 @@ from timm.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -510,7 +510,10 @@ def forward_intermediates( stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: intermediates.append(x) diff --git a/timm/models/eva.py b/timm/models/eva.py index 61301616d6..21f0b9dc35 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -731,7 +731,10 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x, rope=rot_pos_embed) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, rope=rot_pos_embed) + else: + x = blk(x, rope=rot_pos_embed) if i in take_indices: intermediates.append(self.norm(x) if norm else x) diff --git a/timm/models/fasternet.py b/timm/models/fasternet.py index d73f49a265..b9e4aed249 100644 --- a/timm/models/fasternet.py +++ b/timm/models/fasternet.py @@ -142,7 +142,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.downsample(x) if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.blocks, x, flatten=True) + x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) return x diff --git a/timm/models/focalnet.py b/timm/models/focalnet.py index aa3237925f..3c2bd75643 100644 --- a/timm/models/focalnet.py +++ b/timm/models/focalnet.py @@ -274,7 +274,7 @@ def forward(self, x): x = self.downsample(x) for blk in self.blocks: if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint.checkpoint(blk, x) + x = checkpoint(blk, x) else: x = blk(x) return x diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index 214619de9b..367e5dfff5 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -361,7 +361,7 @@ def forward(self, x): global_query = self.global_norm(global_query.permute(0, 2, 3, 1)) for blk in self.blocks: if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint.checkpoint(blk, x) + x = checkpoint(blk, x, global_query) else: x = blk(x, global_query) x = self.norm(x) diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index a2ddad4696..126d638f43 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -727,7 +727,10 @@ def forward_intermediates( stages = self.blocks[:max_index + 1] for feat_idx, stage in enumerate(stages, start=1): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(stage, x) + else: + x = stage(x) if feat_idx in take_indices: intermediates.append(x) diff --git a/timm/models/hgnet.py b/timm/models/hgnet.py index 212cbb58ff..5c49f9ddca 100644 --- a/timm/models/hgnet.py +++ b/timm/models/hgnet.py @@ -345,7 +345,7 @@ def __init__( def forward(self, x): x = self.downsample(x) if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.blocks, x, flatten=False) + x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) return x diff --git a/timm/models/hiera.py b/timm/models/hiera.py index 2c16a9d63e..fa9d6d2833 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -24,7 +24,7 @@ # -------------------------------------------------------- import math from functools import partial -from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -719,7 +719,10 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x) + else: + x = blk(x) if i in take_indices: x_int = self.reroll(x, i, mask=mask) intermediates.append(x_int.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x_int) diff --git a/timm/models/hieradet_sam2.py b/timm/models/hieradet_sam2.py index 6cd2592a95..fbd7ce28ec 100644 --- a/timm/models/hieradet_sam2.py +++ b/timm/models/hieradet_sam2.py @@ -1,12 +1,11 @@ import math from copy import deepcopy from functools import partial -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F -from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, LayerScale, \ @@ -14,8 +13,8 @@ from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv -from ._registry import generate_default_cfgs, register_model, register_model_deprecations +from ._manipulate import named_apply, checkpoint +from ._registry import generate_default_cfgs, register_model def window_partition(x, window_size: Tuple[int, int]): @@ -289,6 +288,7 @@ def __init__( norm_layer = get_norm_layer(norm_layer) act_layer = get_act_layer(act_layer) assert len(stages) == len(window_spec) + self.grad_checkpointing = False self.num_classes = num_classes self.window_spec = window_spec self.output_fmt = 'NHWC' @@ -471,7 +471,10 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x) + else: + x = blk(x) if i in take_indices: x_out = x.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x intermediates.append(x_out) @@ -503,8 +506,11 @@ def prune_intermediate_layers( def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) # BHWC x = self._pos_embed(x) - for i, blk in enumerate(self.blocks): - x = blk(x) + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x) + else: + x = blk(x) return x def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor: diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 75b157d67d..92ee3511cf 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -13,7 +13,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import create_classifier diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index 7fdfee41ed..d691be7a8f 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -5,7 +5,6 @@ from functools import partial import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import create_classifier, ConvNormAct diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index 8cb1a151df..a55521c3de 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -4,7 +4,6 @@ Licensed BSD-Clause 3 https://github.com/pytorch/vision/blob/master/LICENSE """ from functools import partial -from typing import Optional import torch import torch.nn as nn diff --git a/timm/models/levit.py b/timm/models/levit.py index 577fc5f2d7..a4c9ce628a 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -34,7 +34,7 @@ from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_, ndgrid from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import generate_default_cfgs, register_model __all__ = ['Levit'] @@ -671,7 +671,10 @@ def forward_intermediates( else: stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: if self.use_conv: intermediates.append(x) diff --git a/timm/models/metaformer.py b/timm/models/metaformer.py index 2e93e01b16..3364a79563 100644 --- a/timm/models/metaformer.py +++ b/timm/models/metaformer.py @@ -41,7 +41,7 @@ use_fused_attn from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import generate_default_cfgs, register_model __all__ = ['MetaFormer'] @@ -631,7 +631,10 @@ def forward_intermediates( stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: intermediates.append(x) diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index b024fba478..838e0e0117 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -49,7 +49,7 @@ from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import named_apply, checkpoint_seq +from ._manipulate import named_apply, checkpoint, checkpoint_seq from ._registry import generate_default_cfgs, register_model, register_model_deprecations __all__ = ['MixerBlock', 'MlpMixer'] # model_registry will add each entrypoint fn to this @@ -406,7 +406,10 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x) + else: + x = blk(x) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(x) if norm else x) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 8e25674b63..eb87bb38d8 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -227,9 +227,11 @@ def forward_intermediates( blocks = self.blocks else: blocks = self.blocks[:max_index] - for blk in blocks: - feat_idx += 1 - x = blk(x) + for feat_idx, blk in enumerate(blocks, start=1): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(blk, x) + else: + x = blk(x) if feat_idx in take_indices: intermediates.append(x) diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index c048a07277..01c4550ed8 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -681,7 +681,7 @@ def __init__( def forward(self, x, feat_size: List[int]): for blk in self.blocks: if self.grad_checkpointing and not torch.jit.is_scripting(): - x, feat_size = checkpoint.checkpoint(blk, x, feat_size) + x, feat_size = checkpoint(blk, x, feat_size) else: x, feat_size = blk(x, feat_size) return x, feat_size diff --git a/timm/models/naflexvit.py b/timm/models/naflexvit.py index 29e8ba28e7..5794fdf5e0 100644 --- a/timm/models/naflexvit.py +++ b/timm/models/naflexvit.py @@ -20,14 +20,14 @@ import math from dataclasses import dataclass, fields, replace from functools import partial -from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union, Final, Any, Literal +from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union, Any import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import ( AttentionPoolLatent, Mlp, @@ -35,14 +35,13 @@ get_act_layer, get_norm_layer, LayerNorm, - LayerType, _assert, ) from timm.models._builder import build_model_with_cfg from timm.models._features import feature_take_indices from timm.models._features_fx import register_notrace_function, register_notrace_module from timm.models._registry import register_model, generate_default_cfgs -from timm.models._manipulate import checkpoint_seq, named_apply +from timm.models._manipulate import checkpoint, checkpoint_seq, named_apply from .vision_transformer import Block, global_pool_nlc @@ -1202,7 +1201,7 @@ def forward_intermediates( output_dict: bool = False, patch_coord: Optional[torch.Tensor] = None, patch_valid: Optional[torch.Tensor] = None, - mask: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]: """ Forward features that returns intermediates. @@ -1217,7 +1216,7 @@ def forward_intermediates( output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex - mask: Optional attention mask + attn_mask: Optional attention mask for masked attention Returns: A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing 'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix') @@ -1241,8 +1240,8 @@ def forward_intermediates( H, W = self.embeds.dynamic_feat_size((height, width)) # Create attention mask if patch_type is provided and mask is not - if mask is None and patch_valid is not None: - mask = create_attention_mask(patch_valid, self.num_prefix_tokens, patches.dtype) + if attn_mask is None and patch_valid is not None: + attn_mask = create_attention_mask(patch_valid, self.num_prefix_tokens, patches.dtype) # Forward pass through embedding x = self.embeds(patches, patch_coord=patch_coord) @@ -1255,7 +1254,12 @@ def forward_intermediates( blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x, attn_mask=mask) + if attn_mask is not None: + x = blk(x, attn_mask=attn_mask) + elif self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x) + else: + x = blk(x) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(x) if norm else x) diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index a0d2d8a734..8edee45c86 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -3,11 +3,9 @@ https://github.com/Cadene/pretrained-models.pytorch """ from functools import partial -from typing import Optional import torch import torch.nn as nn -import torch.nn.functional as F from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier from ._builder import build_model_with_cfg diff --git a/timm/models/nextvit.py b/timm/models/nextvit.py index 2f232e2990..402a9d76ea 100644 --- a/timm/models/nextvit.py +++ b/timm/models/nextvit.py @@ -17,8 +17,7 @@ from timm.layers import ClassifierHead from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._features_fx import register_notrace_function -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import generate_default_cfgs, register_model __all__ = ['NextViT'] @@ -595,7 +594,10 @@ def forward_intermediates( stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: if feat_idx == last_idx: x_inter = self.norm(x) if norm else x diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index d848f81eba..68b8b1b6d6 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -19,7 +19,7 @@ from collections import OrderedDict from dataclasses import dataclass, replace from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple import torch import torch.nn as nn @@ -304,7 +304,7 @@ def create_stem( if 'deep' in stem_type: if 'quad' in stem_type: # 4 deep conv stack as in NFNet-F models - assert not 'pool' in stem_type + assert 'pool' not in stem_type stem_chs = (out_chs // 8, out_chs // 4, out_chs // 2, out_chs) strides = (2, 1, 1, 2) stem_stride = 4 diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 20d17945b5..7f33aaeabb 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -10,7 +10,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier from ._builder import build_model_with_cfg diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index bb1baf6645..0259a1f64e 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -267,7 +267,7 @@ def forward(self, x): x = x.reshape(B, -1, C) for blk in self.blocks: if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint.checkpoint(blk, x, feat_size) + x = checkpoint(blk, x, feat_size) else: x = blk(x, feat_size) x = self.norm(x) diff --git a/timm/models/rdnet.py b/timm/models/rdnet.py index a3a205fff6..5764b6ed82 100644 --- a/timm/models/rdnet.py +++ b/timm/models/rdnet.py @@ -281,6 +281,27 @@ def __init__( named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) + @torch.jit.ignore + def group_matcher(self, coarse=False): + assert not coarse, "coarse grouping is not implemented for RDNet" + return dict( + stem=r'^stem', + blocks=r'^dense_stages\.(\d+)', + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.dense_stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head.fc + + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes + self.head.reset(num_classes, global_pool) + def forward_intermediates( self, x: torch.Tensor, @@ -350,14 +371,6 @@ def prune_intermediate_layers( self.reset_classifier(0, '') return take_indices - @torch.jit.ignore - def get_classifier(self) -> nn.Module: - return self.head.fc - - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): - self.num_classes = num_classes - self.head.reset(num_classes, global_pool) - def forward_features(self, x): x = self.stem(x) x = self.dense_stages(x) @@ -372,19 +385,6 @@ def forward(self, x): x = self.forward_head(x) return x - @torch.jit.ignore - def group_matcher(self, coarse=False): - assert not coarse, "coarse grouping is not implemented for RDNet" - return dict( - stem=r'^stem', - blocks=r'^dense_stages\.(\d+)', - ) - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - for s in self.dense_stages: - s.grad_checkpointing = enable - def _init_weights(module, name=None, head_init_scale=1.0): if isinstance(module, nn.Conv2d): diff --git a/timm/models/repghost.py b/timm/models/repghost.py index a75d9d8506..c5a7d93a4f 100644 --- a/timm/models/repghost.py +++ b/timm/models/repghost.py @@ -336,7 +336,10 @@ def forward_intermediates( stages = self.blocks[:max_index + 1] for feat_idx, stage in enumerate(stages, start=1): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(stage, x) + else: + x = stage(x) if feat_idx in take_indices: intermediates.append(x) diff --git a/timm/models/repvit.py b/timm/models/repvit.py index 190f4b5298..3641d6f70c 100644 --- a/timm/models/repvit.py +++ b/timm/models/repvit.py @@ -23,7 +23,7 @@ from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import register_model, generate_default_cfgs __all__ = ['RepVit'] @@ -367,7 +367,10 @@ def forward_intermediates( stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: intermediates.append(x) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 38cd89ce11..0b78e7b44d 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -397,6 +397,8 @@ def __init__( **block_kwargs: Any, ): super(ResNetStage, self).__init__() + self.grad_checkpointing = False + first_dilation = 1 if dilation in (1, 2) else 2 layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer) proj_layer = DownsampleAvg if avg_down else DownsampleConv @@ -431,7 +433,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: Output tensor. """ - x = self.blocks(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) return x @@ -604,7 +609,6 @@ def __init__( ) self.init_weights(zero_init_last=zero_init_last) - self.grad_checkpointing = False @torch.jit.ignore def init_weights(self, zero_init_last: bool = True) -> None: @@ -631,7 +635,8 @@ def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: @torch.jit.ignore def set_grad_checkpointing(self, enable: bool = True) -> None: """Enable or disable gradient checkpointing.""" - self.grad_checkpointing = enable + for s in self.stages: + s.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self) -> nn.Module: @@ -731,10 +736,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: Feature tensor. """ x = self.stem(x) - if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.stages, x, flatten=True) - else: - x = self.stages(x) + x = self.stages(x) x = self.norm(x) return x diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 04e284158c..77b801db87 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -22,7 +22,7 @@ from ._builder import build_model_with_cfg from ._efficientnet_builder import efficientnet_init_weights from ._features import feature_take_indices -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import generate_default_cfgs, register_model __all__ = ['RexNet'] # model_registry will add each entrypoint fn to this @@ -382,7 +382,10 @@ def forward_intermediates( stages = self.features[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: intermediates.append(x) @@ -426,7 +429,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: """ x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.features, x, flatten=True) + x = checkpoint_seq(self.features, x) else: x = self.features(x) return x diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index fdfa16c318..dc19ff41d1 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -13,7 +13,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import create_classifier diff --git a/timm/models/shvit.py b/timm/models/shvit.py index be3e206ee8..c165f1a280 100644 --- a/timm/models/shvit.py +++ b/timm/models/shvit.py @@ -11,7 +11,6 @@ year={2024} } """ -import re from typing import Any, Dict, List, Optional, Set, Tuple, Union import torch @@ -245,7 +244,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.downsample(x) if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.blocks, x, flatten=True) + x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) return x @@ -429,7 +428,7 @@ def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) state_dict = state_dict.get('model', state_dict) # out_dict = {} - # + # import re # replace_rules = [ # (re.compile(r'^blocks1\.'), 'stages.0.blocks.'), # (re.compile(r'^blocks2\.'), 'stages.1.blocks.'), diff --git a/timm/models/starnet.py b/timm/models/starnet.py index 646fd324b2..9ed32a85d7 100644 --- a/timm/models/starnet.py +++ b/timm/models/starnet.py @@ -198,7 +198,10 @@ def forward_intermediates( stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(stage, x) + else: + x = stage(x) if feat_idx in take_indices: if norm and feat_idx == last_idx: x_inter = self.norm(x) # applying final norm last intermediate @@ -233,7 +236,7 @@ def prune_intermediate_layers( def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.stages, x, flatten=True) + x = checkpoint_seq(self.stages, x) else: x = self.stages(x) x = self.norm(x) diff --git a/timm/models/swiftformer.py b/timm/models/swiftformer.py index 5998c233fd..38df6f1638 100644 --- a/timm/models/swiftformer.py +++ b/timm/models/swiftformer.py @@ -304,7 +304,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.downsample(x) if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.blocks, x, flatten=True) + x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) return x diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 7eeae8316b..e17f16746c 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -24,7 +24,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \ - _assert, use_fused_attn, resize_rel_pos_bias_table, resample_patch_embed, ndgrid + use_fused_attn, resize_rel_pos_bias_table, resample_patch_embed, ndgrid from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_function diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index f7b758aa8b..35c0daa8ac 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -619,7 +619,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: for blk in self.blocks: if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint.checkpoint(blk, x) + x = checkpoint(blk, x) else: x = blk(x) return x diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index c490fa23ca..1ef3164fd9 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -29,7 +29,7 @@ # -------------------------------------------------------- import logging import math -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -636,7 +636,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: for block in self.blocks: # Perform checkpointing if utilized if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint.checkpoint(block, x) + x = checkpoint(block, x) else: x = block(x) x = bhwc_to_bchw(x) diff --git a/timm/models/tiny_vit.py b/timm/models/tiny_vit.py index 366eef7092..39bacc850c 100644 --- a/timm/models/tiny_vit.py +++ b/timm/models/tiny_vit.py @@ -22,7 +22,7 @@ from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_module -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -570,7 +570,10 @@ def forward_intermediates( stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: intermediates.append(x) diff --git a/timm/models/tnt.py b/timm/models/tnt.py index fa6e1fc9e7..0ecd8e72a4 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -386,7 +386,10 @@ def forward_intermediates( blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - pixel_embed, patch_embed = blk(pixel_embed, patch_embed) + if self.grad_checkpointing and not torch.jit.is_scripting(): + pixel_embed, patch_embed = checkpoint(blk, pixel_embed, patch_embed) + else: + pixel_embed, patch_embed = blk(pixel_embed, patch_embed) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(patch_embed) if norm else patch_embed) diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 0fb76fa40c..2c452e4707 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -15,7 +15,7 @@ from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import register_model, generate_default_cfgs, register_model_deprecations __all__ = ['TResNet'] # model_registry will add each entrypoint fn to this @@ -263,7 +263,10 @@ def forward_intermediates( stages = self.body[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: intermediates.append(x) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index b57e2f213c..05e435e9e9 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -68,7 +68,7 @@ ) from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv +from ._manipulate import named_apply, checkpoint, checkpoint_seq, adapt_input_conv from ._registry import generate_default_cfgs, register_model, register_model_deprecations __all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this @@ -824,7 +824,12 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x, attn_mask=attn_mask) + if attn_mask is not None: + x = blk(x, attn_mask=attn_mask) + elif self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x) + else: + x = blk(x) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(x) if norm else x) diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 4cf3a7664b..0ff4823497 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -13,15 +13,14 @@ Hacked together by / Copyright 2020, Ross Wightman """ -import math from functools import partial -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import Dict, Tuple, Type, Union import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, HybridEmbed +from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_ntuple, HybridEmbed from ._builder import build_model_with_cfg from ._registry import generate_default_cfgs, register_model, register_model_deprecations diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 030c24dc69..dcccba73ba 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -427,7 +427,10 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x, shared_rel_pos=shared_rel_pos) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, shared_rel_pos=shared_rel_pos) + else: + x = blk(x, shared_rel_pos=shared_rel_pos) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(x) if norm else x) diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index 75bb12e56f..df70f4a251 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -24,7 +24,7 @@ from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_function -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import generate_default_cfgs, register_model # model_registry will add each entrypoint fn to this @@ -579,7 +579,10 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x) + else: + x = blk(x) if i in take_indices: # make output BCHW if norm: diff --git a/timm/models/volo.py b/timm/models/volo.py index f76a8361a3..f417dc6df1 100644 --- a/timm/models/volo.py +++ b/timm/models/volo.py @@ -1039,7 +1039,10 @@ def forward_intermediates( # add positional encoding after outlooker blocks x = x + self.pos_embed x = self.pos_drop(x) - x = block(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(block, x) + else: + x = block(x) if idx in take_indices: if norm and idx >= 2: x_inter = self.norm(x) diff --git a/timm/models/xcit.py b/timm/models/xcit.py index 250749f1cf..271578adf8 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -478,7 +478,10 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x, Hp, Wp) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, Hp, Wp) + else: + x = blk(x, Hp, Wp) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(x) if norm else x) @@ -494,7 +497,10 @@ def forward_intermediates( # NOTE not supporting return of class tokens x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) for blk in self.cls_attn_blocks: - x = blk(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x) + else: + x = blk(x) x = self.norm(x)