Skip to content

Commit ad49896

Browse files
committed
add accuracy eval_metric (#362)
1 parent 8f6d7c1 commit ad49896

File tree

8 files changed

+100
-10
lines changed

8 files changed

+100
-10
lines changed

supervised/algorithms/catboost.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@ def catboost_eval_metric(ml_task, eval_metric):
3535
"logloss": "Logloss",
3636
"f1": "F1",
3737
"average_precision": "average_precision",
38+
"accuracy": "Accuracy",
3839
},
3940
MULTICLASS_CLASSIFICATION: {
4041
"logloss": "MultiClass",
4142
"f1": "TotalF1:average=Micro",
43+
"accuracy": "Accuracy",
4244
},
4345
REGRESSION: {
4446
"rmse": "RMSE",
@@ -245,7 +247,7 @@ def fit(
245247
model_init.evals_result_["validation"].get(self.log_metric_name)
246248
+ validation_scores
247249
)
248-
iteration = None
250+
iteration = None
249251
if train_scores is not None:
250252
iteration = range(len(validation_scores))
251253
elif validation_scores is not None:
@@ -314,6 +316,8 @@ def get_metric_name(self):
314316
return "mape"
315317
elif metric in ["F1", "TotalF1:average=Micro"]:
316318
return "f1"
319+
elif metric == "Accuracy":
320+
return "accuracy"
317321
return metric
318322

319323

supervised/algorithms/lightgbm.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
lightgbm_eval_metric_pearson,
2121
lightgbm_eval_metric_f1,
2222
lightgbm_eval_metric_average_precision,
23+
lightgbm_eval_metric_accuracy,
2324
)
2425
from supervised.utils.config import LOG_LEVEL
2526

@@ -45,8 +46,13 @@ def lightgbm_eval_metric(ml_task, automl_eval_metric):
4546
"logloss": "binary_logloss",
4647
"f1": "custom",
4748
"average_precision": "custom",
49+
"accuracy": "custom",
50+
},
51+
MULTICLASS_CLASSIFICATION: {
52+
"logloss": "multi_logloss",
53+
"f1": "custom",
54+
"accuracy": "custom",
4855
},
49-
MULTICLASS_CLASSIFICATION: {"logloss": "multi_logloss", "f1": "custom"},
5056
REGRESSION: {
5157
"rmse": "rmse",
5258
"mae": "mae",
@@ -60,7 +66,14 @@ def lightgbm_eval_metric(ml_task, automl_eval_metric):
6066
metric = metric_name_mapping[ml_task][automl_eval_metric]
6167
custom_eval_metric = None
6268

63-
if automl_eval_metric in ["r2", "spearman", "pearson", "f1", "average_precision"]:
69+
if automl_eval_metric in [
70+
"r2",
71+
"spearman",
72+
"pearson",
73+
"f1",
74+
"average_precision",
75+
"accuracy",
76+
]:
6477
custom_eval_metric = automl_eval_metric
6578

6679
return metric, custom_eval_metric
@@ -133,6 +146,8 @@ def __init__(self, params):
133146
self.custom_eval_metric = lightgbm_eval_metric_f1
134147
elif self.params["custom_eval_metric_name"] == "average_precision":
135148
self.custom_eval_metric = lightgbm_eval_metric_average_precision
149+
elif self.params["custom_eval_metric_name"] == "accuracy":
150+
self.custom_eval_metric = lightgbm_eval_metric_accuracy
136151

137152
logger.debug("LightgbmLearner __init__")
138153

supervised/algorithms/xgboost.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
xgboost_eval_metric_pearson,
2020
xgboost_eval_metric_f1,
2121
xgboost_eval_metric_average_precision,
22+
xgboost_eval_metric_accuracy,
2223
)
2324
from supervised.utils.config import LOG_LEVEL
2425

@@ -116,6 +117,8 @@ def __init__(self, params):
116117
self.custom_eval_metric = xgboost_eval_metric_f1
117118
elif self.params.get("eval_metric", "") == "average_precision":
118119
self.custom_eval_metric = xgboost_eval_metric_average_precision
120+
elif self.params.get("eval_metric", "") == "accuracy":
121+
self.custom_eval_metric = xgboost_eval_metric_accuracy
119122

120123
self.best_ntree_limit = 0
121124
logger.debug("XgbLearner __init__")
@@ -217,7 +220,14 @@ def fit(
217220
# it a is custom metric
218221
# that is always minimized
219222
# we need to revert it
220-
if metric_name in ["r2", "spearman", "pearson", "f1", "average_precision"]:
223+
if metric_name in [
224+
"r2",
225+
"spearman",
226+
"pearson",
227+
"f1",
228+
"average_precision",
229+
"accuracy",
230+
]:
221231
result["train"] *= -1.0
222232
result["validation"] *= -1.0
223233

supervised/base_automl.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,18 +1780,19 @@ def _validate_eval_metric(self):
17801780
"auc",
17811781
"f1",
17821782
"average_precision",
1783+
"accuracy"
17831784
]:
17841785
raise ValueError(
17851786
f"Metric {self.eval_metric} is not allowed in ML task: {self._get_ml_task()}. \
1786-
Use 'logloss', 'auc', 'f1', or 'average_precision'"
1787+
Use 'logloss', 'auc', 'f1', 'average_precision', or 'accuracy'"
17871788
)
17881789

17891790
elif (
17901791
self._get_ml_task() == MULTICLASS_CLASSIFICATION
1791-
) and self.eval_metric not in ["logloss", "f1"]:
1792+
) and self.eval_metric not in ["logloss", "f1", "accuracy"]:
17921793
raise ValueError(
17931794
f"Metric {self.eval_metric} is not allowed in ML task: {self._get_ml_task()}. \
1794-
Use 'logloss', or 'f1'"
1795+
Use 'logloss', 'f1', or 'accuracy'"
17951796
)
17961797

17971798
elif self._get_ml_task() == REGRESSION and self.eval_metric not in [

supervised/tuner/optuna/lightgbm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
lightgbm_eval_metric_pearson,
1010
lightgbm_eval_metric_f1,
1111
lightgbm_eval_metric_average_precision,
12+
lightgbm_eval_metric_accuracy
1213
)
1314
from supervised.algorithms.registry import BINARY_CLASSIFICATION
1415
from supervised.algorithms.registry import MULTICLASS_CLASSIFICATION
@@ -82,6 +83,8 @@ def __init__(
8283
self.custom_eval_metric = lightgbm_eval_metric_f1
8384
elif self.eval_metric.name == "average_precision":
8485
self.custom_eval_metric = lightgbm_eval_metric_average_precision
86+
elif self.eval_metric.name == "accuracy":
87+
self.custom_eval_metric = lightgbm_eval_metric_accuracy
8588

8689
self.num_class = (
8790
len(np.unique(y_train)) if ml_task == MULTICLASS_CLASSIFICATION else None

supervised/tuner/optuna/tuner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
"pearson",
3838
"f1",
3939
"average_precision",
40+
"accuracy"
4041
]:
4142
raise AutoMLException(f"Metric {eval_metric.name} is not supported")
4243

supervised/tuner/optuna/xgboost.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
xgboost_eval_metric_pearson,
1010
xgboost_eval_metric_f1,
1111
xgboost_eval_metric_average_precision,
12+
xgboost_eval_metric_accuracy
1213
)
1314
from supervised.algorithms.registry import BINARY_CLASSIFICATION
1415
from supervised.algorithms.registry import MULTICLASS_CLASSIFICATION
@@ -67,6 +68,8 @@ def __init__(
6768
self.custom_eval_metric = xgboost_eval_metric_f1
6869
elif self.eval_metric_name == "average_precision":
6970
self.custom_eval_metric = xgboost_eval_metric_average_precision
71+
elif self.eval_metric_name == "accuracy":
72+
self.custom_eval_metric = xgboost_eval_metric_accuracy
7073

7174
def __call__(self, trial):
7275
param = {

supervised/utils/metric.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from sklearn.metrics import mean_squared_log_error
1616
from sklearn.metrics import f1_score
1717
from sklearn.metrics import average_precision_score
18+
from sklearn.metrics import accuracy_score
1819

1920

2021
def logloss(y_true, y_predicted, sample_weight=None):
@@ -52,8 +53,11 @@ def negative_f1(y_true, y_predicted, sample_weight=None):
5253
if isinstance(y_predicted, pd.DataFrame):
5354
y_predicted = np.array(y_predicted)
5455

56+
if len(y_predicted.shape) == 2 and y_predicted.shape[1] == 1:
57+
y_predicted = y_predicted.ravel()
58+
5559
average = None
56-
if len(y_predicted.shape) == 1:
60+
if len(y_predicted.shape) == 1 or (len(y_predicted.shape) == 2 and y_predicted.shape[1] == 1):
5761
y_predicted = (y_predicted > 0.5).astype(int)
5862
average = "binary"
5963
else:
@@ -65,6 +69,26 @@ def negative_f1(y_true, y_predicted, sample_weight=None):
6569
return -val
6670

6771

72+
def negative_accuracy(y_true, y_predicted, sample_weight=None):
73+
74+
if isinstance(y_true, pd.DataFrame):
75+
y_true = np.array(y_true)
76+
if isinstance(y_predicted, pd.DataFrame):
77+
y_predicted = np.array(y_predicted)
78+
79+
if len(y_predicted.shape) == 2 and y_predicted.shape[1] == 1:
80+
y_predicted = y_predicted.ravel()
81+
82+
if len(y_predicted.shape) == 1:
83+
y_predicted = (y_predicted > 0.5).astype(int)
84+
else:
85+
y_predicted = np.argmax(y_predicted, axis=1)
86+
87+
val = accuracy_score(y_true, y_predicted, sample_weight=sample_weight)
88+
89+
return -val
90+
91+
6892
def negative_average_precision(y_true, y_predicted, sample_weight=None):
6993

7094
if isinstance(y_true, pd.DataFrame):
@@ -137,6 +161,15 @@ def xgboost_eval_metric_average_precision(preds, dtrain):
137161
return "average_precision", negative_average_precision(target, preds, weight)
138162

139163

164+
def xgboost_eval_metric_accuracy(preds, dtrain):
165+
# Xgboost needs to minimize eval_metric
166+
target = dtrain.get_label()
167+
weight = dtrain.get_weight()
168+
if len(weight) == 0:
169+
weight = None
170+
return "accuracy", negative_accuracy(target, preds, weight)
171+
172+
140173
def lightgbm_eval_metric_r2(preds, dtrain):
141174
target = dtrain.get_label()
142175
weight = dtrain.get_weight()
@@ -159,8 +192,10 @@ def lightgbm_eval_metric_f1(preds, dtrain):
159192

160193
unique_targets = np.unique(target)
161194
if len(unique_targets) > 2:
162-
preds = preds.reshape(-1, len(unique_targets))
163-
195+
cols = len(unique_targets)
196+
rows = int(preds.shape[0] / len(unique_targets))
197+
preds = np.reshape(preds, (rows, cols), order="F")
198+
164199
return "f1", -negative_f1(target, preds, weight), True
165200

166201

@@ -171,6 +206,19 @@ def lightgbm_eval_metric_average_precision(preds, dtrain):
171206
return "average_precision", -negative_average_precision(target, preds, weight), True
172207

173208

209+
def lightgbm_eval_metric_accuracy(preds, dtrain):
210+
target = dtrain.get_label()
211+
weight = dtrain.get_weight()
212+
213+
unique_targets = np.unique(target)
214+
if len(unique_targets) > 2:
215+
cols = len(unique_targets)
216+
rows = int(preds.shape[0] / len(unique_targets))
217+
preds = np.reshape(preds, (rows, cols), order="F")
218+
219+
return "accuracy", -negative_accuracy(target, preds, weight), True
220+
221+
174222
class CatBoostEvalMetricSpearman(object):
175223
def get_final_error(self, error, weight):
176224
return error
@@ -244,6 +292,7 @@ def __init__(self, params):
244292
"pearson", # negative
245293
"f1", # negative
246294
"average_precision", # negative
295+
"accuracy", # negative
247296
]
248297
if self.name == "logloss":
249298
self.metric = logloss
@@ -269,6 +318,8 @@ def __init__(self, params):
269318
self.metric = negative_f1
270319
elif self.name == "average_precision":
271320
self.metric = negative_average_precision
321+
elif self.name == "accuracy":
322+
self.metric = negative_accuracy
272323
# elif self.name == "rmsle": # need to update target preprocessing
273324
# self.metric = rmsle # to assure that target is not negative ...
274325
else:
@@ -304,6 +355,7 @@ def is_negative(self):
304355
"pearson",
305356
"f1",
306357
"average_precision",
358+
"accuracy",
307359
]
308360

309361
@staticmethod
@@ -315,4 +367,5 @@ def optimize_negative(metric_name):
315367
"pearson",
316368
"f1",
317369
"average_precision",
370+
"accuracy",
318371
]

0 commit comments

Comments
 (0)