-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy pathdo_xgboost.py
executable file
·70 lines (56 loc) · 1.65 KB
/
do_xgboost.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python
"""
Train a gradient boosting classifier on the airline dataset using
XGBoost's python API.
"""
import argparse
import numpy as np
import pandas as pd
from IPython import embed
import pickle
import scipy
import xgboost as xgb
from xgboost import XGBClassifier
from sklearn import metrics
from sklearn.ensemble import RandomForestClassifier
from matplotlib import pyplot as plt
import seaborn as sns
FLAGS = None
def train_and_predict(X_train, y_train, X_test, y_test, **kwargs):
"""Run training and evaluation using xgboost."""
bst = XGBClassifier(
max_depth=FLAGS.depth,
learning_rate=FLAGS.learning_rate,
n_estimators=FLAGS.num_trees,
silent=False,
objective='binary:logistic',
nthread=-1,
seed=42,
)
bst.fit(X_train, y_train)
# pickle.dump(bst, open('xgboost.pickle', 'wb'))
y_pred = bst.predict_proba(X_test)[:, 1]
# Save predictions
np.save(
'outputs/pred_xgb_t{:03d}_d{:02d}.npy'.format(FLAGS.num_trees, FLAGS.depth),
y_pred)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--num_trees",
type=int,
default=10,
help="Number of trees to grow before stopping.")
parser.add_argument(
"--depth",
type=int,
default=6,
help="Maximum depth of weak learners.")
parser.add_argument(
"--learning_rate",
type=float,
default=0.1,
help="Learning rate (shrinkage weight) with which each new tree is added.")
FLAGS = parser.parse_args()
data = np.load('airlines_data.npz')
train_and_predict(**data)