Skip to content

Commit

Permalink
Fix _list_missing issues + add "run with prediction" tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nnansters committed Apr 25, 2024
1 parent cb6e6b3 commit 4eb8cb8
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 2 deletions.
10 changes: 8 additions & 2 deletions nannyml/performance_estimation/confidence_based/cbpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,10 @@ def _estimate(self, data: pd.DataFrame, *args, **kwargs) -> Result:
raise InvalidArgumentsException('data contains no rows. Please provide a valid data set.')

if self.problem_type == ProblemType.CLASSIFICATION_BINARY:
_list_missing([self.y_pred, self.y_pred_proba], data)
required_cols = [self.y_pred_proba]
if self.y_pred is not None:
required_cols.append(self.y_pred)
_list_missing(required_cols, list(data.columns))

# We need uncalibrated data to calculate the realized performance on.
# https://github.com/NannyML/nannyml/issues/98
Expand Down Expand Up @@ -419,7 +422,10 @@ def _fit_binary(self, reference_data: pd.DataFrame) -> CBPE:
if reference_data.empty:
raise InvalidArgumentsException('data contains no rows. Please provide a valid data set.')

_list_missing([self.y_true, self.y_pred_proba, self.y_pred], list(reference_data.columns))
required_cols = [self.y_true, self.y_pred_proba]
if self.y_pred is not None:
required_cols.append(self.y_pred)
_list_missing(required_cols, list(reference_data.columns))

# We need uncalibrated data to calculate the realized performance on.
# We need realized performance in threshold calculations.
Expand Down
14 changes: 14 additions & 0 deletions tests/performance_calculation/test_performance_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,17 @@ def test_binary_classification_result_plots_raise_no_exceptions(calc_args, plot_
_ = sut.plot(**plot_args)
except Exception as exc:
pytest.fail(f"an unexpected exception occurred: {exc}")


def test_binary_classification_calculate_without_prediction_column():
reference, analysis, analysis_targets = load_synthetic_binary_classification_dataset()
calc = PerformanceCalculator(
y_true='work_home_actual',
y_pred_proba='y_pred_proba',
problem_type=ProblemType.CLASSIFICATION_BINARY,
metrics=['roc_auc', 'average_precision'],
timestamp_column_name='timestamp',
chunk_period='M'
).fit(reference)
res = calc.calculate(analysis.merge(analysis_targets, on='id'))

16 changes: 16 additions & 0 deletions tests/performance_estimation/CBPE/test_cbpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,3 +702,19 @@ def test_cbpe_with_default_thresholds():
sut = est.thresholds

assert sut == DEFAULT_THRESHOLDS


def test_cbpe_without_predictions():
ref_df, ana_df, _ = load_synthetic_binary_classification_dataset()
cbpe = CBPE(
y_pred_proba='y_pred_proba',
y_true='work_home_actual',
problem_type='classification_binary',
metrics=[
'roc_auc',
'average_precision',
],
timestamp_column_name='timestamp',
chunk_period='M',
).fit(ref_df)
result = cbpe.estimate(ana_df)

0 comments on commit 4eb8cb8

Please sign in to comment.