From 4b213adab79022270e75376fe596585c397a9812 Mon Sep 17 00:00:00 2001 From: Frank Weber <55132727+fweber144@users.noreply.github.com> Date: Thu, 3 Mar 2022 06:13:22 +0100 Subject: [PATCH] Fix `contrasts` issue #283 and minor other changes (#284) --- NEWS.md | 1 + R/cv_varsel.R | 2 +- R/data.R | 2 +- R/divergence_minimizers.R | 101 +++++++++++++++-------------- R/methods.R | 1 + R/misc.R | 10 +-- R/projpred-package.R | 6 +- R/refmodel.R | 12 +++- man/cv_varsel.Rd | 2 +- man/mesquite.Rd | 2 +- man/projpred-package.Rd | 6 +- man/refmodel-init-get.Rd | 3 +- tests/testthat/test_as_matrix.R | 4 +- tests/testthat/test_datafit.R | 8 +-- tests/testthat/test_methods_vsel.R | 4 +- tests/testthat/test_parallel.R | 4 +- tests/testthat/test_proj_pred.R | 12 ++-- tests/testthat/test_refmodel.R | 2 +- vignettes/projpred.Rmd | 2 +- 19 files changed, 100 insertions(+), 84 deletions(-) diff --git a/NEWS.md b/NEWS.md index 309b2e0fd..bd0879139 100644 --- a/NEWS.md +++ b/NEWS.md @@ -46,6 +46,7 @@ * Argument `fit` of `init_refmodel()`'s argument `proj_predfun` was renamed to `fits`. This is a non-breaking change since all calls to `proj_predfun` in **projpred** have that argument unnamed. However, this cannot be guaranteed in the future, so we strongly encourage users with a custom `proj_predfun` to rename argument `fit` to `fits`. (GitHub: #263) * `init_refmodel()` has gained argument `cvrefbuilder` which may be a custom function for constructing the K reference models in a K-fold CV. (GitHub: #271) * Allow arguments to be passed from `project()`, `varsel()`, and `cv_varsel()` to the divergence minimizer. (GitHub: #278) +* In `init_refmodel()`, any `contrasts` attributes of the dataset's columns are silently removed. (GitHub: #284) ## Bug fixes diff --git a/R/cv_varsel.R b/R/cv_varsel.R index 771538421..7de3eccbe 100644 --- a/R/cv_varsel.R +++ b/R/cv_varsel.R @@ -55,7 +55,7 @@ #' #' Vehtari, A., Gelman, A., and Gabry, J. (2017). Practical Bayesian model #' evaluation using leave-one-out cross-validation and WAIC. *Statistics and -#' Computing*, **27**(5), 1413-1432. DOI: \doi{10.1007/s11222-016-9696-4}. +#' Computing*, **27**(5), 1413-1432. \doi{10.1007/s11222-016-9696-4}. #' #' Vehtari, A., Simpson, D., Gelman, A., Yao, Y., and Gabry, J. (2021). Pareto #' smoothed importance sampling. *arXiv:1507.02646*. URL: diff --git a/R/data.R b/R/data.R index 279e783fd..2f416424e 100644 --- a/R/data.R +++ b/R/data.R @@ -38,7 +38,7 @@ #' group).} #' } #' @references Gelman, A. and Hill, J. (2006). *Data Analysis Using Regression -#' and Multilevel/Hierarchical Models*. Cambridge University Press. DOI: +#' and Multilevel/Hierarchical Models*. Cambridge University Press. #' \doi{10.1017/CBO9780511790942}. #' @source "mesquite" diff --git a/R/divergence_minimizers.R b/R/divergence_minimizers.R index 097b04263..b9079146b 100644 --- a/R/divergence_minimizers.R +++ b/R/divergence_minimizers.R @@ -85,9 +85,12 @@ divmin <- function(formula, projpred_var, ...) { } } +# Use projpred's own implementation to fit non-multilevel non-additive +# submodels: fit_glm_ridge_callback <- function(formula, data, projpred_var = matrix(nrow = nrow(data)), projpred_regul = 1e-4, ...) { + # Preparations: fr <- model.frame(delete.intercept(formula), data = data) contrasts_arg <- get_contrasts_arg_list(formula, data = data) x <- model.matrix(fr, data = data, contrasts.arg = contrasts_arg) @@ -98,10 +101,12 @@ fit_glm_ridge_callback <- function(formula, data, names(dot_args), methods::formalArgs(glm_ridge) )] + # Call the submodel fitter: fit <- do.call(glm_ridge, c( list(x = x, y = y, lambda = projpred_regul, obsvar = projpred_var), dot_args )) + # Post-processing: rownames(fit$beta) <- colnames(x) sub <- nlist( alpha = fit$beta0, @@ -118,37 +123,34 @@ fit_glm_ridge_callback <- function(formula, data, # `projpred.glm_fitter`): fit_glm_callback <- function(formula, family, projpred_var, projpred_regul, ...) { - tryCatch({ - if (family$family == "gaussian" && family$link == "identity") { - # Exclude arguments from `...` which cannot be passed to stats::lm(): - dot_args <- list(...) - dot_args <- dot_args[intersect( - names(dot_args), - union(methods::formalArgs(stats::lm), - union(methods::formalArgs(stats::lm.fit), - methods::formalArgs(stats::lm.wfit))) - )] - return(suppressMessages(suppressWarnings(do.call(stats::lm, c( - list(formula = formula), - dot_args - ))))) - } else { - # Exclude arguments from `...` which cannot be passed to stats::glm(): - dot_args <- list(...) - dot_args <- dot_args[intersect( - names(dot_args), - union(methods::formalArgs(stats::glm), - methods::formalArgs(stats::glm.control)) - )] - return(suppressMessages(suppressWarnings(do.call(stats::glm, c( - list(formula = formula, family = family), - dot_args - ))))) - } - }, error = function(e) { - # May be used to handle errors. - stop(e) - }) + if (family$family == "gaussian" && family$link == "identity") { + # Exclude arguments from `...` which cannot be passed to stats::lm(): + dot_args <- list(...) + dot_args <- dot_args[intersect( + names(dot_args), + c(methods::formalArgs(stats::lm), + methods::formalArgs(stats::lm.fit), + methods::formalArgs(stats::lm.wfit)) + )] + # Call the submodel fitter: + return(suppressMessages(suppressWarnings(do.call(stats::lm, c( + list(formula = formula), + dot_args + ))))) + } else { + # Exclude arguments from `...` which cannot be passed to stats::glm(): + dot_args <- list(...) + dot_args <- dot_args[intersect( + names(dot_args), + c(methods::formalArgs(stats::glm), + methods::formalArgs(stats::glm.control)) + )] + # Call the submodel fitter: + return(suppressMessages(suppressWarnings(do.call(stats::glm, c( + list(formula = formula, family = family), + dot_args + ))))) + } } # Use package "mgcv" to fit additive non-multilevel submodels: @@ -158,9 +160,10 @@ fit_gam_callback <- function(formula, ...) { dot_args <- list(...) dot_args <- dot_args[intersect( names(dot_args), - union(methods::formalArgs(gam), - methods::formalArgs(mgcv::gam.fit)) + c(methods::formalArgs(gam), + methods::formalArgs(mgcv::gam.fit)) )] + # Call the submodel fitter: return(suppressMessages(suppressWarnings(do.call(gam, c( list(formula = formula), dot_args @@ -176,10 +179,11 @@ fit_gamm_callback <- function(formula, projpred_formula_no_random, dot_args <- list(...) dot_args <- dot_args[intersect( names(dot_args), - union(union(methods::formalArgs(gamm4), - methods::formalArgs(lme4::lFormula)), - methods::formalArgs(lme4::glFormula)) + c(methods::formalArgs(gamm4), + methods::formalArgs(lme4::lFormula), + methods::formalArgs(lme4::glFormula)) )] + # Call the submodel fitter: fit <- tryCatch({ suppressMessages(suppressWarnings(do.call(gamm4, c( list(formula = projpred_formula_no_random, random = projpred_random, @@ -211,9 +215,7 @@ fit_gamm_callback <- function(formula, projpred_formula_no_random, return(fit) } -# Use package "lme4" to fit submodels for multilevel reference models (with a -# fallback to "projpred"'s own implementation for fitting non-multilevel (and -# non-additive) submodels): +# Use package "lme4" to fit multilevel submodels: fit_glmer_callback <- function(formula, family, control = control_callback(family), ...) { tryCatch({ @@ -224,6 +226,7 @@ fit_glmer_callback <- function(formula, family, names(dot_args), methods::formalArgs(lme4::lmer) )] + # Call the submodel fitter: return(suppressMessages(suppressWarnings(do.call(lme4::lmer, c( list(formula = formula, control = control), dot_args @@ -235,6 +238,7 @@ fit_glmer_callback <- function(formula, family, names(dot_args), methods::formalArgs(lme4::glmer) )] + # Call the submodel fitter: return(suppressMessages(suppressWarnings(do.call(lme4::glmer, c( list(formula = formula, family = family, control = control), @@ -413,20 +417,19 @@ check_conv <- function(fit) { # Prediction functions for submodels -------------------------------------- subprd <- function(fits, newdata) { - return(do.call(cbind, lapply(fits, function(fit) { - # Only pass argument `allow.new.levels` to the predict() generic if the fit - # is multilevel: - has_grp <- inherits(fit, c("lmerMod", "glmerMod")) - has_add <- inherits(fit, c("gam", "gamm4")) - if (has_add && !is.null(newdata)) { + prd_list <- lapply(fits, function(fit) { + is_glmm <- inherits(fit, c("lmerMod", "glmerMod")) + is_gam_gamm <- inherits(fit, c("gam", "gamm4")) + if (is_gam_gamm && !is.null(newdata)) { newdata <- cbind(`(Intercept)` = rep(1, NROW(newdata)), newdata) } - if (!has_grp) { - return(predict(fit, newdata = newdata)) - } else { + if (is_glmm) { return(predict(fit, newdata = newdata, allow.new.levels = TRUE)) + } else { + return(predict(fit, newdata = newdata)) } - }))) + }) + return(do.call(cbind, prd_list)) } ## FIXME: find a way that allows us to remove this diff --git a/R/methods.R b/R/methods.R index a5016dc27..8b5536d5c 100644 --- a/R/methods.R +++ b/R/methods.R @@ -626,6 +626,7 @@ print.vselsummary <- function(x, digits = 1, ...) { cat(paste0("Suggested Projection Size: ", x$suggested_size, "\n")) cat("\n") cat("Selection Summary:\n") + where <- "tidyselect" %:::% "where" print( x$selection %>% dplyr::mutate(dplyr::across( where(is.numeric), diff --git a/R/misc.R b/R/misc.R index 56c7f7cf3..8c3a395ca 100644 --- a/R/misc.R +++ b/R/misc.R @@ -430,6 +430,11 @@ deparse_combine <- function(x, max_char = NULL) { #' @export magrittr::`%>%` +# `R CMD check` throws a note when using :::() (for accessing +# which is not exported by its ). Of course, usage of +# non-exported functions should be avoided, but sometimes there's no way around +# that. Thus, with the following helper operator, it is possible to redefine +# such functions here in projpred: `%:::%` <- function(pkg, fun) { # Note: `utils::getFromNamespace(fun, pkg)` could probably be used, too (but # its documentation is unclear about the inheritance from parent @@ -437,11 +442,6 @@ magrittr::`%>%` get(fun, envir = asNamespace(pkg), inherits = FALSE) } -# Function where() is not exported by package tidyselect, so redefine it here to -# avoid a note in R CMD check which would occur for usage of -# tidyselect:::where(): -where <- "tidyselect" %:::% "where" - # Helper function to combine separate `list`s into a single `list`: rbind2list <- function(x) { as.list(do.call(rbind, lapply(x, as.data.frame))) diff --git a/R/projpred-package.R b/R/projpred-package.R index c2f252e8c..074965f56 100644 --- a/R/projpred-package.R +++ b/R/projpred-package.R @@ -92,15 +92,15 @@ #' #' Dupuis, J. A. and Robert, C. P. (2003). Variable selection in qualitative #' models via an entropic explanatory power. *Journal of Statistical Planning -#' and Inference*, **111**(1-2):77–94. DOI: \doi{10.1016/S0378-3758(02)00286-0}. +#' and Inference*, **111**(1-2):77–94. \doi{10.1016/S0378-3758(02)00286-0}. #' #' Piironen, J. and Vehtari, A. (2017). Comparison of Bayesian predictive #' methods for model selection. *Statistics and Computing*, **27**(3):711-735. -#' DOI: \doi{10.1007/s11222-016-9649-y}. +#' \doi{10.1007/s11222-016-9649-y}. #' #' Piironen, J., Paasiniemi, M., and Vehtari, A. (2020). Projective inference in #' high-dimensional problems: Prediction and feature selection. *Electronic -#' Journal of Statistics*, **14**(1):2155-2197. DOI: \doi{10.1214/20-EJS1711}. +#' Journal of Statistics*, **14**(1):2155-2197. \doi{10.1214/20-EJS1711}. #' #' Catalina, A., Bürkner, P.-C., and Vehtari, A. (2020). Projection predictive #' inference for generalized linear and additive multilevel models. diff --git a/R/refmodel.R b/R/refmodel.R index 0a4915b7e..64233e6d5 100644 --- a/R/refmodel.R +++ b/R/refmodel.R @@ -18,7 +18,8 @@ #' additionally to the properties required for [init_refmodel()]. For #' non-default methods of [get_refmodel()], an object of the corresponding #' class. -#' @param data Data used for fitting the reference model. +#' @param data Data used for fitting the reference model. Any `contrasts` +#' attributes of the dataset's columns are silently removed. #' @param formula Reference model's formula. For general information on formulas #' in \R, see [`formula`]. For multilevel formulas, see also package #' \pkg{lme4} (in particular, functions [lme4::lmer()] and [lme4::glmer()]). @@ -664,6 +665,15 @@ init_refmodel <- function(object, data, formula, family, ref_predfun = NULL, offset <- rep(0, NROW(y)) } + # For avoiding the warning "contrasts dropped from factor " when + # predicting for each projected draw, e.g., for submodels fit with lm()/glm(): + has_contr <- sapply(data, function(data_col) { + !is.null(attr(data_col, "contrasts")) + }) + for (idx_col in which(has_contr)) { + attr(data[[idx_col]], "contrasts") <- NULL + } + # Functions --------------------------------------------------------------- if (proper_model) { diff --git a/man/cv_varsel.Rd b/man/cv_varsel.Rd index b38998d17..3a1ce8ea7 100644 --- a/man/cv_varsel.Rd +++ b/man/cv_varsel.Rd @@ -220,7 +220,7 @@ if (requireNamespace("rstanarm", quietly = TRUE)) { \references{ Vehtari, A., Gelman, A., and Gabry, J. (2017). Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. \emph{Statistics and -Computing}, \strong{27}(5), 1413-1432. DOI: \doi{10.1007/s11222-016-9696-4}. +Computing}, \strong{27}(5), 1413-1432. \doi{10.1007/s11222-016-9696-4}. Vehtari, A., Simpson, D., Gelman, A., Yao, Y., and Gabry, J. (2021). Pareto smoothed importance sampling. \emph{arXiv:1507.02646}. URL: diff --git a/man/mesquite.Rd b/man/mesquite.Rd index f5126a12c..cbc629e0d 100644 --- a/man/mesquite.Rd +++ b/man/mesquite.Rd @@ -31,7 +31,7 @@ The mesquite bushes yields dataset from Gelman and Hill (2006) } \references{ Gelman, A. and Hill, J. (2006). \emph{Data Analysis Using Regression -and Multilevel/Hierarchical Models}. Cambridge University Press. DOI: +and Multilevel/Hierarchical Models}. Cambridge University Press. \doi{10.1017/CBO9780511790942}. } \keyword{datasets} diff --git a/man/projpred-package.Rd b/man/projpred-package.Rd index baf3c5ed7..918739013 100644 --- a/man/projpred-package.Rd +++ b/man/projpred-package.Rd @@ -84,15 +84,15 @@ models: A Bayesian approach via Kullback–Leibler projections. \emph{Biometrika Dupuis, J. A. and Robert, C. P. (2003). Variable selection in qualitative models via an entropic explanatory power. \emph{Journal of Statistical Planning -and Inference}, \strong{111}(1-2):77–94. DOI: \doi{10.1016/S0378-3758(02)00286-0}. +and Inference}, \strong{111}(1-2):77–94. \doi{10.1016/S0378-3758(02)00286-0}. Piironen, J. and Vehtari, A. (2017). Comparison of Bayesian predictive methods for model selection. \emph{Statistics and Computing}, \strong{27}(3):711-735. -DOI: \doi{10.1007/s11222-016-9649-y}. +\doi{10.1007/s11222-016-9649-y}. Piironen, J., Paasiniemi, M., and Vehtari, A. (2020). Projective inference in high-dimensional problems: Prediction and feature selection. \emph{Electronic -Journal of Statistics}, \strong{14}(1):2155-2197. DOI: \doi{10.1214/20-EJS1711}. +Journal of Statistics}, \strong{14}(1):2155-2197. \doi{10.1214/20-EJS1711}. Catalina, A., Bürkner, P.-C., and Vehtari, A. (2020). Projection predictive inference for generalized linear and additive multilevel models. diff --git a/man/refmodel-init-get.Rd b/man/refmodel-init-get.Rd index 74a04a4f5..06735e9f8 100644 --- a/man/refmodel-init-get.Rd +++ b/man/refmodel-init-get.Rd @@ -61,7 +61,8 @@ function \code{\link[mgcv:gam]{mgcv::gam()}}) and \pkg{gamm4} (in particular, fu \item{family}{A \code{\link{family}} object representing the observational model (i.e., the distributional family for the response).} -\item{data}{Data used for fitting the reference model.} +\item{data}{Data used for fitting the reference model. Any \code{contrasts} +attributes of the dataset's columns are silently removed.} \item{ref_predfun}{Prediction function for the linear predictor of the reference model, including offsets (if existing). See also section diff --git a/tests/testthat/test_as_matrix.R b/tests/testthat/test_as_matrix.R index 27be1bbdd..c2ec141cd 100644 --- a/tests/testthat/test_as_matrix.R +++ b/tests/testthat/test_as_matrix.R @@ -174,7 +174,7 @@ test_that("as.matrix.projection() works", { print(tstsetup) print(rlang::hash(m)) # cat(m) }) - options(width = width_orig$width) + options(width_orig) if (testthat_ed_max2) local_edition(2) } } @@ -276,6 +276,6 @@ if (run_snaps) { } }) - options(width = width_orig$width) + options(width_orig) if (testthat_ed_max2) local_edition(2) } diff --git a/tests/testthat/test_datafit.R b/tests/testthat/test_datafit.R index 6382955ab..211d5a2b9 100644 --- a/tests/testthat/test_datafit.R +++ b/tests/testthat/test_datafit.R @@ -435,7 +435,7 @@ test_that(paste( print(tstsetup) print(rlang::hash(pl_with_args)) }) - options(width = width_orig$width) + options(width_orig) if (testthat_ed_max2) local_edition(2) } } @@ -476,7 +476,7 @@ test_that(paste( print(tstsetup) print(rlang::hash(pp_with_args)) }) - options(width = width_orig$width) + options(width_orig) if (testthat_ed_max2) local_edition(2) } } @@ -538,7 +538,7 @@ test_that(paste( print(tstsetup) print(smmry, digits = 6) }) - options(width = width_orig$width) + options(width_orig) if (testthat_ed_max2) local_edition(2) } } @@ -577,7 +577,7 @@ test_that(paste( print(tstsetup) print(smmry, digits = 6) }) - options(width = width_orig$width) + options(width_orig) if (testthat_ed_max2) local_edition(2) } } diff --git a/tests/testthat/test_methods_vsel.R b/tests/testthat/test_methods_vsel.R index 2b2b0021f..1cf091af4 100644 --- a/tests/testthat/test_methods_vsel.R +++ b/tests/testthat/test_methods_vsel.R @@ -119,7 +119,7 @@ test_that("`x` of class \"vselsummary\" (based on varsel()) works", { print(tstsetup) print(smmrys_vs[[tstsetup]], digits = 6) }) - options(width = width_orig$width) + options(width_orig) if (testthat_ed_max2) local_edition(2) } } @@ -141,7 +141,7 @@ test_that("`x` of class \"vselsummary\" (based on cv_varsel()) works", { print(tstsetup) print(smmrys_cvvs[[tstsetup]], digits = 6) }) - options(width = width_orig$width) + options(width_orig) if (testthat_ed_max2) local_edition(2) } } diff --git a/tests/testthat/test_parallel.R b/tests/testthat/test_parallel.R index 10aaad182..47bbd8d43 100644 --- a/tests/testthat/test_parallel.R +++ b/tests/testthat/test_parallel.R @@ -81,12 +81,12 @@ if (run_prll) { doParallel::stopImplicitCluster() } else if (dopar_backend == "doFuture") { future::plan(future::sequential) - options(doFuture.foreach.export = export_default$doFuture.foreach.export) + options(export_default) rm(export_default) } else { stop("Unrecognized `dopar_backend`.") } - options(projpred.prll_prj_trigger = trigger_default$projpred.prll_prj_trigger) + options(trigger_default) rm(trigger_default) } diff --git a/tests/testthat/test_proj_pred.R b/tests/testthat/test_proj_pred.R index 25d232b8f..a73440a8d 100644 --- a/tests/testthat/test_proj_pred.R +++ b/tests/testthat/test_proj_pred.R @@ -16,7 +16,7 @@ test_that("pl: `object` of class \"projection\" works", { print(tstsetup) print(rlang::hash(pls[[tstsetup]])) }) - options(width = width_orig$width) + options(width_orig) if (testthat_ed_max2) local_edition(2) } } @@ -44,7 +44,7 @@ test_that(paste( print(tstsetup) print(rlang::hash(pls_vs[[tstsetup]])) }) - options(width = width_orig$width) + options(width_orig) if (testthat_ed_max2) local_edition(2) } } @@ -72,7 +72,7 @@ test_that(paste( print(tstsetup) print(rlang::hash(pls_cvvs[[tstsetup]])) }) - options(width = width_orig$width) + options(width_orig) if (testthat_ed_max2) local_edition(2) } } @@ -615,7 +615,7 @@ test_that("pp: `object` of class \"projection\" works", { print(tstsetup) print(rlang::hash(pps[[tstsetup]])) }) - options(width = width_orig$width) + options(width_orig) if (testthat_ed_max2) local_edition(2) } } @@ -643,7 +643,7 @@ test_that(paste( print(tstsetup) print(rlang::hash(pps_vs[[tstsetup]])) }) - options(width = width_orig$width) + options(width_orig) if (testthat_ed_max2) local_edition(2) } } @@ -671,7 +671,7 @@ test_that(paste( print(tstsetup) print(rlang::hash(pps_cvvs[[tstsetup]])) }) - options(width = width_orig$width) + options(width_orig) if (testthat_ed_max2) local_edition(2) } } diff --git a/tests/testthat/test_refmodel.R b/tests/testthat/test_refmodel.R index d679b5e02..9bef1993d 100644 --- a/tests/testthat/test_refmodel.R +++ b/tests/testthat/test_refmodel.R @@ -203,7 +203,7 @@ test_that(paste( print(rlang::hash(predref_ynew_resp)) print(rlang::hash(predref_ynew_link)) }) - options(width = width_orig$width) + options(width_orig) if (testthat_ed_max2) local_edition(2) } } diff --git a/vignettes/projpred.Rmd b/vignettes/projpred.Rmd index 86b158a6a..99cdb6a7c 100755 --- a/vignettes/projpred.Rmd +++ b/vignettes/projpred.Rmd @@ -244,7 +244,7 @@ This PPPC shows that our final projection is able to generate predictions simila - +