Skip to content

Commit 1de3f41

Browse files
authored
[R] Enable vector-valued parameters (#9849)
1 parent 0716c64 commit 1de3f41

File tree

3 files changed

+42
-1
lines changed

3 files changed

+42
-1
lines changed

R-package/R/utils.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,14 @@ check.booster.params <- function(params, ...) {
9393
interaction_constraints <- sapply(params[['interaction_constraints']], function(x) paste0('[', paste(x, collapse = ','), ']'))
9494
params[['interaction_constraints']] <- paste0('[', paste(interaction_constraints, collapse = ','), ']')
9595
}
96+
97+
# for evaluation metrics, should generate multiple entries per metric
98+
if (NROW(params[['eval_metric']]) > 1) {
99+
eval_metrics <- as.list(params[["eval_metric"]])
100+
names(eval_metrics) <- rep("eval_metric", length(eval_metrics))
101+
params_without_ev_metrics <- within(params, rm("eval_metric"))
102+
params <- c(params_without_ev_metrics, eval_metrics)
103+
}
96104
return(params)
97105
}
98106

R-package/R/xgb.Booster.R

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,13 @@ xgb.config <- function(object) {
697697
stop("parameter names cannot be empty strings")
698698
}
699699
names(p) <- gsub(".", "_", names(p), fixed = TRUE)
700-
p <- lapply(p, function(x) as.character(x)[1])
700+
p <- lapply(p, function(x) {
701+
if (is.vector(x) && length(x) == 1) {
702+
return(as.character(x)[1])
703+
} else {
704+
return(jsonlite::toJSON(x, auto_unbox = TRUE))
705+
}
706+
})
701707
handle <- xgb.get.handle(object)
702708
for (i in seq_along(p)) {
703709
.Call(XGBoosterSetParam_R, handle, names(p[i]), p[[i]])

R-package/tests/testthat/test_basic.R

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,33 @@ test_that("'predict' accepts CSR data", {
566566
expect_equal(p_csc, p_spv)
567567
})
568568

569+
test_that("Quantile regression accepts multiple quantiles", {
570+
data(mtcars)
571+
y <- mtcars[, 1]
572+
x <- as.matrix(mtcars[, -1])
573+
dm <- xgb.DMatrix(data = x, label = y)
574+
model <- xgb.train(
575+
data = dm,
576+
params = list(
577+
objective = "reg:quantileerror",
578+
tree_method = "exact",
579+
quantile_alpha = c(0.05, 0.5, 0.95),
580+
nthread = n_threads
581+
),
582+
nrounds = 15
583+
)
584+
pred <- predict(model, x, reshape = TRUE)
585+
586+
expect_equal(dim(pred)[1], nrow(x))
587+
expect_equal(dim(pred)[2], 3)
588+
expect_true(all(pred[, 1] <= pred[, 3]))
589+
590+
cors <- cor(y, pred)
591+
expect_true(cors[2] > cors[1])
592+
expect_true(cors[2] > cors[3])
593+
expect_true(cors[2] > 0.85)
594+
})
595+
569596
test_that("Can use multi-output labels with built-in objectives", {
570597
data("mtcars")
571598
y <- mtcars$mpg

0 commit comments

Comments
 (0)