Skip to content

Commit

Permalink
mmnist -> ct
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 16, 2024
1 parent bb219d7 commit 964809c
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tuning/ct.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def objective(trial):
)

result_dir = Path(
f"results/tuning/mmnist/{network_name}/{motion_encoder_name}/{trial.number}-{str(uuid4())[:8]}"
f"results/tuning/ct/{network_name}/{motion_encoder_name}/{trial.number}-{str(uuid4())[:8]}"
)
train_option = TrainExpOption(
result_dir=result_dir,
Expand Down Expand Up @@ -319,7 +319,7 @@ def objective(trial):
parser.add_argument("--rnn_name", type=str)
args = parser.parse_args()

study_name = "mmnist"
study_name = "ct"
if args.network_name is not None:
assert args.network_name in [
"hrdae3d",
Expand All @@ -345,11 +345,11 @@ def objective(trial):
study_name += f"_{args.rnn_name}"
study = optuna.create_study(
study_name=study_name,
storage="sqlite:///results/tuning/mmnist/sqlite.db",
storage="sqlite:///results/tuning/ct/sqlite.db",
load_if_exists=True,
)
study.optimize(objective, n_trials=500)
print(study.best_params)
print(study.best_value)
print(study.best_trial)
study.trials_dataframe().to_csv("results/tuning/mmnist/trials.csv")
study.trials_dataframe().to_csv("results/tuning/ct/trials.csv")

0 comments on commit 964809c

Please sign in to comment.