Skip to content

Commit

Permalink
[R] Error out on multidimensional arrays (#9852)
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes authored Dec 6, 2023
1 parent 62571b7 commit 0716c64
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
3 changes: 3 additions & 0 deletions R-package/src/xgboost_R.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ SEXP SafeMkChar(const char *c_str, SEXP continuation_token) {

[[nodiscard]] std::string MakeArrayInterfaceFromRMat(SEXP R_mat) {
SEXP mat_dims = Rf_getAttrib(R_mat, R_DimSymbol);
if (Rf_xlength(mat_dims) > 2) {
LOG(FATAL) << "Passed input array with more than two dimensions, which is not supported.";
}
const int *ptr_mat_dims = INTEGER(mat_dims);

// Lambda for type dispatch.
Expand Down
8 changes: 8 additions & 0 deletions R-package/tests/testthat/test_dmatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,11 @@ test_that("xgb.DMatrix: Inf as missing", {
file.remove("inf.dmatrix")
file.remove("nan.dmatrix")
})

test_that("xgb.DMatrix: error on three-dimensional array", {
set.seed(123)
x <- matrix(rnorm(500), nrow = 50)
y <- rnorm(400)
dim(y) <- c(50, 4, 2)
expect_error(xgb.DMatrix(data = x, label = y))
})

0 comments on commit 0716c64

Please sign in to comment.