-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path99_tuning.py
67 lines (57 loc) · 2.18 KB
/
99_tuning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import logging
import optuna
import pandas as pd
import xgboost as xgb
import yaml
from sklearn.metrics import roc_auc_score
class Tuning:
def __init__(self, n_trials, X_train, y_train, X_val, y_val):
self.n_trials = n_trials
self.X_train = X_train
self.y_train = y_train
self.X_val = X_val
self.y_val = y_val
def objective(self, trial):
params = {
"n_estimators": trial.suggest_int("n_estimators", 30, 200),
"early_stopping_rounds": trial.suggest_int("early_stopping_rounds", 5, 20),
"learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.1, log=True),
"max_depth": trial.suggest_int("max_depth", 1, 10),
"subsample": trial.suggest_float("subsample", 0.05, 1.0),
"colsample_bytree": trial.suggest_float("colsample_bytree", 0.05, 1.0),
"min_child_weight": trial.suggest_int("min_child_weight", 1, 20),
}
model = xgb.XGBClassifier(random_state=42, **params)
model.fit(
self.X_train,
self.y_train,
eval_set=[(self.X_val, self.y_val)],
verbose=False,
)
y_score = model.predict_proba(self.X_val)[:, 1]
roc_auc = float(roc_auc_score(self.y_val, y_score))
return roc_auc
def compute(self):
study = optuna.create_study(direction="maximize")
study.optimize(self.objective, n_trials=self.n_trials)
return study.best_params
def main() -> None:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
logging.info("Start")
X_train = pd.read_parquet("data/dataset/X_train.parquet")
y_train = pd.read_parquet("data/dataset/y_train.parquet")
X_val = pd.read_parquet("data/dataset/X_val.parquet")
y_val = pd.read_parquet("data/dataset/y_val.parquet")
logging.info("Tuning")
tuning = Tuning(50, X_train, y_train, X_val, y_val)
best_params = tuning.compute()
best_params["random_state"] = 42
logging.info(f"{best_params=}")
with open("params.yaml", "w") as f:
yaml.dump(best_params, f)
logging.info("End")
if __name__ == "__main__":
main()