Skip to content

Commit

Permalink
update tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 15, 2024
1 parent 14d9624 commit 976f235
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 11 deletions.
2 changes: 1 addition & 1 deletion hrdae/conf/experiment/dataloader/basic.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
batch_size: 96
batch_size: 64
train_val_ratio: 0.8
# # mnist
# transform_order_train:
Expand Down
2 changes: 1 addition & 1 deletion hrdae/conf/experiment/model/vr.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
loss_coef:
wmse: 1
phase: "0"
pred_diff: true
pred_diff: false
defaults:
- /config/experiment/model/vr@_here_
- [email protected]: wmse
Expand Down
4 changes: 4 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-omegaconf.*]
ignore_missing_imports = True
[mypy-optuna]
ignore_missing_imports = True
[mypy-optuna.*]
ignore_missing_imports = True
[mypy-pytorch_tcn]
ignore_missing_imports = True
[mypy-pytorch_tcn.*]
Expand Down
14 changes: 14 additions & 0 deletions requirements.gpu.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
alembic==1.13.1
antlr4-python3-runtime==4.9.3
black==24.4.2
click==8.1.7
colorlog==6.8.2
contourpy==1.2.1
cycler==0.12.1
filelock==3.14.0
fonttools==4.53.0
fsspec==2024.5.0
greenlet==3.0.3
hydra-core==1.3.2
isort==5.13.2
Jinja2==3.1.4
kiwisolver==1.4.5
Mako==1.3.5
MarkupSafe==2.1.5
matplotlib==3.9.0
mpmath==1.3.0
mypy-extensions==1.0.0
networkx==3.3
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
Expand All @@ -25,14 +33,20 @@ nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.5.40
nvidia-nvtx-cu12==12.1.105
omegaconf==2.3.0
optuna==3.6.1
packaging==24.0
pathspec==0.12.1
pillow==10.3.0
platformdirs==4.2.2
pyparsing==3.1.2
python-dateutil==2.9.0.post0
pytorch-tcn==1.1.0
PyYAML==6.0.1
six==1.16.0
SQLAlchemy==2.0.30
sympy==1.12.1
torch==2.3.0
torchaudio==2.3.0
torchvision==0.18.0
tqdm==4.66.4
typing_extensions==4.12.1
30 changes: 21 additions & 9 deletions tuning.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from dataclasses import asdict
from pathlib import Path
from typing import Any

from hrdae.option import TrainExpOption
from hrdae.dataloaders.transforms import MinMaxNormalizationOption
Expand Down Expand Up @@ -35,6 +36,14 @@ def interleave_arrays(
return result


def default(item: Any):
match item:
case Path():
return str(item)
case _:
raise TypeError(type(item))


def objective(trial):
dataset_option = MovingMNISTDatasetOption(
root="data",
Expand Down Expand Up @@ -63,8 +72,10 @@ def objective(trial):
phase = trial.suggest_categorical("phase", ["0", "t", "all"])
network_name = trial.suggest_categorical("network", ["hrdae2d", "rae2d", "rdae2d"])
if phase == "all":
motion_encoder_name = trial.suggest_categorical("motion_encoder", ["conv2d", "normal1d", "rnn1d"])
pred_diff = False
motion_encoder_name = trial.suggest_categorical("motion_encoder_all", ["conv2d", "normal1d", "rnn1d"])
else:
pred_diff = trial.suggest_categorical("pred_diff", [True, False])
motion_encoder_name = trial.suggest_categorical("motion_encoder", ["conv2d", "guided1d", "normal1d", "rnn1d", "tsn1d"])
motion_encoder_num_layers = trial.suggest_int("motion_encoder_num_layers", 0, 3)
if motion_encoder_name == "rnn1d":
Expand Down Expand Up @@ -130,7 +141,7 @@ def objective(trial):
)
else:
raise RuntimeError("unreachable")
latent_dim = int(trial.suggest_discrete_uniform("latent_dim", 16, 64, 8))
latent_dim = trial.suggest_int("latent_dim", 16, 64, step=8)
content_encoder_num_layers = trial.suggest_int("content_encoder_encoder_num_layers", 0, 3)
if network_name == "rae2d":
network_option = RAE2dOption(
Expand Down Expand Up @@ -168,33 +179,33 @@ def objective(trial):
raise RuntimeError("unreachable")

optimizer_option = AdamOptimizerOption(
lr=trial.suggest_loguniform("lr", 1e-5, 1e-2),
lr=trial.suggest_float("lr", 1e-5, 1e-2, log=True),
)

scheduler_option = OneCycleLRSchedulerOption(
max_lr=trial.suggest_loguniform("max_lr", 1e-3, 1e-2),
max_lr=trial.suggest_float("max_lr", 1e-3, 1e-2, log=True),
)

model_option = VRModelOption(
loss_coef={"wmse": 1.0},
phase=phase,
pred_diff=trial.suggest_categorical("pred_diff", [True, False]),
pred_diff=pred_diff,
loss=loss_option,
network=network_option,
optimizer=optimizer_option,
scheduler=scheduler_option,
)

result_dir = Path(f"result/tuning/mmnist/{network_name}/{motion_encoder_name}/{trial.number}")
result_dir = Path(f"results/tuning/mmnist/{network_name}/{motion_encoder_name}/{trial.number}")
train_option = TrainExpOption(
result_dir=result_dir,
dataloader=dataloader_option,
model=model_option,
n_epoch=50,
n_epoch=1,
)
result_dir.mkdir(parents=True, exist_ok=True)
with open(result_dir / "config.json", "w") as f:
json.dump(asdict(train_option), f, indent=2)
json.dump(asdict(train_option), f, indent=2, default=default)

train_loader, val_loader = create_dataloader(
dataloader_option,
Expand All @@ -210,6 +221,7 @@ def objective(trial):
val_loader,
n_epoch=train_option.n_epoch,
result_dir=result_dir,
debug=False,
)


Expand All @@ -218,7 +230,7 @@ def objective(trial):

study = optuna.create_study(
study_name="tuning",
storage="sqlite:///result/tuning.db",
storage="sqlite:///results/tuning.db",
load_if_exists=True,
)
study.optimize(objective, n_trials=100)
Expand Down

0 comments on commit 976f235

Please sign in to comment.