Skip to content

Commit

Permalink
update basic_model
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 26, 2024
1 parent 34741ca commit b16d154
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
17 changes: 17 additions & 0 deletions hrdae/models/basic_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from dataclasses import dataclass
from pathlib import Path

Expand Down Expand Up @@ -45,6 +46,7 @@ def __init__(
if torch.cuda.is_available():
print("GPU is enabled")
self.device = torch.device("cuda:0")
self.network = nn.DataParallel(network).to(self.device)
else:
print("GPU is not enabled")
self.device = torch.device("cpu")
Expand All @@ -64,6 +66,7 @@ def train(
self.network.to(self.device)

least_val_loss = float("inf")
training_history: dict[str, list[dict[str, int | float]]] = {"history": []}

for epoch in range(n_epoch):
self.network.train()
Expand Down Expand Up @@ -143,6 +146,17 @@ def train(
"best",
)

training_history["history"].append(
{
"epoch": int(epoch + 1),
"train_loss": float(running_loss),
"val_loss": float(avg_val_loss),
}
)

with open(result_dir / "training_history.json", "w") as f:
json.dump(training_history, f)

if epoch % 10 == 0:
data = next(iter(val_loader))

Expand All @@ -169,6 +183,9 @@ def train(
f"epoch_{epoch}",
)

with open(result_dir / "training_history.json", "w") as f:
json.dump(training_history, f)

return least_val_loss


Expand Down

0 comments on commit b16d154

Please sign in to comment.