Skip to content

Commit

Permalink
modify tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 15, 2024
1 parent 976f235 commit fca6473
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
2 changes: 1 addition & 1 deletion hrdae/conf/experiment/model/vr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ pred_diff: false
defaults:
- /config/experiment/model/vr@_here_
- [email protected]: wmse
- network: hrdae2d
- network: hrdae3d
- optimizer: adam
- scheduler: onecyclelr
30 changes: 23 additions & 7 deletions tuning.py → tuning/mmnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,24 @@ def objective(trial):
}

phase = trial.suggest_categorical("phase", ["0", "t", "all"])
network_name = trial.suggest_categorical("network", ["hrdae2d", "rae2d", "rdae2d"])
if phase == "all":
if hasattr(args, "network_name") and args.network_name in ["hrdae2d", "rae2d", "rdae2d"]:
network_name = args.network_name
else:
network_name = trial.suggest_categorical("network", ["hrdae2d", "rae2d", "rdae2d"])
if hasattr(args, "motion_encoder_name") and args.motion_encoder_name in ["conv2d", "guided1d", "normal1d", "rnn1d", "tsn1d"]:
motion_encoder_name = args.motion_encoder_name
elif phase == "all":
pred_diff = False
motion_encoder_name = trial.suggest_categorical("motion_encoder_all", ["conv2d", "normal1d", "rnn1d"])
else:
pred_diff = trial.suggest_categorical("pred_diff", [True, False])
motion_encoder_name = trial.suggest_categorical("motion_encoder", ["conv2d", "guided1d", "normal1d", "rnn1d", "tsn1d"])
motion_encoder_num_layers = trial.suggest_int("motion_encoder_num_layers", 0, 3)
if motion_encoder_name == "rnn1d":
rnn_name = trial.suggest_categorical("rnn", ["conv_lstm1d", "gru1d", "tcn1d"])
if hasattr(args, "rnn_name") and args.rnn_name in ["conv_lstm1d", "gru1d", "tcn1d"]:
rnn_name = args.rnn_name
else:
rnn_name = trial.suggest_categorical("rnn", ["conv_lstm1d", "gru1d", "tcn1d"])
motion_encoder_name = f"{motion_encoder_name}/{rnn_name}"
if rnn_name == "conv_lstm1d":
rnn_option = ConvLSTM1dOption(
Expand Down Expand Up @@ -201,7 +209,7 @@ def objective(trial):
result_dir=result_dir,
dataloader=dataloader_option,
model=model_option,
n_epoch=1,
n_epoch=50,
)
result_dir.mkdir(parents=True, exist_ok=True)
with open(result_dir / "config.json", "w") as f:
Expand All @@ -226,15 +234,23 @@ def objective(trial):


if __name__ == "__main__":
import argparse

import optuna

parser = argparse.ArgumentParser()
parser.add_argument("--network_name", type=str)
parser.add_argument("--motion_encoder_name", type=str)
parser.add_argument("--rnn_name", type=str)
args = parser.parse_args()

study = optuna.create_study(
study_name="tuning",
storage="sqlite:///results/tuning.db",
study_name="mmnist",
storage="sqlite:///results/tuning/mmnist/sqlite.db",
load_if_exists=True,
)
study.optimize(objective, n_trials=100)
print(study.best_params)
print(study.best_value)
print(study.best_trial)
study.trials_dataframe().to_csv("result/tuning.csv")
study.trials_dataframe().to_csv("result/tuning/mmnist/trials.csv")

0 comments on commit fca6473

Please sign in to comment.