-
Notifications
You must be signed in to change notification settings - Fork 0
/
fairtabddpm_opt.py
101 lines (85 loc) · 2.97 KB
/
fairtabddpm_opt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import optuna
import shutil
import warnings
import argparse
import subprocess
import lib
from constant import EXPS_PATH
import numpy as np
warnings.filterwarnings('ignore')
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='adult')
parser.add_argument('--n_trials', type=int, default=20)
parser.add_argument('--gpu_id', type=int, default=0)
args = parser.parse_args()
dataset = args.dataset
n_trials = args.n_trials
study = optuna.create_study(
direction='maximize',
sampler=optuna.samplers.TPESampler(seed=0),
)
base_config_path = f'./args/{dataset}/fairtabddpm/config.toml'
def objective(trial):
lr = trial.suggest_float('lr', 0.00001, 0.003, log=True)
n_epochs = trial.suggest_categorical('n_epochs', [100, 500, 1000])
n_timesteps = trial.suggest_categorical('n_timesteps', [100, 1000])
base_config = lib.load_config(base_config_path)
exp_name = 'many-exps'
exp_dir = os.path.join(
base_config['exp']['home'],
base_config['data']['name'],
base_config['exp']['method'],
exp_name,
)
os.makedirs(exp_dir, exist_ok=True)
base_config['train']['lr'] = lr
base_config['train']['n_epochs'] = n_epochs
base_config['model']['n_timesteps'] = n_timesteps
base_config['exp']['device'] = f'cuda:{args.gpu_id}'
trial.set_user_attr('config', base_config)
lib.write_config(base_config, f'{exp_dir}/config.toml')
subprocess.run(
[
'python3.10',
'fairtabddpm_run.py',
'--config',
f'{exp_dir}/config.toml',
'--exp_name',
exp_name,
],
check=True,
)
report_path = f'{exp_dir}/metric.json'
report = lib.load_json(report_path)
score = np.mean(report['CatBoost']['AUC']['Train'])
return score
study.optimize(objective, n_trials=n_trials, show_progress_bar=True)
best_config_dir = os.path.join(EXPS_PATH, dataset, 'fairtabddpm', 'best')
os.makedirs(best_config_dir, exist_ok=True)
best_config_path = os.path.join(best_config_dir, 'config.toml')
best_config = study.best_trial.user_attrs['config']
lib.write_config(best_config, best_config_path)
lib.write_json(optuna.importance.get_param_importances(study), os.path.join(best_config_dir, 'importance.json'))
subprocess.run(
[
'python3.10',
'fairtabddpm_run.py',
'--exp_name',
'best',
'--config',
f'{best_config_path}',
],
check=True,
)
shutil.rmtree(
os.path.join(
EXPS_PATH,
dataset,
'fairtabddpm',
'many-exps',
),
)
if __name__ == '__main__':
main()