From 0716c64ef75d62384437f587eab8fa3925763e38 Mon Sep 17 00:00:00 2001 From: david-cortes Date: Wed, 6 Dec 2023 10:43:51 +0100 Subject: [PATCH] [R] Error out on multidimensional arrays (#9852) --- R-package/src/xgboost_R.cc | 3 +++ R-package/tests/testthat/test_dmatrix.R | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index 1ab3e764180f..214ca6cb4acc 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -50,6 +50,9 @@ SEXP SafeMkChar(const char *c_str, SEXP continuation_token) { [[nodiscard]] std::string MakeArrayInterfaceFromRMat(SEXP R_mat) { SEXP mat_dims = Rf_getAttrib(R_mat, R_DimSymbol); + if (Rf_xlength(mat_dims) > 2) { + LOG(FATAL) << "Passed input array with more than two dimensions, which is not supported."; + } const int *ptr_mat_dims = INTEGER(mat_dims); // Lambda for type dispatch. diff --git a/R-package/tests/testthat/test_dmatrix.R b/R-package/tests/testthat/test_dmatrix.R index a0cf90088704..4db7aad08d45 100644 --- a/R-package/tests/testthat/test_dmatrix.R +++ b/R-package/tests/testthat/test_dmatrix.R @@ -297,3 +297,11 @@ test_that("xgb.DMatrix: Inf as missing", { file.remove("inf.dmatrix") file.remove("nan.dmatrix") }) + +test_that("xgb.DMatrix: error on three-dimensional array", { + set.seed(123) + x <- matrix(rnorm(500), nrow = 50) + y <- rnorm(400) + dim(y) <- c(50, 4, 2) + expect_error(xgb.DMatrix(data = x, label = y)) +})