Skip to content

Commit 0716c64

Browse files
authored
[R] Error out on multidimensional arrays (#9852)
1 parent 62571b7 commit 0716c64

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

R-package/src/xgboost_R.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ SEXP SafeMkChar(const char *c_str, SEXP continuation_token) {
5050

5151
[[nodiscard]] std::string MakeArrayInterfaceFromRMat(SEXP R_mat) {
5252
SEXP mat_dims = Rf_getAttrib(R_mat, R_DimSymbol);
53+
if (Rf_xlength(mat_dims) > 2) {
54+
LOG(FATAL) << "Passed input array with more than two dimensions, which is not supported.";
55+
}
5356
const int *ptr_mat_dims = INTEGER(mat_dims);
5457

5558
// Lambda for type dispatch.

R-package/tests/testthat/test_dmatrix.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,11 @@ test_that("xgb.DMatrix: Inf as missing", {
297297
file.remove("inf.dmatrix")
298298
file.remove("nan.dmatrix")
299299
})
300+
301+
test_that("xgb.DMatrix: error on three-dimensional array", {
302+
set.seed(123)
303+
x <- matrix(rnorm(500), nrow = 50)
304+
y <- rnorm(400)
305+
dim(y) <- c(50, 4, 2)
306+
expect_error(xgb.DMatrix(data = x, label = y))
307+
})

0 commit comments

Comments
 (0)