Skip to content

Commit

Permalink
impl tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 15, 2024
1 parent 6011af7 commit 2c49f44
Show file tree
Hide file tree
Showing 4 changed files with 269 additions and 23 deletions.
25 changes: 4 additions & 21 deletions hrdae/models/networks/r_dae.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# Recurrent Disentangled AutoEncoder (R-DAE)

from dataclasses import dataclass, field
from dataclasses import dataclass

from omegaconf import MISSING
from torch import Tensor, nn

from .functions import aggregate
Expand All @@ -15,38 +14,22 @@
)
from .motion_encoder import (
MotionEncoder1d,
MotionEncoder1dOption,
MotionEncoder2d,
MotionEncoder2dOption,
create_motion_encoder1d,
create_motion_encoder2d,
check_in_channels,
)
from .option import NetworkOption
from .r_ae import RAE2dOption, RAE3dOption


@dataclass
class RDAE2dOption(NetworkOption):
latent_dim: int = 64
conv_params: list[dict[str, list[int]]] = field(
default_factory=lambda: [{"kernel_size": [3], "stride": [2], "padding": [1]}]
* 3,
)
motion_encoder: MotionEncoder1dOption = MISSING
class RDAE2dOption(RAE2dOption):
aggregation_method: str = "concat"
debug_show_dim: bool = False


@dataclass
class RDAE3dOption(NetworkOption):
latent_dim: int = 64
conv_params: list[dict[str, list[int]]] = field(
default_factory=lambda: [{"kernel_size": [3], "stride": [2], "padding": [1]}]
* 3,
)
motion_encoder: MotionEncoder2dOption = MISSING
class RDAE3dOption(RAE3dOption):
aggregation_method: str = "concat"
debug_show_dim: bool = False


def create_rdae2d(in_channels: int, out_channels: int, opt: RDAE2dOption) -> nn.Module:
Expand Down
22 changes: 20 additions & 2 deletions hrdae/models/vr_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from dataclasses import dataclass
from pathlib import Path

Expand Down Expand Up @@ -28,6 +29,11 @@ class VRModelOption(ModelOption):
pred_diff: bool = False


def save_model(model: nn.Module, filepath: Path):
model_to_save = model.module if isinstance(model, torch.nn.DataParallel) else model
torch.save(model_to_save.state_dict(), filepath)


class VRModel(Model):
def __init__(
self,
Expand All @@ -49,6 +55,7 @@ def __init__(
if torch.cuda.is_available():
print("GPU is enabled")
self.device = torch.device("cuda:0")
self.network = nn.DataParallel(network).to(self.device)
else:
print("GPU is not enabled")
self.device = torch.device("cpu")
Expand All @@ -65,9 +72,8 @@ def train(
if debug:
max_iter = 5

self.network.to(self.device)

least_val_loss = float("inf")
training_history: dict[str, list[dict[str, int | float]]] = {"history": []}

for epoch in range(n_epoch):
self.network.train()
Expand Down Expand Up @@ -130,6 +136,15 @@ def train(

if avg_val_loss < least_val_loss:
least_val_loss = avg_val_loss
save_model(self.network, result_dir / "best_model.pth")

training_history["history"].append(
{
"epoch": int(epoch + 1),
"train_loss": float(running_loss),
"val_loss": float(avg_val_loss),
}
)

if epoch % 10 == 0:
data = next(iter(val_loader))
Expand All @@ -152,6 +167,9 @@ def train(
result_dir / "logs" / "reconstructed",
)

with open(result_dir / "training_history.json", "w") as f:
json.dump(training_history, f)

return least_val_loss


Expand Down
18 changes: 18 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
alembic==1.13.1
antlr4-python3-runtime==4.9.3
attrs==23.2.0
autopage==0.5.2
black==24.4.2
click==8.1.7
cliff==4.7.0
cmaes==0.10.0
cmd2==2.4.3
colorlog==6.8.2
contourpy==1.2.1
cycler==0.12.1
exceptiongroup==1.2.1
filelock==3.14.0
flake8==7.0.0
fonttools==4.53.0
fsspec==2024.5.0
greenlet==3.0.3
hydra-core==1.3.2
hydra-optuna-sweeper==1.2.0
iniconfig==2.0.0
isort==5.13.2
Jinja2==3.1.4
kiwisolver==1.4.5
Mako==1.3.5
MarkupSafe==2.1.5
matplotlib==3.9.0
mccabe==0.7.0
Expand All @@ -22,22 +32,30 @@ mypy-extensions==1.0.0
networkx==3.3
numpy==1.26.4
omegaconf==2.3.0
optuna==2.10.1
packaging==24.0
pathspec==0.12.1
pbr==6.0.0
pillow==10.3.0
platformdirs==4.2.2
pluggy==1.5.0
prettytable==3.10.0
pycodestyle==2.11.1
pyflakes==3.2.0
pyparsing==3.1.2
pyperclip==1.8.2
pytest==8.2.2
python-dateutil==2.9.0.post0
pytorch-tcn==1.1.0
PyYAML==6.0.1
scipy==1.13.1
six==1.16.0
SQLAlchemy==2.0.30
stevedore==5.2.0
sympy==1.12.1
tomli==2.0.1
torch==2.2.2
torchvision==0.17.2
tqdm==4.66.4
typing_extensions==4.12.1
wcwidth==0.2.13
Loading

0 comments on commit 2c49f44

Please sign in to comment.