Skip to content

Commit

Permalink
update model logic
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jul 8, 2024
1 parent 1a2a778 commit 1704b1e
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 7 deletions.
3 changes: 0 additions & 3 deletions hrdae/models/basic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,6 @@ def train(
f"epoch_{epoch}",
)

with open(result_dir / "training_history.json", "w") as f:
json.dump(training_history, f)

return least_val_loss


Expand Down
5 changes: 2 additions & 3 deletions hrdae/models/vr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def train(
"val_loss": float(avg_val_loss),
}
)
with open(result_dir / "training_history.json", "w") as f:
json.dump(training_history, f)

if epoch % 10 == 0:
data = next(iter(val_loader))
Expand All @@ -166,9 +168,6 @@ def train(
result_dir / "weights" / f"model_{epoch}.pth",
)

with open(result_dir / "training_history.json", "w") as f:
json.dump(training_history, f)

return least_val_loss


Expand Down

0 comments on commit 1704b1e

Please sign in to comment.