-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[R] Make xgb.cv
work with xgb.DMatrix
only, adding support for survival and ranking fields
#10031
Changes from 4 commits
100405f
c94d018
82eb309
a745af1
3639ef1
f083ff7
a849522
ba14525
c06b969
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
#' Cross Validation | ||
#' | ||
#' The cross validation function of xgboost | ||
#' The cross validation function of xgboost. | ||
#' | ||
#' @param params the list of parameters. The complete list of parameters is | ||
#' available in the \href{http://xgboost.readthedocs.io/en/latest/parameter.html}{online documentation}. Below | ||
|
@@ -19,13 +19,17 @@ | |
#' | ||
#' See \code{\link{xgb.train}} for further details. | ||
#' See also demo/ for walkthrough example in R. | ||
#' @param data takes an \code{xgb.DMatrix}, \code{matrix}, or \code{dgCMatrix} as the input. | ||
#' | ||
#' Note that, while `params` accepts a `seed` entry and will use such parameter for model training if | ||
#' supplied, this seed is not used for creation of train-test splits, which instead rely on R's own RNG | ||
#' system - thus, for reproducible results, one needs to call the `set.seed` function beforehand. | ||
#' @param data An `xgb.DMatrix` object, with corresponding fields like `label` or bounds as required | ||
#' for model training by the objective. | ||
#' | ||
#' Note that only the basic `xgb.DMatrix` class is supported - variants such as `xgb.QuantileDMatrix` | ||
#' or `xgb.ExternalDMatrix` are not supported here. | ||
#' @param nrounds the max number of iterations | ||
#' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples. | ||
#' @param label vector of response values. Should be provided only when data is an R-matrix. | ||
#' @param missing is only used when input is a dense matrix. By default is set to NA, which means | ||
#' that NA values should be considered as 'missing' by the algorithm. | ||
#' Sometimes, 0 or other extreme value might be used to represent missing values. | ||
#' @param prediction A logical value indicating whether to return the test fold predictions | ||
#' from each CV model. This parameter engages the \code{\link{cb.cv.predict}} callback. | ||
#' @param showsd \code{boolean}, whether to show standard deviation of cross validation | ||
|
@@ -47,13 +51,30 @@ | |
#' @param feval customized evaluation function. Returns | ||
#' \code{list(metric='metric-name', value='metric-value')} with given | ||
#' prediction and dtrain. | ||
#' @param stratified a \code{boolean} indicating whether sampling of folds should be stratified | ||
#' by the values of outcome labels. | ||
#' @param stratified A \code{boolean} indicating whether sampling of folds should be stratified | ||
#' by the values of outcome labels. For real-valued labels in regression objectives, | ||
#' stratification will be done by discretizing the labels into up to 5 buckets beforehand. | ||
#' | ||
#' If passing "auto", will be set to `TRUE` if the objective in `params` is a classification | ||
#' objective (from XGBoost's built-in objectives, doesn't apply to custom ones), and to | ||
#' `FALSE` otherwise. | ||
#' | ||
#' This parameter is ignored when `data` has a `group` field - in such case, the splitting | ||
#' will be based on whole groups (note that this might make the folds have different sizes). | ||
#' | ||
#' Value `TRUE` here is \bold{not} supported for custom objectives. | ||
#' @param folds \code{list} provides a possibility to use a list of pre-defined CV folds | ||
#' (each element must be a vector of test fold's indices). When folds are supplied, | ||
#' the \code{nfold} and \code{stratified} parameters are ignored. | ||
#' | ||
#' If `data` has a `group` field and the objective requires this field, each fold (list element) | ||
#' must additionally have two attributes (retrievable through \link{attributes}) named `group_test` | ||
#' and `group_train`, which should hold the `group` to assign through \link{setinfo.xgb.DMatrix} to | ||
#' the resulting DMatrices. | ||
#' @param train_folds \code{list} list specifying which indicies to use for training. If \code{NULL} | ||
#' (the default) all indices not specified in \code{folds} will be used for training. | ||
#' | ||
#' This is not supported when `data` has `group` field. | ||
#' @param verbose \code{boolean}, print the statistics during the process | ||
#' @param print_every_n Print each n-th iteration evaluation messages when \code{verbose>0}. | ||
#' Default is 1 which means all messages are printed. This parameter is passed to the | ||
|
@@ -118,13 +139,14 @@ | |
#' print(cv, verbose=TRUE) | ||
#' | ||
#' @export | ||
xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing = NA, | ||
xgb.cv <- function(params = list(), data, nrounds, nfold, | ||
prediction = FALSE, showsd = TRUE, metrics = list(), | ||
obj = NULL, feval = NULL, stratified = TRUE, folds = NULL, train_folds = NULL, | ||
obj = NULL, feval = NULL, stratified = "auto", folds = NULL, train_folds = NULL, | ||
verbose = TRUE, print_every_n = 1L, | ||
early_stopping_rounds = NULL, maximize = NULL, callbacks = list(), ...) { | ||
|
||
check.deprecation(...) | ||
stopifnot(inherits(data, "xgb.DMatrix")) | ||
if (inherits(data, "xgb.DMatrix") && .Call(XGCheckNullPtr_R, data)) { | ||
stop("'data' is an invalid 'xgb.DMatrix' object. Must be constructed again.") | ||
} | ||
|
@@ -137,16 +159,19 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing | |
check.custom.obj() | ||
check.custom.eval() | ||
|
||
# Check the labels | ||
if ((inherits(data, 'xgb.DMatrix') && !xgb.DMatrix.hasinfo(data, 'label')) || | ||
(!inherits(data, 'xgb.DMatrix') && is.null(label))) { | ||
stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix") | ||
} else if (inherits(data, 'xgb.DMatrix')) { | ||
if (!is.null(label)) | ||
warning("xgb.cv: label will be ignored, since data is of type xgb.DMatrix") | ||
cv_label <- getinfo(data, 'label') | ||
} else { | ||
cv_label <- label | ||
if (stratified == "auto") { | ||
if (is.character(params$objective)) { | ||
stratified <- params$objective %in% .CLASSIFICATION_OBJECTIVES() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's update There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a new |
||
} else { | ||
stratified <- FALSE | ||
} | ||
} | ||
|
||
# Check the labels and groups | ||
cv_label <- getinfo(data, "label") | ||
cv_group <- getinfo(data, "group") | ||
if (!is.null(train_folds) && NROW(cv_group)) { | ||
stop("'train_folds' is not supported for DMatrix object with 'group' field.") | ||
} | ||
|
||
# CV folds | ||
|
@@ -157,7 +182,7 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing | |
} else { | ||
if (nfold <= 1) | ||
stop("'nfold' must be > 1") | ||
folds <- generate.cv.folds(nfold, nrow(data), stratified, cv_label, params) | ||
folds <- generate.cv.folds(nfold, nrow(data), stratified, cv_label, cv_group, params) | ||
} | ||
|
||
# verbosity & evaluation printing callback: | ||
|
@@ -189,20 +214,18 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing | |
|
||
# create the booster-folds | ||
# train_folds | ||
dall <- xgb.get.DMatrix( | ||
data = data, | ||
label = label, | ||
missing = missing, | ||
weight = NULL, | ||
nthread = params$nthread | ||
) | ||
dall <- data | ||
bst_folds <- lapply(seq_along(folds), function(k) { | ||
dtest <- xgb.slice.DMatrix(dall, folds[[k]]) | ||
dtest <- xgb.slice.DMatrix(dall, folds[[k]], allow_groups = TRUE) | ||
# code originally contributed by @RolandASc on stackoverflow | ||
if (is.null(train_folds)) | ||
dtrain <- xgb.slice.DMatrix(dall, unlist(folds[-k])) | ||
dtrain <- xgb.slice.DMatrix(dall, unlist(folds[-k]), allow_groups = TRUE) | ||
else | ||
dtrain <- xgb.slice.DMatrix(dall, train_folds[[k]]) | ||
dtrain <- xgb.slice.DMatrix(dall, train_folds[[k]], allow_groups = TRUE) | ||
if (!is.null(attributes(folds[[k]])$group_test)) { | ||
setinfo(dtest, "group", attributes(folds[[k]])$group_test) | ||
setinfo(dtrain, "group", attributes(folds[[k]])$group_train) | ||
} | ||
bst <- xgb.Booster( | ||
params = params, | ||
cachelist = list(dtrain, dtest), | ||
|
@@ -285,8 +308,8 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing | |
#' @examples | ||
#' data(agaricus.train, package='xgboost') | ||
#' train <- agaricus.train | ||
#' cv <- xgb.cv(data = train$data, label = train$label, nfold = 5, max_depth = 2, | ||
#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic") | ||
#' cv <- xgb.cv(data = xgb.DMatrix(train$data, label = train$label), nfold = 5, max_depth = 2, | ||
#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic") | ||
#' print(cv) | ||
#' print(cv, verbose=TRUE) | ||
#' | ||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It appears that we don't support stratified CV when
group
is non-empty. Can we throw a warning about the use of un-stratified splitting whengroup
is non-empty?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed it to throw an error rather than a warning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed it to a warning now.