Skip to content

Commit

Permalink
default to reshape=TRUE
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes committed May 25, 2024
1 parent 9def441 commit 61d6e6a
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 13 deletions.
25 changes: 20 additions & 5 deletions R-package/R/xgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,19 @@ xgb.get.handle <- function(object) {
#' @param approxcontrib Whether to use a fast approximation for feature contributions (see Details).
#' @param predinteraction Whether to return contributions of feature interactions to individual predictions (see Details).
#' @param reshape Whether to reshape the vector of predictions to matrix form when there are several
#' prediction outputs per case. No effect if `predleaf`, `predcontrib`,
#' or `predinteraction` is `TRUE`.
#' prediction outputs per case.
#'
#' If passing `reshape=FALSE` and there are multiple predictions per row, they will be returned
#' as a single vector of dimensions `[nrows, ntargets]`, which will be \bold{in row-major order}
#' (meaning: if calling `matrix(.)` on it, one should either swap the number of rows and columns,
#' or pass argument `byrow=TRUE`).
#'
#' Producing non-reshaped predictions is faster, since XGBoost internally uses row-major order
#' while the reshaped outputs are R matrices which follow column-major order and thus need transposing.
#'
#' Will be ignored when using any of `predleaf`, `predcontrib`, or `predinteraction` (they return a
#' multi-dimensional array with the rows being the first or last dimension according to parameter
#' `strict_shape`), and will be overriden when passing `strict_shape=TRUE`.
#' @param training Whether the prediction result is used for training. For dart booster,
#' training predicting will perform dropout.
#' @param iterationrange Sequence of rounds/iterations from the model to use for prediction, specified by passing
Expand All @@ -128,8 +139,12 @@ xgb.get.handle <- function(object) {
#' of the iterations (rounds) otherwise.
#'
#' If passing "all", will use all of the rounds regardless of whether the model had early stopping or not.
#' @param strict_shape Default is `FALSE`. When set to `TRUE`, the output
#' type and shape of predictions are invariant to the model type.
#' @param strict_shape Whether to make the predictions output type and shape invariant to the model type,
#' by always returning a matrix where the rows are the second dimension.
#'
#' Note that, if there is more than one prediction output per row, passing `strict_shape=TRUE` will
#' override the `reshape` argument, and unlike the output from `reshape=TRUE`, the output with
#' `strict_shape=TRUE` will have dimensions `[ntargets, nrows]`.
#' @param base_margin Base margin used for boosting from existing model.
#'
#' Note that, if `newdata` is an `xgb.DMatrix` object, this argument will
Expand Down Expand Up @@ -311,7 +326,7 @@ xgb.get.handle <- function(object) {
#' @export
predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE,
predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE,
reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE,
reshape = TRUE, training = FALSE, iterationrange = NULL, strict_shape = FALSE,
validate_features = FALSE, base_margin = NULL, ...) {
if (validate_features) {
newdata <- validate.features(object, newdata)
Expand Down
27 changes: 22 additions & 5 deletions R-package/man/predict.xgb.Booster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ test_that("train and predict softprob", {
expect_false(is.null(attributes(bst)$evaluation_log))
expect_lt(attributes(bst)$evaluation_log[, min(train_merror)], 0.025)
expect_equal(xgb.get.num.boosted.rounds(bst), 5)
pred <- predict(bst, as.matrix(iris[, -5]))
pred <- predict(bst, as.matrix(iris[, -5]), reshape = FALSE)
expect_length(pred, nrow(iris) * 3)
# row sums add up to total probability of 1:
expect_equal(rowSums(matrix(pred, ncol = 3, byrow = TRUE)), rep(1, nrow(iris)), tolerance = 1e-7)
Expand Down
4 changes: 2 additions & 2 deletions R-package/tests/testthat/test_interactions.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,13 @@ test_that("multiclass feature interactions work", {
b <- xgb.train(param, dm, 40)
pred <- t(
array(
data = predict(b, dm, outputmargin = TRUE),
data = predict(b, dm, outputmargin = TRUE, reshape = FALSE),
dim = c(3, 150)
)
)

# SHAP contributions:
cont <- predict(b, dm, predcontrib = TRUE)
cont <- predict(b, dm, predcontrib = TRUE, reshape = FALSE)
expect_length(cont, 3)
# rewrap them as a 3d array
cont <- array(
Expand Down

0 comments on commit 61d6e6a

Please sign in to comment.