From 1704b1ea557f70360ff9be0538176fddcdd644b6 Mon Sep 17 00:00:00 2001 From: nnaakkaaii Date: Mon, 8 Jul 2024 15:38:36 +0900 Subject: [PATCH] update model logic --- hrdae/conf | 2 +- hrdae/models/basic_model.py | 3 --- hrdae/models/vr_model.py | 5 ++--- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/hrdae/conf b/hrdae/conf index f011d2c..a8fe296 160000 --- a/hrdae/conf +++ b/hrdae/conf @@ -1 +1 @@ -Subproject commit f011d2c4a8ba1d9ddc4e4f3855923e42c8bbd356 +Subproject commit a8fe296ab46c0dbc134c6084e529910df6fccdc5 diff --git a/hrdae/models/basic_model.py b/hrdae/models/basic_model.py index 5468f0e..3b09270 100644 --- a/hrdae/models/basic_model.py +++ b/hrdae/models/basic_model.py @@ -181,9 +181,6 @@ def train( f"epoch_{epoch}", ) - with open(result_dir / "training_history.json", "w") as f: - json.dump(training_history, f) - return least_val_loss diff --git a/hrdae/models/vr_model.py b/hrdae/models/vr_model.py index f6e4b96..b7b3ba5 100644 --- a/hrdae/models/vr_model.py +++ b/hrdae/models/vr_model.py @@ -144,6 +144,8 @@ def train( "val_loss": float(avg_val_loss), } ) + with open(result_dir / "training_history.json", "w") as f: + json.dump(training_history, f) if epoch % 10 == 0: data = next(iter(val_loader)) @@ -166,9 +168,6 @@ def train( result_dir / "weights" / f"model_{epoch}.pth", ) - with open(result_dir / "training_history.json", "w") as f: - json.dump(training_history, f) - return least_val_loss