Skip to content

Commit

Permalink
Make a few more layers symbolically traceable (remove from FX leaf mo…
Browse files Browse the repository at this point in the history
…dules)

* remove dtype kwarg from .to() calls in EvoNorm as it messed up script + trace combo
* BatchNormAct2d always uses custom forward (cut & paste from original) instead of super().forward. Fixes huggingface#1176
* BlurPool groups==channels, no need to use input.dim[1]
  • Loading branch information
rwightman committed Mar 25, 2022
1 parent a9ecb88 commit f670d98
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 48 deletions.
10 changes: 1 addition & 9 deletions timm/models/fx_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,17 @@
has_fx_feature_extraction = False

# Layers we went to treat as leaf modules
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath
from .layers import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2
from .layers import EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
from .layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
from .layers.non_local_attn import BilinearAttnTransform
from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame

# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
# BUT modules from timm.models should use the registration mechanism below
_leaf_modules = {
BatchNormAct2d, # reason: flow control for jit scripting
BilinearAttnTransform, # reason: flow control t <= 1
BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1]
# Reason: get_same_padding has a max which raises a control flow error
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
DropPath, # reason: TypeError: rand recieved Proxy in `size` argument
EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2, # to(dtype) use that causes tracing failure (on scripted models only?)
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a,

}

try:
Expand Down
2 changes: 1 addition & 1 deletion timm/models/layers/blur_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ def __init__(self, channels, filt_size=3, stride=2) -> None:

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.pad(x, self.padding, 'reflect')
return F.conv2d(x, self.filt, stride=self.stride, groups=x.shape[1])
return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels)
16 changes: 8 additions & 8 deletions timm/models/layers/drop.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,13 @@ class DropBlock2d(nn.Module):

def __init__(
self,
drop_prob=0.1,
block_size=7,
gamma_scale=1.0,
with_noise=False,
inplace=False,
batchwise=False,
fast=True):
drop_prob: float = 0.1,
block_size: int = 7,
gamma_scale: float = 1.0,
with_noise: bool = False,
inplace: bool = False,
batchwise: bool = False,
fast: bool = True):
super(DropBlock2d, self).__init__()
self.drop_prob = drop_prob
self.gamma_scale = gamma_scale
Expand Down Expand Up @@ -157,7 +157,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: b
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None, scale_by_keep=True):
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
Expand Down
30 changes: 15 additions & 15 deletions timm/models/layers/evo_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def group_rms(x, groups: int = 32, eps: float = 1e-5):
_assert(C % groups == 0, '')
x_dtype = x.dtype
x = x.reshape(B, groups, C // groups, H, W)
rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype)
rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(x_dtype)
return rms.expand(x.shape).reshape(B, C, H, W)


Expand Down Expand Up @@ -160,14 +160,14 @@ def forward(self, x):
n = x.numel() / x.shape[1]
self.running_var.copy_(
self.running_var * (1 - self.momentum) +
var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1)))
var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1)))
else:
var = self.running_var
var = var.to(dtype=x_dtype).view(v_shape)
var = var.to(x_dtype).view(v_shape)
left = var.add(self.eps).sqrt_()
right = (x + 1) * instance_rms(x, self.eps)
x = x / left.max(right)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)


class EvoNorm2dB2(nn.Module):
Expand Down Expand Up @@ -195,14 +195,14 @@ def forward(self, x):
n = x.numel() / x.shape[1]
self.running_var.copy_(
self.running_var * (1 - self.momentum) +
var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1)))
var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1)))
else:
var = self.running_var
var = var.to(dtype=x_dtype).view(v_shape)
var = var.to(x_dtype).view(v_shape)
left = var.add(self.eps).sqrt_()
right = instance_rms(x, self.eps) - x
x = x / left.max(right)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)


class EvoNorm2dS0(nn.Module):
Expand Down Expand Up @@ -231,9 +231,9 @@ def forward(self, x):
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
if self.v is not None:
v = self.v.view(v_shape).to(dtype=x_dtype)
v = self.v.view(v_shape).to(x_dtype)
x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)


class EvoNorm2dS0a(EvoNorm2dS0):
Expand All @@ -247,10 +247,10 @@ def forward(self, x):
v_shape = (1, -1, 1, 1)
d = group_std(x, self.groups, self.eps)
if self.v is not None:
v = self.v.view(v_shape).to(dtype=x_dtype)
v = self.v.view(v_shape).to(x_dtype)
x = x * (x * v).sigmoid()
x = x / d
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)


class EvoNorm2dS1(nn.Module):
Expand Down Expand Up @@ -284,7 +284,7 @@ def forward(self, x):
v_shape = (1, -1, 1, 1)
if self.apply_act:
x = self.act(x) / group_std(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)


class EvoNorm2dS1a(EvoNorm2dS1):
Expand All @@ -299,7 +299,7 @@ def forward(self, x):
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
x = self.act(x) / group_std(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)


class EvoNorm2dS2(nn.Module):
Expand Down Expand Up @@ -332,7 +332,7 @@ def forward(self, x):
v_shape = (1, -1, 1, 1)
if self.apply_act:
x = self.act(x) / group_rms(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)


class EvoNorm2dS2a(EvoNorm2dS2):
Expand All @@ -347,4 +347,4 @@ def forward(self, x):
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
x = self.act(x) / group_rms(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
21 changes: 6 additions & 15 deletions timm/models/layers/norm_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import nn as nn
from torch.nn import functional as F

from .trace_utils import _assert
from .create_act import get_act_layer


Expand All @@ -29,9 +30,10 @@ def __init__(
else:
self.act = nn.Identity()

def _forward_jit(self, x):
""" A cut & paste of the contents of the PyTorch BatchNorm2d forward function
"""
def forward(self, x):
# cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing
_assert(x.ndim == 4, f'expected 4D input (got {x.ndim}D input)')

# exponential_average_factor is set to self.momentum
# (when it is available) only so that it gets updated
# in ONNX graph when this node is exported to ONNX.
Expand Down Expand Up @@ -63,7 +65,7 @@ def _forward_jit(self, x):
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
used for normalization (i.e. in eval mode when buffers are not None).
"""
return F.batch_norm(
x = F.batch_norm(
x,
# If buffers are not to be tracked, ensure that they won't be updated
self.running_mean if not self.training or self.track_running_stats else None,
Expand All @@ -74,17 +76,6 @@ def _forward_jit(self, x):
exponential_average_factor,
self.eps,
)

@torch.jit.ignore
def _forward_python(self, x):
return super(BatchNormAct2d, self).forward(x)

def forward(self, x):
# FIXME cannot call parent forward() and maintain jit.script compatibility?
if torch.jit.is_scripting():
x = self._forward_jit(x)
else:
x = self._forward_python(x)
x = self.drop(x)
x = self.act(x)
return x
Expand Down

0 comments on commit f670d98

Please sign in to comment.