Skip to content

Commit

Permalink
Merge pull request #6933 from ales-erjavec/fixes/owpredictions-error-…
Browse files Browse the repository at this point in the history
…magic-2

[FIX] owpredicttions: Remove special magic value 2
  • Loading branch information
janezd authored Nov 29, 2024
2 parents f9e5d71 + 0c47254 commit 9ff3124
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
4 changes: 1 addition & 3 deletions Orange/widgets/evaluate/owpredictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,6 @@ def _add_error_out_columns(self, slot, newmetas, newcolumns, index):
name = f"{slot.predictor.name} (error)"
newmetas.append(ContinuousVariable(name=name))
err = self.predictionsview.model().errorColumn(index)
err[err == 2] = numpy.nan
newcolumns.append(err)

def send_report(self):
Expand Down Expand Up @@ -1383,8 +1382,7 @@ def errorColumn(self, column):
nans = numpy.isnan(actuals)
actuals[nans] = 0
errors = 1 - numpy.choose(actuals.astype(int), self._probs[column].T)
errors[nans] = 2
errors[numpy.isnan(errors)] = 2
errors[nans] = numpy.nan
return errors
else:
actual = self._actual
Expand Down
11 changes: 11 additions & 0 deletions Orange/widgets/evaluate/tests/test_owpredictions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for OWPredictions"""
# pylint: disable=protected-access,too-many-lines,too-many-public-methods
import os
import random
import unittest
from functools import partial
from tempfile import NamedTemporaryFile
Expand Down Expand Up @@ -1766,6 +1767,16 @@ def test_sorting_classification_error(self):
[model.data(model.index(row, 3)) for row in range(5)],
1 - np.array(sorted([80, 5, 20, 60, 50])) / 100)

# Numpy's sort puts nan's at the end, and the widget counts on it
# because we want to show them last. If this test fails, this
# (undocumented) numpy's behavior has changed, and the widget needs
# to be updated.
data = list(range(10)) + [np.nan] * 10
copy = data.copy()
for _ in range(10):
random.shuffle(data)
np.testing.assert_equal(np.sort(data), copy)

def test_sorting_classification_different(self):
model = PredictionsModel(self.values, self.probs, self.actual)

Expand Down

0 comments on commit 9ff3124

Please sign in to comment.