From b435021f533812c470ac895f5942d03f4e8446df Mon Sep 17 00:00:00 2001 From: Shahnawaz Alam Date: Mon, 29 May 2023 17:58:43 -0400 Subject: [PATCH] Updated Swish function to enable model export --- pytorchvideo/layers/swish.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytorchvideo/layers/swish.py b/pytorchvideo/layers/swish.py index 21bdcece..b5d15928 100644 --- a/pytorchvideo/layers/swish.py +++ b/pytorchvideo/layers/swish.py @@ -10,6 +10,12 @@ class Swish(nn.Module): """ def forward(self, x): + if not self.training: + """ + SwishFunction is not traceable by torchscript due to use of torch.autograd + Apply model.eval() before tracing using torch.jit.trace + """ + return x * torch.sigmoid(x) return SwishFunction.apply(x)