From 49ef9b14633e374b5b7aba99a21ebe882597b192 Mon Sep 17 00:00:00 2001 From: nnaakkaaii Date: Mon, 1 Jul 2024 10:00:23 +0900 Subject: [PATCH] fix perceptual loss --- hrdae/conf | 2 +- hrdae/models/losses/perceptual.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/hrdae/conf b/hrdae/conf index ac23481..103ba58 160000 --- a/hrdae/conf +++ b/hrdae/conf @@ -1 +1 @@ -Subproject commit ac23481521a427c7a2a2adbf43a39113dc7c56aa +Subproject commit 103ba5803b845099449e250300b2e3c4cbb9d923 diff --git a/hrdae/models/losses/perceptual.py b/hrdae/models/losses/perceptual.py index 92e3bea..fdb93bc 100644 --- a/hrdae/models/losses/perceptual.py +++ b/hrdae/models/losses/perceptual.py @@ -44,6 +44,7 @@ def __init__( if torch.cuda.is_available(): self.network.to("cuda:0") self.network = nn.DataParallel(self.network) + self.network.eval() def forward(self, input: Tensor, target: Tensor) -> Tensor: b, t = input.size()[:2]