Skip to content

Commit

Permalink
tune tuning script
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 16, 2024
1 parent 92510b5 commit deabae5
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 18 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ outputs/
results/
data/
playground/
logs/
scripts/

*.out
*.d000*

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
37 changes: 19 additions & 18 deletions tuning/mmnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def objective(trial):
}

dataloader_option = BasicDataLoaderOption(
batch_size=96,
batch_size=256,
train_val_ratio=0.8,
dataset=dataset_option,
transform_order_train=["min_max_normalization"],
Expand Down Expand Up @@ -94,22 +94,22 @@ def objective(trial):
"rnn1d",
"tsn1d",
]:
if phase == "all":
pred_diff = False
else:
pred_diff = trial.suggest_categorical("pred_diff", [True, False])
# if phase == "all":
# pred_diff = False
# else:
# pred_diff = trial.suggest_categorical("pred_diff", [True, False])
motion_encoder_name = args.motion_encoder_name
elif phase == "all":
pred_diff = False
# pred_diff = False
motion_encoder_name = trial.suggest_categorical(
"motion_encoder_all", ["conv2d", "normal1d", "rnn1d"]
)
else:
pred_diff = trial.suggest_categorical("pred_diff", [True, False])
# pred_diff = trial.suggest_categorical("pred_diff", [True, False])
motion_encoder_name = trial.suggest_categorical(
"motion_encoder", ["conv2d", "guided1d", "normal1d", "rnn1d", "tsn1d"]
)
motion_encoder_num_layers = trial.suggest_int("motion_encoder_num_layers", 0, 3)
motion_encoder_num_layers = trial.suggest_int("motion_encoder_num_layers", 0, 6)
if motion_encoder_name == "rnn1d":
if args.rnn_name in [
"conv_lstm1d",
Expand All @@ -122,18 +122,19 @@ def objective(trial):
"rnn", ["conv_lstm1d", "gru1d", "tcn1d"]
)
motion_encoder_name = f"{motion_encoder_name}/{rnn_name}"
rnn_num_layers = trial.suggest_int("rnn_num_layers", 1, 5)
if rnn_name == "conv_lstm1d":
rnn_option = ConvLSTM1dOption(
num_layers=trial.suggest_int("rnn_num_layers", 2, 4),
num_layers=rnn_num_layers,
)
elif rnn_name == "gru1d":
rnn_option = GRU1dOption(
num_layers=trial.suggest_int("rnn_num_layers", 2, 4),
num_layers=rnn_num_layers,
image_size=8,
)
elif rnn_name == "tcn1d":
rnn_option = TCN1dOption(
num_layers=trial.suggest_int("rnn_num_layers", 2, 4),
num_layers=rnn_num_layers,
image_size=8,
kernel_size=3,
dropout=0.1,
Expand Down Expand Up @@ -187,9 +188,9 @@ def objective(trial):
)
else:
raise RuntimeError("unreachable")
latent_dim = trial.suggest_int("latent_dim", 16, 64, step=8)
latent_dim = trial.suggest_int("latent_dim", 16, 96, step=8)
content_encoder_num_layers = trial.suggest_int(
"content_encoder_encoder_num_layers", 0, 3
"content_encoder_num_layers", 0, 4
)
if network_name == "rae2d":
network_option = RAE2dOption(
Expand Down Expand Up @@ -234,17 +235,17 @@ def objective(trial):
raise RuntimeError("unreachable")

optimizer_option = AdamOptimizerOption(
lr=trial.suggest_float("lr", 1e-5, 1e-2, log=True),
lr=trial.suggest_float("lr", 1e-5, 5e-2, log=True),
)

scheduler_option = OneCycleLRSchedulerOption(
max_lr=trial.suggest_float("max_lr", 1e-3, 1e-2, log=True),
max_lr=trial.suggest_float("max_lr", 1e-3, 5e-2, log=True),
)

model_option = VRModelOption(
loss_coef={"wmse": 1.0},
phase=phase,
pred_diff=pred_diff,
pred_diff=False,
loss=loss_option,
network=network_option,
optimizer=optimizer_option,
Expand All @@ -258,7 +259,7 @@ def objective(trial):
result_dir=result_dir,
dataloader=dataloader_option,
model=model_option,
n_epoch=50,
n_epoch=150,
)
result_dir.mkdir(parents=True, exist_ok=True)
with open(result_dir / "config.json", "w") as f:
Expand Down Expand Up @@ -288,7 +289,7 @@ def objective(trial):
import optuna

parser = argparse.ArgumentParser()
parser.add_argument("--network_name", type=str)
parser.add_argument("--network_name", type=str, default="hrdae2d")
parser.add_argument("--motion_encoder_name", type=str)
parser.add_argument("--rnn_name", type=str)
args = parser.parse_args()
Expand Down

0 comments on commit deabae5

Please sign in to comment.