Skip to content

Commit

Permalink
Make prediction column optional for performance estimation / calculat…
Browse files Browse the repository at this point in the history
…ion. (#380)

* Make y_pred column optional for estimated and realized performance + tests

* Update data requirements docs

* Fix _list_missing issues + add "run with prediction" tests

* Fix flake8 issues
  • Loading branch information
nnansters authored Apr 29, 2024
1 parent e98cab9 commit a52c06d
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 13 deletions.
5 changes: 3 additions & 2 deletions docs/tutorials/data_requirements.rst
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ The :term:`predicted label<Predicted labels>`, retrieved by interpreting (thresh
In the sample data this is the **y_pred** column.

Required for running :ref:`performance estimation<performance-estimation>` or :ref:`performance calculation<performance-calculation>` on binary classification, multiclass, and regression models.

On binary classification models, it is not required for calculating the **AUROC** and **average precision** metrics.

NannyML Functionality Requirements
----------------------------------
Expand All @@ -190,7 +190,8 @@ You can see those requirements in the table below:
| y_pred_proba | Required (reference and analysis) | | | | | | Required (reference and analysis) |
+--------------+-------------------------------------+-------------------------------------+-------------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+
| y_pred | | Required (reference and analysis) | Required (reference and analysis) | | Required (reference and analysis) | | | | Required (reference and analysis) |
| | | Not needed for ROC_AUC metric | | | Not needed for ROC_AUC metric | | | | |
| | | Not needed for ROC_AUC or | | | Not needed for ROC_AUC or | | | | |
| | | average precision metrics | | | average precision metrics | | | | |
+--------------+-------------------------------------+-------------------------------------+-------------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+
| y_true | Required (reference only) | Required (reference only) | Required (reference and analysis) | | | Required (reference and analysis) | |
+--------------+-------------------------------------+-------------------------------------+-------------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+
Expand Down
28 changes: 25 additions & 3 deletions nannyml/performance_calculation/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def __init__(
self,
metrics: Union[str, List[str]],
y_true: str,
y_pred: str,
problem_type: Union[str, ProblemType],
y_pred: Optional[str] = None,
y_pred_proba: Optional[ModelOutputsType] = None,
timestamp_column_name: Optional[str] = None,
thresholds: Optional[Dict[str, Threshold]] = None,
Expand All @@ -105,8 +105,10 @@ def __init__(
A metric or list of metrics to calculate.
y_true: str
The name of the column containing target values.
y_pred: str
y_pred: Optional[str], default=None
The name of the column containing your model predictions.
This parameter is optional for binary classification cases.
When it is not given, only the ROC AUC and Average Precision metrics are supported.
problem_type: Union[str, ProblemType]
Determines which method to use. Allowed values are:
Expand Down Expand Up @@ -211,7 +213,12 @@ def __init__(
self.problem_type = problem_type

if self.problem_type is not ProblemType.REGRESSION and y_pred_proba is None:
raise InvalidArgumentsException(f"'y_pred_proba' can not be 'None' for problem type {ProblemType.value}")
raise InvalidArgumentsException(
f"'y_pred_proba' can not be 'None' for problem type {self.problem_type.value}"
)

if self.problem_type is not ProblemType.CLASSIFICATION_BINARY and y_pred is None:
raise InvalidArgumentsException(f"'y_pred' can not be 'None' for problem type {self.problem_type.value}")

self.thresholds = DEFAULT_THRESHOLDS
if thresholds:
Expand All @@ -236,6 +243,8 @@ def __init__(
if metric not in SUPPORTED_METRIC_VALUES:
raise InvalidArgumentsException(f"Metric '{metric}' is not supported.")

raise_if_metrics_require_y_pred(metrics, y_pred)

self.metrics: List[Metric] = [
MetricFactory.create(
m,
Expand Down Expand Up @@ -387,3 +396,16 @@ def _create_multilevel_index(metric_names: List[str]):
tuples = chunk_tuples + reconstruction_tuples

return MultiIndex.from_tuples(tuples)


def raise_if_metrics_require_y_pred(metrics: List[str], y_pred: Optional[str]):
"""Raise an exception if metrics require y_pred and y_pred is not set.
Current metrics that require 'y_pred' are:
- roc_auc
- average_precision
"""
metrics_that_need_y_pred = [m for m in metrics if m not in ['roc_auc', 'average_precision']]

if len(metrics_that_need_y_pred) > 0 and y_pred is None:
raise InvalidArgumentsException(f"Metrics '{metrics_that_need_y_pred}' require 'y_pred' to be set.")
2 changes: 1 addition & 1 deletion nannyml/performance_calculation/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
self,
results_data: pd.DataFrame,
problem_type: ProblemType,
y_pred: str,
y_pred: Optional[str],
y_pred_proba: Optional[Union[str, Dict[str, str]]],
y_true: str,
metrics: List[Metric],
Expand Down
34 changes: 29 additions & 5 deletions nannyml/performance_estimation/confidence_based/cbpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ class CBPE(AbstractEstimator):
def __init__(
self,
metrics: Union[str, List[str]],
y_pred: str,
y_pred_proba: ModelOutputsType,
y_true: str,
problem_type: Union[str, ProblemType],
y_pred: Optional[str] = None,
timestamp_column_name: Optional[str] = None,
chunk_size: Optional[int] = None,
chunk_number: Optional[int] = None,
Expand All @@ -103,8 +103,6 @@ def __init__(
- For binary classification, pass a single string refering to the model output column.
- For multiclass classification, pass a dictionary that maps a class string to the column name
model outputs for that class.
y_pred: str
The name of the column containing your model predictions.
timestamp_column_name: str, default=None
The name of the column containing the timestamp of the model prediction.
If not given, plots will not use a time-based x-axis but will use the index of the chunks instead.
Expand All @@ -121,6 +119,8 @@ def __init__(
- `accuracy`
- `confusion_matrix` - only for binary classification tasks
- `business_value` - only for binary classification tasks
y_pred: str
The name of the column containing your model predictions.
chunk_size: int, default=None
Splits the data into chunks containing `chunks_size` observations.
Only one of `chunk_size`, `chunk_number` or `chunk_period` should be given.
Expand Down Expand Up @@ -256,13 +256,18 @@ def __init__(
else:
self.problem_type = problem_type

if self.problem_type is not ProblemType.CLASSIFICATION_BINARY and y_pred is None:
raise InvalidArgumentsException(f"'y_pred' can not be 'None' for problem type {self.problem_type.value}")

self.thresholds = DEFAULT_THRESHOLDS
if thresholds:
self.thresholds.update(**thresholds)

if isinstance(metrics, str):
metrics = [metrics]

raise_if_metrics_require_y_pred(metrics, y_pred)

self.metrics = []
for metric in metrics:
if metric not in SUPPORTED_METRIC_VALUES:
Expand Down Expand Up @@ -341,7 +346,10 @@ def _estimate(self, data: pd.DataFrame, *args, **kwargs) -> Result:
raise InvalidArgumentsException('data contains no rows. Please provide a valid data set.')

if self.problem_type == ProblemType.CLASSIFICATION_BINARY:
_list_missing([self.y_pred, self.y_pred_proba], data)
required_cols = [self.y_pred_proba]
if self.y_pred is not None:
required_cols.append(self.y_pred)
_list_missing(required_cols, list(data.columns))

# We need uncalibrated data to calculate the realized performance on.
# https://github.com/NannyML/nannyml/issues/98
Expand Down Expand Up @@ -414,7 +422,10 @@ def _fit_binary(self, reference_data: pd.DataFrame) -> CBPE:
if reference_data.empty:
raise InvalidArgumentsException('data contains no rows. Please provide a valid data set.')

_list_missing([self.y_true, self.y_pred_proba, self.y_pred], list(reference_data.columns))
required_cols = [self.y_true, self.y_pred_proba]
if self.y_pred is not None:
required_cols.append(self.y_pred)
_list_missing(required_cols, list(reference_data.columns))

# We need uncalibrated data to calculate the realized performance on.
# We need realized performance in threshold calculations.
Expand Down Expand Up @@ -552,3 +563,16 @@ def _calibrate_predicted_probabilities(
calibrated_data[predicted_class_proba_column_names[idx]] = calibrated_probas[:, idx]

return calibrated_data


def raise_if_metrics_require_y_pred(metrics: List[str], y_pred: Optional[str]):
"""Raise an exception if metrics require y_pred and y_pred is not set.
Current metrics that require 'y_pred' are:
- roc_auc
- average_precision
"""
metrics_that_need_y_pred = [m for m in metrics if m not in ['roc_auc', 'average_precision']]

if len(metrics_that_need_y_pred) > 0 and y_pred is None:
raise InvalidArgumentsException(f"Metrics '{metrics_that_need_y_pred}' require 'y_pred' to be set.")
2 changes: 1 addition & 1 deletion nannyml/performance_estimation/confidence_based/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
self,
results_data: pd.DataFrame,
metrics: List[Metric],
y_pred: str,
y_pred: Optional[str],
y_pred_proba: ModelOutputsType,
y_true: str,
chunker: Chunker,
Expand Down
2 changes: 1 addition & 1 deletion nannyml/sampling_error/summary_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
#
# License: Apache Software License 2.0

import warnings
from logging import getLogger
from typing import Tuple

import numpy as np
import pandas as pd
import warnings
from scipy.stats import gaussian_kde, moment

logger = getLogger(__name__)
Expand Down
68 changes: 68 additions & 0 deletions tests/performance_calculation/test_performance_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,58 @@ def test_performance_calculator_create_with_single_or_list_of_metrics(metrics, e
assert [metric.column_name for metric in calc.metrics] == expected


@pytest.mark.parametrize(
'problem',
[
"classification_multiclass",
"regression",
],
)
def test_performance_calculator_create_raises_exception_when_y_pred_not_given_and_problem_type_not_binary_clf(
problem,
):
with pytest.raises(InvalidArgumentsException, match=f"'y_pred' can not be 'None' for problem type {problem}"):
_ = PerformanceCalculator(
timestamp_column_name='timestamp',
y_pred_proba='y_pred_proba',
y_true='y_true',
metrics=['roc_auc', 'f1'],
problem_type=problem,
)


@pytest.mark.parametrize(
'metric, expected',
[
(['roc_auc', 'f1'], "['f1']"),
(['roc_auc', 'f1', 'average_precision', 'precision'], "['f1', 'precision']"),
],
)
def test_performance_calculator_create_without_y_pred_raises_exception_when_metrics_require_it(metric, expected):
with pytest.raises(InvalidArgumentsException, match=expected):
_ = PerformanceCalculator(
timestamp_column_name='timestamp',
y_pred_proba='y_pred_proba',
y_true='y_true',
metrics=metric,
problem_type='classification_binary',
)


@pytest.mark.parametrize('metric', ['roc_auc', 'average_precision'])
def test_performance_calculator_create_without_y_pred_works_when_metrics_dont_require_it(metric):
try:
_ = PerformanceCalculator(
timestamp_column_name='timestamp',
y_pred_proba='y_pred_proba',
y_true='y_true',
metrics=metric,
problem_type='classification_binary',
)
except Exception as exc:
pytest.fail(f'unexpected exception: {exc}')


def test_calculator_fit_should_raise_invalid_args_exception_when_no_target_data_present(data): # noqa: D103, F821
calc = PerformanceCalculator(
timestamp_column_name='timestamp',
Expand Down Expand Up @@ -410,3 +462,19 @@ def test_binary_classification_result_plots_raise_no_exceptions(calc_args, plot_
_ = sut.plot(**plot_args)
except Exception as exc:
pytest.fail(f"an unexpected exception occurred: {exc}")


def test_binary_classification_calculate_without_prediction_column():
reference, analysis, analysis_targets = load_synthetic_binary_classification_dataset()
try:
calc = PerformanceCalculator(
y_true='work_home_actual',
y_pred_proba='y_pred_proba',
problem_type=ProblemType.CLASSIFICATION_BINARY,
metrics=['roc_auc', 'average_precision'],
timestamp_column_name='timestamp',
chunk_period='M',
).fit(reference)
_ = calc.calculate(analysis.merge(analysis_targets, on='id'))
except Exception as exc:
pytest.fail(f"an unexpected exception occurred: {exc}")
69 changes: 69 additions & 0 deletions tests/performance_estimation/CBPE/test_cbpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,56 @@ def test_cbpe_create_with_single_or_list_of_metrics(metrics, expected):
assert [metric.name for metric in sut.metrics] == expected


@pytest.mark.parametrize(
'problem',
[
"classification_multiclass",
"regression",
],
)
def test_cbpe_create_raises_exception_when_y_pred_not_given_and_problem_type_not_binary_classification(problem):
with pytest.raises(InvalidArgumentsException, match=f"'y_pred' can not be 'None' for problem type {problem}"):
_ = CBPE(
timestamp_column_name='timestamp',
y_pred_proba='y_pred_proba',
y_true='y_true',
metrics=['roc_auc', 'f1'],
problem_type=problem,
)


@pytest.mark.parametrize(
'metric, expected',
[
(['roc_auc', 'f1'], "['f1']"),
(['roc_auc', 'f1', 'average_precision', 'precision'], "['f1', 'precision']"),
],
)
def test_cbpe_create_without_y_pred_raises_exception_when_metrics_require_it(metric, expected):
with pytest.raises(InvalidArgumentsException, match=expected):
_ = CBPE(
timestamp_column_name='timestamp',
y_pred_proba='y_pred_proba',
y_true='y_true',
metrics=metric,
problem_type='classification_binary',
)


@pytest.mark.parametrize('metric', ['roc_auc', 'average_precision'])
def test_cbpe_create_without_y_pred_works_when_metrics_dont_require_it(metric):
try:
_ = CBPE(
timestamp_column_name='timestamp',
y_pred_proba='y_pred_proba',
y_true='y_true',
metrics=metric,
problem_type='classification_binary',
)
except Exception as exc:
pytest.fail(f'unexpected exception: {exc}')


def test_cbpe_will_calibrate_scores_when_needed(binary_classification_data): # noqa: D103
ref_df = binary_classification_data[0]

Expand Down Expand Up @@ -652,3 +702,22 @@ def test_cbpe_with_default_thresholds():
sut = est.thresholds

assert sut == DEFAULT_THRESHOLDS


def test_cbpe_without_predictions():
ref_df, ana_df, _ = load_synthetic_binary_classification_dataset()
try:
cbpe = CBPE(
y_pred_proba='y_pred_proba',
y_true='work_home_actual',
problem_type='classification_binary',
metrics=[
'roc_auc',
'average_precision',
],
timestamp_column_name='timestamp',
chunk_period='M',
).fit(ref_df)
_ = cbpe.estimate(ana_df)
except Exception as exc:
pytest.fail(f'unexpected exception: {exc}')

0 comments on commit a52c06d

Please sign in to comment.