diff --git a/pytorchvideo/layers/swish.py b/pytorchvideo/layers/swish.py index 21bdcec..b5d1592 100644 --- a/pytorchvideo/layers/swish.py +++ b/pytorchvideo/layers/swish.py @@ -10,6 +10,12 @@ class Swish(nn.Module): """ def forward(self, x): + if not self.training: + """ + SwishFunction is not traceable by torchscript due to use of torch.autograd + Apply model.eval() before tracing using torch.jit.trace + """ + return x * torch.sigmoid(x) return SwishFunction.apply(x)