Skip to content

Commit 2487013

Browse files
authored
feat: throw warning when prediction and measure type do not match (#1188)
* feat: throw warning when prediction and measure type do not match * ...
1 parent e22dbe4 commit 2487013

File tree

5 files changed

+24
-4
lines changed

5 files changed

+24
-4
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# mlr3 (development version)
22

3+
* feat: Throw warning when prediction and measure type do not match.
34
* fix: The `mlr_reflections` were broken when an extension package was not loaded on the workers.
45
Extension packages must now register themselves in the `mlr_reflections$loaded_packages` field.
56

R/Measure.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ Measure = R6Class("Measure",
194194
#'
195195
#' @return `numeric(1)`.
196196
score = function(prediction, task = NULL, learner = NULL, train_set = NULL) {
197-
assert_measure(self, task = task, learner = learner)
197+
assert_measure(self, task = task, learner = learner, prediction = prediction)
198198
assert_prediction(prediction, null.ok = "requires_no_prediction" %nin% self$properties)
199199

200200
if ("requires_task" %in% self$properties && is.null(task)) {

R/assertions.R

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,9 @@ assert_predictable = function(task, learner) {
194194

195195
#' @export
196196
#' @param measure ([Measure]).
197+
#' @param prediction ([Prediction]).
197198
#' @rdname mlr_assertions
198-
assert_measure = function(measure, task = NULL, learner = NULL, .var.name = vname(measure)) {
199+
assert_measure = function(measure, task = NULL, learner = NULL, prediction = NULL, .var.name = vname(measure)) {
199200
assert_class(measure, "Measure", .var.name = .var.name)
200201

201202
if (!is.null(task)) {
@@ -236,6 +237,13 @@ assert_measure = function(measure, task = NULL, learner = NULL, .var.name = vnam
236237
}
237238
}
238239

240+
if (!is.null(prediction)) {
241+
# same as above but works without learner e.g. measure$score(prediction)
242+
if (measure$check_prerequisites != "ignore" && measure$predict_type %nin% prediction$predict_types) {
243+
warningf("Measure '%s' is missing predict type '%s' of prediction", measure$id, measure$predict_type)
244+
}
245+
}
246+
239247
invisible(measure)
240248
}
241249

man/mlr_assertions.Rd

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_Measure.R

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,13 @@ test_that("checks on predict_sets", {
188188
expect_error({m$predict_sets = "imaginary"}, "Must be a subset")
189189
})
190190

191+
test_that("measure and prediction type is checked", {
192+
learner = lrn("classif.rpart")
193+
task = tsk("pima")
194+
learner$train(task)
195+
pred = learner$predict(task)
196+
197+
measure = msr("classif.logloss")
198+
expect_warning(measure$score(pred), "is missing predict type")
199+
})
200+

0 commit comments

Comments
 (0)