Skip to content

Commit

Permalink
fix tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 16, 2024
1 parent 4a6cf7f commit 87cfe49
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions tuning/mmnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def objective(trial):
}

phase = trial.suggest_categorical("phase", ["0", "t", "all"])
if hasattr(args, "network_name") and args.network_name in [
if args.network_name in [
"hrdae2d",
"rae2d",
"rdae2d",
Expand All @@ -87,7 +87,7 @@ def objective(trial):
network_name = trial.suggest_categorical(
"network", ["hrdae2d", "rae2d", "rdae2d"]
)
if hasattr(args, "motion_encoder_name") and args.motion_encoder_name in [
if args.motion_encoder_name in [
"conv2d",
"guided1d",
"normal1d",
Expand All @@ -111,7 +111,7 @@ def objective(trial):
)
motion_encoder_num_layers = trial.suggest_int("motion_encoder_num_layers", 0, 3)
if motion_encoder_name == "rnn1d":
if hasattr(args, "rnn_name") and args.rnn_name in [
if args.rnn_name in [
"conv_lstm1d",
"gru1d",
"tcn1d",
Expand Down Expand Up @@ -294,11 +294,28 @@ def objective(trial):
args = parser.parse_args()

study_name = "mmnist"
if hasattr(args, "network_name"):
if args.network_name is not None:
assert args.network_name in [
"hrdae2d",
"rae2d",
"rdae2d",
]
study_name += f"_{args.network_name}"
if hasattr(args, "motion_encoder_name"):
if args.motion_encoder_name is not None:
assert args.motion_encoder_name in [
"conv2d",
"guided1d",
"normal1d",
"rnn1d",
"tsn1d",
]
study_name += f"_{args.motion_encoder_name}"
if hasattr(args, "rnn_name"):
if args.rnn_name is not None:
assert args.rnn_name in [
"conv_lstm1d",
"gru1d",
"tcn1d",
]
study_name += f"_{args.rnn_name}"
study = optuna.create_study(
study_name=study_name,
Expand Down

0 comments on commit 87cfe49

Please sign in to comment.