Skip to content

Commit

Permalink
Fix handling single class in chunk for CBPE (#384)
Browse files Browse the repository at this point in the history
* Fix handling single class in CBPE fitting

The `confusion_matrix` function used in various CBPE metrics returns
values for each class/label present in the input. For binary
classification this means we expect 4 values (TP, FP, FN, TN). However
if only one class is represented in the input, the function will only
return one value.

This commit addresses that failure case by explicitly providing the
expected labels to the `confusion_matrix` function. Currently these
values are hard-coded for binary classification, but we may want to
derive them from the input later on if we were to support string-based
pass/fail classes.

* Add test case for single class CBPE fitting

* Fix F1 sampling error when no positive cases
  • Loading branch information
michael-nml authored May 2, 2024
1 parent 13ace29 commit d064916
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 6 deletions.
29 changes: 23 additions & 6 deletions nannyml/performance_estimation/confidence_based/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,9 @@ def __init__(
# sampling error
self._sampling_error_components: Tuple = ()

# Set labels expected in y_true/y_pred. Currently hard-coded to 0, 1 for binary classification
self._labels = [0, 1]

def _fit(self, reference_data: pd.DataFrame):
self._sampling_error_components = bse.specificity_sampling_error_components(
y_true_reference=reference_data[self.y_true],
Expand Down Expand Up @@ -1039,7 +1042,7 @@ def _realized_performance(self, data: pd.DataFrame) -> float:
warnings.warn(f"Not enough data to compute estimated {self.display_name}.")
return np.NaN
y_pred, y_true = _dat
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=self._labels).ravel()
denominator = tn + fp
if denominator == 0:
return np.NaN
Expand Down Expand Up @@ -1213,6 +1216,9 @@ def __init__(
self.false_negative_lower_threshold: Optional[float] = None
self.false_negative_upper_threshold: Optional[float] = None

# Set labels expected in y_true/y_pred. Currently hard-coded to 0, 1 for binary classification
self._labels = [0, 1]

def fit(self, reference_data: pd.DataFrame): # override the superclass fit method
"""Fits a Metric on reference data.
Expand Down Expand Up @@ -1348,7 +1354,9 @@ def _true_positive_realized_performance(self, data: pd.DataFrame) -> float:
return np.NaN
y_pred, y_true = _dat

_, _, _, tp = confusion_matrix(y_true, y_pred, normalize=self.normalize_confusion_matrix).ravel()
_, _, _, tp = confusion_matrix(
y_true, y_pred, labels=self._labels, normalize=self.normalize_confusion_matrix
).ravel()
return tp

def _true_negative_realized_performance(self, data: pd.DataFrame) -> float:
Expand All @@ -1368,7 +1376,9 @@ def _true_negative_realized_performance(self, data: pd.DataFrame) -> float:

y_pred, y_true = _dat

tn, _, _, _ = confusion_matrix(y_true, y_pred, normalize=self.normalize_confusion_matrix).ravel()
tn, _, _, _ = confusion_matrix(
y_true, y_pred, labels=self._labels, normalize=self.normalize_confusion_matrix
).ravel()
return tn

def _false_positive_realized_performance(self, data: pd.DataFrame) -> float:
Expand All @@ -1387,7 +1397,9 @@ def _false_positive_realized_performance(self, data: pd.DataFrame) -> float:
return np.NaN
y_pred, y_true = _dat

_, fp, _, _ = confusion_matrix(y_true, y_pred, normalize=self.normalize_confusion_matrix).ravel()
_, fp, _, _ = confusion_matrix(
y_true, y_pred, labels=self._labels, normalize=self.normalize_confusion_matrix
).ravel()
return fp

def _false_negative_realized_performance(self, data: pd.DataFrame) -> float:
Expand All @@ -1406,7 +1418,9 @@ def _false_negative_realized_performance(self, data: pd.DataFrame) -> float:
return np.NaN
y_pred, y_true = _dat

_, _, fn, _ = confusion_matrix(y_true, y_pred, normalize=self.normalize_confusion_matrix).ravel()
_, _, fn, _ = confusion_matrix(
y_true, y_pred, labels=self._labels, normalize=self.normalize_confusion_matrix
).ravel()
return fn

def get_true_positive_estimate(self, chunk_data: pd.DataFrame) -> float:
Expand Down Expand Up @@ -1907,6 +1921,9 @@ def __init__(
self.business_value_matrix = business_value_matrix
self.normalize_business_value: Optional[str] = normalize_business_value

# Set labels expected in y_true/y_pred. Currently hard-coded to 0, 1 for binary classification
self._labels = [0, 1]

# self.lower_threshold: Optional[float] = 0
# self.upper_threshold: Optional[float] = 1

Expand Down Expand Up @@ -1940,7 +1957,7 @@ def _realized_performance(self, data: pd.DataFrame) -> float:
fn_value = self.business_value_matrix[1, 0]
bv_array = np.array([[tn_value, fp_value], [fn_value, tp_value]])

cm = confusion_matrix(y_true, y_pred)
cm = confusion_matrix(y_true, y_pred, labels=self._labels)
if self.normalize_business_value == 'per_prediction':
with np.errstate(all="ignore"):
cm = cm / cm.sum(axis=0, keepdims=True)
Expand Down
4 changes: 4 additions & 0 deletions nannyml/sampling_error/binary_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ def f1_sampling_error_components(y_true_reference: pd.Series, y_pred_reference:

tp_fp_fn = np.concatenate([TP, FN, FP])

# If there's no true positives, false negatives or false positives, sampling error is NaN
if tp_fp_fn.size == 0:
return np.nan, 0

correcting_factor = len(tp_fp_fn) / ((len(FN) + len(FP)) * 0.5 + len(TP))
obs_level_f1 = tp_fp_fn * correcting_factor
fraction_of_relevant = len(tp_fp_fn) / len(y_pred_reference)
Expand Down
28 changes: 28 additions & 0 deletions tests/performance_estimation/CBPE/test_cbpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,3 +721,31 @@ def test_cbpe_without_predictions():
_ = cbpe.estimate(ana_df)
except Exception as exc:
pytest.fail(f'unexpected exception: {exc}')


@pytest.mark.filterwarnings("ignore:Too few unique values", "ignore:'y_true' contains a single class")
def test_cbpe_fitting_does_not_generate_error_when_single_class_present():
ref_df = pd.DataFrame({
'y_true': [0] * 1000,
'y_pred': [0] * 1000,
'y_pred_proba': [0.5] * 1000,
})
sut = CBPE(
y_true='y_true',
y_pred='y_pred',
y_pred_proba='y_pred_proba',
problem_type='classification_binary',
metrics=[
'roc_auc',
'f1',
'precision',
'recall',
'specificity',
'accuracy',
'confusion_matrix',
'business_value',
],
chunk_size=100,
business_value_matrix=[[1, -1], [-1, 1]]
)
sut.fit(ref_df)

0 comments on commit d064916

Please sign in to comment.