Skip to content

Commit 959c9d6

Browse files
authored
fix: reduce number of data table threads when running with future (#979)
* fix: reduce number of data table threads when running with future * refactor: pass is sequential to worker * test: data table threads are not changed in main session * fix: use default * feat: reduce blas threads to 1 * docs: namespace * docs: namespace * fix: roxygen * chore: update news
1 parent 318748b commit 959c9d6

File tree

7 files changed

+36
-2
lines changed

7 files changed

+36
-2
lines changed

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ Imports:
5656
parallelly,
5757
palmerpenguins,
5858
paradox (>= 0.10.0),
59+
RhpcBLASctl,
5960
uuid
6061
Suggests:
6162
Matrix,
@@ -73,7 +74,7 @@ Config/testthat/edition: 3
7374
Config/testthat/parallel: false
7475
NeedsCompilation: no
7576
Roxygen: list(markdown = TRUE, r6 = TRUE)
76-
RoxygenNote: 7.2.3.9000
77+
RoxygenNote: 7.2.3
7778
Collate:
7879
'mlr_reflections.R'
7980
'BenchmarkResult.R'

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ import(palmerpenguins)
234234
import(paradox)
235235
importFrom(R6,R6Class)
236236
importFrom(R6,is.R6)
237+
importFrom(RhpcBLASctl,blas_get_num_procs)
238+
importFrom(RhpcBLASctl,blas_set_num_threads)
237239
importFrom(data.table,as.data.table)
238240
importFrom(data.table,data.table)
239241
importFrom(future,nbrOfWorkers)

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# mlr3 (development version)
22

3+
* Reduce number of threads used by `data.table` and BLAS to 1 when running `resample()` or `benchmark()` in parallel.
34
* Optimize runtime of `resample()` and `benchmark()` by reducing the number of hashing operations.
45

56
# mlr3 0.17.0

R/helper_exec.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ future_map = function(n, FUN, ..., MoreArgs = list()) {
3636
}
3737
stdout = if (is_sequential) NA else TRUE
3838

39+
MoreArgs = c(MoreArgs, list(is_sequential = is_sequential))
40+
3941
lg$debug("Running resample() via future with %i iterations", n)
4042
future.apply::future_mapply(
4143
FUN, ..., MoreArgs = MoreArgs, SIMPLIFY = FALSE, USE.NAMES = FALSE,

R/worker.R

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,11 +219,19 @@ learner_predict = function(learner, task, row_ids = NULL) {
219219
}
220220

221221

222-
workhorse = function(iteration, task, learner, resampling, param_values = NULL, lgr_threshold, store_models = FALSE, pb = NULL, mode = "train") {
222+
workhorse = function(iteration, task, learner, resampling, param_values = NULL, lgr_threshold, store_models = FALSE, pb = NULL, mode = "train", is_sequential = TRUE) {
223223
if (!is.null(pb)) {
224224
pb(sprintf("%s|%s|i:%i", task$id, learner$id, iteration))
225225
}
226226

227+
# reduce data.table and blas threads to 1
228+
if (!is_sequential) {
229+
setDTthreads(1, restore_after_fork = TRUE)
230+
old_blas_threads = blas_get_num_procs()
231+
on.exit(blas_set_num_threads(old_blas_threads), add = TRUE)
232+
blas_set_num_threads(1)
233+
}
234+
227235
# restore logger thresholds
228236
for (package in names(lgr_threshold)) {
229237
logger = lgr::get_logger(package)

R/zzz.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#' @importFrom uuid UUIDgenerate
1111
#' @importFrom parallelly availableCores
1212
#' @importFrom future nbrOfWorkers plan
13+
#' @importFrom RhpcBLASctl blas_set_num_threads blas_get_num_procs
1314
#'
1415
#' @section Learn mlr3:
1516
#' * Book on mlr3: \url{https://mlr3book.mlr-org.com}

tests/testthat/test_parallel.R

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,22 @@ test_that("parallel seed", {
8585
})
8686
expect_equal(rr1$prediction()$prob, rr2$prediction()$prob)
8787
})
88+
89+
test_that("data table threads are not changed in main session", {
90+
old_dt_threads = getDTthreads()
91+
on.exit({
92+
setDTthreads(old_dt_threads)
93+
}, add = TRUE)
94+
setDTthreads(2L)
95+
96+
task = tsk("sonar")
97+
learner = lrn("classif.debug", predict_type = "prob")
98+
resampling = rsmp("cv", folds = 3L)
99+
measure = msr("classif.auc")
100+
101+
rr1 = with_seed(123, with_future(future::sequential, resample(task, learner, resampling)))
102+
expect_equal(getDTthreads(), 2L)
103+
104+
rr2 = with_seed(123, with_future(future::multisession, resample(task, learner, resampling)))
105+
expect_equal(getDTthreads(), 2L)
106+
})

0 commit comments

Comments
 (0)