diff --git a/hrdae/models/losses/perceptual.py b/hrdae/models/losses/perceptual.py index 92e3bea..fdb93bc 100644 --- a/hrdae/models/losses/perceptual.py +++ b/hrdae/models/losses/perceptual.py @@ -44,6 +44,7 @@ def __init__( if torch.cuda.is_available(): self.network.to("cuda:0") self.network = nn.DataParallel(self.network) + self.network.eval() def forward(self, input: Tensor, target: Tensor) -> Tensor: b, t = input.size()[:2]