Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 16, 2024
1 parent 3f6cdc0 commit 4a6cf7f
Showing 1 changed file with 67 additions and 23 deletions.
90 changes: 67 additions & 23 deletions tuning/mmnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
from hrdae.models.schedulers import OneCycleLRSchedulerOption
from hrdae.models.networks import RAE2dOption, RDAE2dOption, HRDAE2dOption
from hrdae.models.networks.rnn import ConvLSTM1dOption, GRU1dOption, TCN1dOption
from hrdae.models.networks.motion_encoder import MotionRNNEncoder1dOption, MotionConv2dEncoder1dOption, MotionGuidedEncoder1dOption, MotionNormalEncoder1dOption, MotionTSNEncoder1dOption
from hrdae.models.networks.motion_encoder import (
MotionRNNEncoder1dOption,
MotionConv2dEncoder1dOption,
MotionGuidedEncoder1dOption,
MotionNormalEncoder1dOption,
MotionTSNEncoder1dOption,
)


def interleave_arrays(
Expand All @@ -30,7 +36,7 @@ def interleave_arrays(
result.append(an[i])
if i < len(an) - 1:
num_elements = per_slot + (1 if extra > 0 else 0)
result.extend(am[am_index:am_index + num_elements])
result.extend(am[am_index : am_index + num_elements])
am_index += num_elements
if extra > 0:
extra -= 1
Expand Down Expand Up @@ -71,28 +77,50 @@ def objective(trial):
}

phase = trial.suggest_categorical("phase", ["0", "t", "all"])
if hasattr(args, "network_name") and args.network_name in ["hrdae2d", "rae2d", "rdae2d"]:
if hasattr(args, "network_name") and args.network_name in [
"hrdae2d",
"rae2d",
"rdae2d",
]:
network_name = args.network_name
else:
network_name = trial.suggest_categorical("network", ["hrdae2d", "rae2d", "rdae2d"])
if hasattr(args, "motion_encoder_name") and args.motion_encoder_name in ["conv2d", "guided1d", "normal1d", "rnn1d", "tsn1d"]:
network_name = trial.suggest_categorical(
"network", ["hrdae2d", "rae2d", "rdae2d"]
)
if hasattr(args, "motion_encoder_name") and args.motion_encoder_name in [
"conv2d",
"guided1d",
"normal1d",
"rnn1d",
"tsn1d",
]:
if phase == "all":
pred_diff = False
else:
pred_diff = trial.suggest_categorical("pred_diff", [True, False])
motion_encoder_name = args.motion_encoder_name
elif phase == "all":
pred_diff = False
motion_encoder_name = trial.suggest_categorical("motion_encoder_all", ["conv2d", "normal1d", "rnn1d"])
motion_encoder_name = trial.suggest_categorical(
"motion_encoder_all", ["conv2d", "normal1d", "rnn1d"]
)
else:
pred_diff = trial.suggest_categorical("pred_diff", [True, False])
motion_encoder_name = trial.suggest_categorical("motion_encoder", ["conv2d", "guided1d", "normal1d", "rnn1d", "tsn1d"])
motion_encoder_name = trial.suggest_categorical(
"motion_encoder", ["conv2d", "guided1d", "normal1d", "rnn1d", "tsn1d"]
)
motion_encoder_num_layers = trial.suggest_int("motion_encoder_num_layers", 0, 3)
if motion_encoder_name == "rnn1d":
if hasattr(args, "rnn_name") and args.rnn_name in ["conv_lstm1d", "gru1d", "tcn1d"]:
if hasattr(args, "rnn_name") and args.rnn_name in [
"conv_lstm1d",
"gru1d",
"tcn1d",
]:
rnn_name = args.rnn_name
else:
rnn_name = trial.suggest_categorical("rnn", ["conv_lstm1d", "gru1d", "tcn1d"])
rnn_name = trial.suggest_categorical(
"rnn", ["conv_lstm1d", "gru1d", "tcn1d"]
)
motion_encoder_name = f"{motion_encoder_name}/{rnn_name}"
if rnn_name == "conv_lstm1d":
rnn_option = ConvLSTM1dOption(
Expand All @@ -116,7 +144,8 @@ def objective(trial):
in_channels=7,
conv_params=interleave_arrays(
[{"kernel_size": [3], "stride": [2], "padding": [1]}] * 3,
[{"kernel_size": [3], "stride": [1], "padding": [1]}] * motion_encoder_num_layers,
[{"kernel_size": [3], "stride": [1], "padding": [1]}]
* motion_encoder_num_layers,
),
rnn=rnn_option,
)
Expand All @@ -125,43 +154,50 @@ def objective(trial):
in_channels=7,
conv_params=interleave_arrays(
[{"kernel_size": [3], "stride": [1, 2], "padding": [1]}] * 3,
[{"kernel_size": [3], "stride": [1], "padding": [1]}] * motion_encoder_num_layers,
[{"kernel_size": [3], "stride": [1], "padding": [1]}]
* motion_encoder_num_layers,
),
)
elif motion_encoder_name == "guided1d":
motion_encoder_option = MotionGuidedEncoder1dOption(
in_channels=7,
conv_params=interleave_arrays(
[{"kernel_size": [3], "stride": [2], "padding": [1]}] * 3,
[{"kernel_size": [3], "stride": [1], "padding": [1]}] * motion_encoder_num_layers,
[{"kernel_size": [3], "stride": [1], "padding": [1]}]
* motion_encoder_num_layers,
),
)
elif motion_encoder_name == "normal1d":
motion_encoder_option = MotionNormalEncoder1dOption(
in_channels=7,
conv_params=interleave_arrays(
[{"kernel_size": [3], "stride": [2], "padding": [1]}] * 3,
[{"kernel_size": [3], "stride": [1], "padding": [1]}] * motion_encoder_num_layers,
),
[{"kernel_size": [3], "stride": [1], "padding": [1]}]
* motion_encoder_num_layers,
),
)
elif motion_encoder_name == "tsn1d":
motion_encoder_option = MotionTSNEncoder1dOption(
in_channels=7,
conv_params=interleave_arrays(
[{"kernel_size": [3], "stride": [2], "padding": [1]}] * 3,
[{"kernel_size": [3], "stride": [1], "padding": [1]}] * motion_encoder_num_layers,
[{"kernel_size": [3], "stride": [1], "padding": [1]}]
* motion_encoder_num_layers,
),
)
else:
raise RuntimeError("unreachable")
latent_dim = trial.suggest_int("latent_dim", 16, 64, step=8)
content_encoder_num_layers = trial.suggest_int("content_encoder_encoder_num_layers", 0, 3)
content_encoder_num_layers = trial.suggest_int(
"content_encoder_encoder_num_layers", 0, 3
)
if network_name == "rae2d":
network_option = RAE2dOption(
latent_dim=latent_dim,
conv_params=interleave_arrays(
[{"kernel_size": [3], "stride": [2], "padding": [1]}] * 3,
[{"kernel_size": [3], "stride": [1], "padding": [1]}] * content_encoder_num_layers,
[{"kernel_size": [3], "stride": [1], "padding": [1]}]
* content_encoder_num_layers,
),
motion_encoder=motion_encoder_option,
upsample_size=[8, 8],
Expand All @@ -171,22 +207,28 @@ def objective(trial):
latent_dim=latent_dim,
conv_params=interleave_arrays(
[{"kernel_size": [3], "stride": [2], "padding": [1]}] * 3,
[{"kernel_size": [3], "stride": [1], "padding": [1]}] * content_encoder_num_layers,
[{"kernel_size": [3], "stride": [1], "padding": [1]}]
* content_encoder_num_layers,
),
motion_encoder=motion_encoder_option,
upsample_size=[8, 8],
aggregation_method=trial.suggest_categorical("aggregation_method", ["concat", "sum"]),
aggregation_method=trial.suggest_categorical(
"aggregation_method", ["concat", "sum"]
),
)
elif network_name == "hrdae2d":
network_option = HRDAE2dOption(
latent_dim=latent_dim,
conv_params=interleave_arrays(
[{"kernel_size": [3], "stride": [2], "padding": [1]}] * 3,
[{"kernel_size": [3], "stride": [1], "padding": [1]}] * content_encoder_num_layers,
),
[{"kernel_size": [3], "stride": [1], "padding": [1]}]
* content_encoder_num_layers,
),
motion_encoder=motion_encoder_option,
upsample_size=[8, 8],
aggregation_method=trial.suggest_categorical("aggregation_method", ["concat", "sum"]),
aggregation_method=trial.suggest_categorical(
"aggregation_method", ["concat", "sum"]
),
)
else:
raise RuntimeError("unreachable")
Expand All @@ -209,7 +251,9 @@ def objective(trial):
scheduler=scheduler_option,
)

result_dir = Path(f"results/tuning/mmnist/{network_name}/{motion_encoder_name}/{trial.number}-{uuid4()[:8]}")
result_dir = Path(
f"results/tuning/mmnist/{network_name}/{motion_encoder_name}/{trial.number}-{str(uuid4())[:8]}"
)
train_option = TrainExpOption(
result_dir=result_dir,
dataloader=dataloader_option,
Expand Down

0 comments on commit 4a6cf7f

Please sign in to comment.