Skip to content

Commit b285c51

Browse files
be-marcsebffischer
andauthored
fix: remove task prototype when resample (#981)
* refactor: remove task prototype when resample * refactor: add option to store prototype * fix: braket * refactor: null * fix: browser * keep prototypes in state when store_models is TRUE * refactor: only store data_prototype when train * refactor: feature names --------- Co-authored-by: Sebastian Fischer <[email protected]>
1 parent 959c9d6 commit b285c51

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

R/Learner.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,11 @@ Learner = R6Class("Learner",
244244

245245
learner_train(learner, task, train_row_ids = train_row_ids, test_row_ids = test_row_ids, mode = mode)
246246

247+
# store data prototype
248+
proto = task$data(rows = integer())
249+
self$state$data_prototype = proto
250+
self$state$task_prototype = proto
251+
247252
# store the task w/o the data
248253
self$state$train_task = task_rm_backend(task$clone(deep = TRUE))
249254

R/worker.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,13 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL
6868
log = append_log(NULL, "train", result$log$class, result$log$msg)
6969
train_time = result$elapsed
7070

71-
proto = task$data(rows = integer())
7271
learner$state = insert_named(learner$state, list(
7372
model = result$result,
7473
log = log,
7574
train_time = train_time,
7675
param_vals = learner$param_set$values,
7776
task_hash = task$hash,
78-
data_prototype = proto,
79-
task_prototype = proto,
77+
feature_names = task$feature_names,
8078
mlr3_version = mlr_reflections$package_version
8179
))
8280

0 commit comments

Comments
 (0)