Skip to content

Commit

Permalink
fix tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 15, 2024
1 parent fca6473 commit 77a7bf8
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion tuning/mmnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import asdict
from pathlib import Path
from typing import Any
from uuid import uuid4

from hrdae.option import TrainExpOption
from hrdae.dataloaders.transforms import MinMaxNormalizationOption
Expand Down Expand Up @@ -75,6 +76,10 @@ def objective(trial):
else:
network_name = trial.suggest_categorical("network", ["hrdae2d", "rae2d", "rdae2d"])
if hasattr(args, "motion_encoder_name") and args.motion_encoder_name in ["conv2d", "guided1d", "normal1d", "rnn1d", "tsn1d"]:
if phase == "all":
pred_diff = False
else:
pred_diff = trial.suggest_categorical("pred_diff", [True, False])
motion_encoder_name = args.motion_encoder_name
elif phase == "all":
pred_diff = False
Expand Down Expand Up @@ -204,7 +209,7 @@ def objective(trial):
scheduler=scheduler_option,
)

result_dir = Path(f"results/tuning/mmnist/{network_name}/{motion_encoder_name}/{trial.number}")
result_dir = Path(f"results/tuning/mmnist/{network_name}/{motion_encoder_name}/{trial.number}-{uuid4()[:8]}")
train_option = TrainExpOption(
result_dir=result_dir,
dataloader=dataloader_option,
Expand Down

0 comments on commit 77a7bf8

Please sign in to comment.