diff --git a/supervised/utils/metric.py b/supervised/utils/metric.py index 0d1d3baa..e3ca22fd 100644 --- a/supervised/utils/metric.py +++ b/supervised/utils/metric.py @@ -57,7 +57,7 @@ def negative_f1(y_true, y_predicted, sample_weight=None): y_predicted = y_predicted.ravel() average = None - if len(y_predicted.shape) == 1 or (len(y_predicted.shape) == 2 and y_predicted.shape[1] == 1): + if len(y_predicted.shape) == 1: y_predicted = (y_predicted > 0.5).astype(int) average = "binary" else: