diff --git a/R-package/R/utils.R b/R-package/R/utils.R index 2fc2321cad31..6f1a1b4ec8c3 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -160,23 +160,24 @@ xgb.iter.update <- function(booster_handle, dtrain, iter, obj) { ) gpair <- obj(pred, dtrain) n_samples <- dim(dtrain)[1] - - msg <- paste( - "Since 2.1.0, the shape of the gradient and hessian is required to be ", - "(n_samples, n_targets) or (n_samples, n_classes).", - sep = "" - ) - if (is.matrix(gpair$grad) && dim(gpair$grad)[1] != n_samples) { - warning(msg) - } - if (is.numeric(gpair$grad) && length(gpair$grad) != n_samples) { - warning(msg) + grad <- gpair$grad + hess <- gpair$hess + + if ((is.matrix(grad) && dim(grad)[1] != n_samples) || + (is.vector(grad) && length(grad) != n_samples) || + (is.vector(grad) != is.vector(hess))) { + warning(paste( + "Since 2.1.0, the shape of the gradient and hessian is required to be ", + "(n_samples, n_targets) or (n_samples, n_classes). Will reshape assuming ", + "column-major order.", + sep = "" + )) + grad <- matrix(grad, nrow = n_samples) + hess <- matrix(hess, nrow = n_samples) } - gpair$grad <- matrix(gpair$grad, nrow = n_samples) - gpair$hess <- matrix(gpair$hess, nrow = n_samples) .Call( - XGBoosterBoostOneIter_R, booster_handle, dtrain, iter, gpair$grad, gpair$hess + XGBoosterTrainOneIter_R, booster_handle, dtrain, iter, grad, hess ) } return(TRUE) diff --git a/R-package/R/xgb.DMatrix.R b/R-package/R/xgb.DMatrix.R index 8e19e87b078f..3a1b6afc8d2a 100644 --- a/R-package/R/xgb.DMatrix.R +++ b/R-package/R/xgb.DMatrix.R @@ -243,6 +243,9 @@ getinfo.xgb.DMatrix <- function(object, name, ...) { ret <- .Call(XGDMatrixGetStrFeatureInfo_R, object, name) } else if (name != "nrow") { ret <- .Call(XGDMatrixGetInfo_R, object, name) + if (length(ret) > nrow(object)) { + ret <- matrix(ret, nrow = nrow(object), byrow = TRUE) + } } else { ret <- nrow(object) } @@ -286,9 +289,9 @@ setinfo <- function(object, ...) UseMethod("setinfo") #' @export setinfo.xgb.DMatrix <- function(object, name, info, ...) { if (name == "label") { - if (length(info) != nrow(object)) + if (NROW(info) != nrow(object)) stop("The length of labels must equal to the number of rows in the input data") - .Call(XGDMatrixSetInfo_R, object, name, as.numeric(info)) + .Call(XGDMatrixSetInfo_R, object, name, info) return(TRUE) } if (name == "label_lower_bound") { diff --git a/R-package/src/init.c b/R-package/src/init.c index 5c8e179d6c69..afac524e380e 100644 --- a/R-package/src/init.c +++ b/R-package/src/init.c @@ -52,7 +52,7 @@ extern SEXP XGBGetGlobalConfig_R(void); extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP); static const R_CallMethodDef CallEntries[] = { - {"XGBoosterBoostOneIter_R", (DL_FUNC) &XGBoosterTrainOneIter_R, 5}, + {"XGBoosterTrainOneIter_R", (DL_FUNC) &XGBoosterTrainOneIter_R, 5}, {"XGBoosterCreate_R", (DL_FUNC) &XGBoosterCreate_R, 1}, {"XGBoosterCreateInEmptyObj_R", (DL_FUNC) &XGBoosterCreateInEmptyObj_R, 2}, {"XGBoosterDumpModel_R", (DL_FUNC) &XGBoosterDumpModel_R, 4}, diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index b267d7da62fe..2f5e62158650 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -346,9 +346,11 @@ XGB_DLL SEXP XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) { XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) { R_API_BEGIN(); SEXP field_ = PROTECT(Rf_asChar(field)); + SEXP arr_dim = Rf_getAttrib(array, R_DimSymbol); int res_code; { - const std::string array_str = MakeArrayInterfaceFromRVector(array); + const std::string array_str = Rf_isNull(arr_dim)? + MakeArrayInterfaceFromRVector(array) : MakeArrayInterfaceFromRMat(array); res_code = XGDMatrixSetInfoFromInterface( R_ExternalPtrAddr(handle), CHAR(field_), array_str.c_str()); } @@ -516,20 +518,14 @@ XGB_DLL SEXP XGBoosterTrainOneIter_R(SEXP handle, SEXP dtrain, SEXP iter, SEXP g R_API_BEGIN(); CHECK_EQ(Rf_xlength(grad), Rf_xlength(hess)) << "gradient and hess must have same length."; SEXP gdim = getAttrib(grad, R_DimSymbol); - auto n_samples = static_cast(INTEGER(gdim)[0]); - auto n_targets = static_cast(INTEGER(gdim)[1]); - SEXP hdim = getAttrib(hess, R_DimSymbol); - CHECK_EQ(INTEGER(hdim)[0], n_samples) << "mismatched size between gradient and hessian"; - CHECK_EQ(INTEGER(hdim)[1], n_targets) << "mismatched size between gradient and hessian"; - double const *d_grad = REAL(grad); - double const *d_hess = REAL(hess); int res_code; { - auto ctx = xgboost::detail::BoosterCtx(R_ExternalPtrAddr(handle)); - auto [s_grad, s_hess] = xgboost::detail::MakeGradientInterface( - ctx, d_grad, d_hess, xgboost::linalg::kF, n_samples, n_targets); + const std::string s_grad = Rf_isNull(gdim)? + MakeArrayInterfaceFromRVector(grad) : MakeArrayInterfaceFromRMat(grad); + const std::string s_hess = Rf_isNull(hdim)? + MakeArrayInterfaceFromRVector(hess) : MakeArrayInterfaceFromRMat(hess); res_code = XGBoosterTrainOneIter(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(dtrain), asInteger(iter), s_grad.c_str(), s_hess.c_str()); } diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index b7e8197383eb..97c1353dc43a 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -565,3 +565,54 @@ test_that("'predict' accepts CSR data", { expect_equal(p_csc, p_csr) expect_equal(p_csc, p_spv) }) + +test_that("Can use multi-output labels with built-in objectives", { + data("mtcars") + y <- mtcars$mpg + x <- as.matrix(mtcars[, -1]) + y_mirrored <- cbind(y, -y) + dm <- xgb.DMatrix(x, label = y_mirrored, nthread = n_threads) + model <- xgb.train( + params = list( + tree_method = "hist", + multi_strategy = "multi_output_tree", + objective = "reg:squarederror", + nthread = n_threads + ), + data = dm, + nrounds = 5 + ) + pred <- predict(model, x, reshape = TRUE) + expect_equal(pred[, 1], -pred[, 2]) + expect_true(cor(y, pred[, 1]) > 0.9) + expect_true(cor(y, pred[, 2]) < -0.9) +}) + +test_that("Can use multi-output labels with custom objectives", { + data("mtcars") + y <- mtcars$mpg + x <- as.matrix(mtcars[, -1]) + y_mirrored <- cbind(y, -y) + dm <- xgb.DMatrix(x, label = y_mirrored, nthread = n_threads) + model <- xgb.train( + params = list( + tree_method = "hist", + multi_strategy = "multi_output_tree", + base_score = 0, + objective = function(pred, dtrain) { + y <- getinfo(dtrain, "label") + grad <- pred - y + hess <- rep(1, nrow(grad) * ncol(grad)) + hess <- matrix(hess, nrow = nrow(grad)) + return(list(grad = grad, hess = hess)) + }, + nthread = n_threads + ), + data = dm, + nrounds = 5 + ) + pred <- predict(model, x, reshape = TRUE) + expect_equal(pred[, 1], -pred[, 2]) + expect_true(cor(y, pred[, 1]) > 0.9) + expect_true(cor(y, pred[, 2]) < -0.9) +})