Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix contrasts issue #283 and minor other changes #284

Merged
merged 9 commits into from
Mar 3, 2022
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
* Argument `fit` of `init_refmodel()`'s argument `proj_predfun` was renamed to `fits`. This is a non-breaking change since all calls to `proj_predfun` in **projpred** have that argument unnamed. However, this cannot be guaranteed in the future, so we strongly encourage users with a custom `proj_predfun` to rename argument `fit` to `fits`. (GitHub: #263)
* `init_refmodel()` has gained argument `cvrefbuilder` which may be a custom function for constructing the K reference models in a K-fold CV. (GitHub: #271)
* Allow arguments to be passed from `project()`, `varsel()`, and `cv_varsel()` to the divergence minimizer. (GitHub: #278)
* In `init_refmodel()`, any `contrasts` attributes of the dataset's columns are silently removed. (GitHub: #284)

## Bug fixes

Expand Down
2 changes: 1 addition & 1 deletion R/cv_varsel.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
#'
#' Vehtari, A., Gelman, A., and Gabry, J. (2017). Practical Bayesian model
#' evaluation using leave-one-out cross-validation and WAIC. *Statistics and
#' Computing*, **27**(5), 1413-1432. DOI: \doi{10.1007/s11222-016-9696-4}.
#' Computing*, **27**(5), 1413-1432. \doi{10.1007/s11222-016-9696-4}.
#'
#' Vehtari, A., Simpson, D., Gelman, A., Yao, Y., and Gabry, J. (2021). Pareto
#' smoothed importance sampling. *arXiv:1507.02646*. URL:
Expand Down
2 changes: 1 addition & 1 deletion R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
#' group).}
#' }
#' @references Gelman, A. and Hill, J. (2006). *Data Analysis Using Regression
#' and Multilevel/Hierarchical Models*. Cambridge University Press. DOI:
#' and Multilevel/Hierarchical Models*. Cambridge University Press.
#' \doi{10.1017/CBO9780511790942}.
#' @source <http://www.stat.columbia.edu/~gelman/arm/examples/mesquite/mesquite.dat>
"mesquite"
101 changes: 52 additions & 49 deletions R/divergence_minimizers.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,12 @@ divmin <- function(formula, projpred_var, ...) {
}
}

# Use projpred's own implementation to fit non-multilevel non-additive
# submodels:
fit_glm_ridge_callback <- function(formula, data,
projpred_var = matrix(nrow = nrow(data)),
projpred_regul = 1e-4, ...) {
# Preparations:
fr <- model.frame(delete.intercept(formula), data = data)
contrasts_arg <- get_contrasts_arg_list(formula, data = data)
x <- model.matrix(fr, data = data, contrasts.arg = contrasts_arg)
Expand All @@ -98,10 +101,12 @@ fit_glm_ridge_callback <- function(formula, data,
names(dot_args),
methods::formalArgs(glm_ridge)
)]
# Call the submodel fitter:
fit <- do.call(glm_ridge, c(
list(x = x, y = y, lambda = projpred_regul, obsvar = projpred_var),
dot_args
))
# Post-processing:
rownames(fit$beta) <- colnames(x)
sub <- nlist(
alpha = fit$beta0,
Expand All @@ -118,37 +123,34 @@ fit_glm_ridge_callback <- function(formula, data,
# `projpred.glm_fitter`):
fit_glm_callback <- function(formula, family, projpred_var, projpred_regul,
...) {
tryCatch({
if (family$family == "gaussian" && family$link == "identity") {
# Exclude arguments from `...` which cannot be passed to stats::lm():
dot_args <- list(...)
dot_args <- dot_args[intersect(
names(dot_args),
union(methods::formalArgs(stats::lm),
union(methods::formalArgs(stats::lm.fit),
methods::formalArgs(stats::lm.wfit)))
)]
return(suppressMessages(suppressWarnings(do.call(stats::lm, c(
list(formula = formula),
dot_args
)))))
} else {
# Exclude arguments from `...` which cannot be passed to stats::glm():
dot_args <- list(...)
dot_args <- dot_args[intersect(
names(dot_args),
union(methods::formalArgs(stats::glm),
methods::formalArgs(stats::glm.control))
)]
return(suppressMessages(suppressWarnings(do.call(stats::glm, c(
list(formula = formula, family = family),
dot_args
)))))
}
}, error = function(e) {
# May be used to handle errors.
stop(e)
})
if (family$family == "gaussian" && family$link == "identity") {
# Exclude arguments from `...` which cannot be passed to stats::lm():
dot_args <- list(...)
dot_args <- dot_args[intersect(
names(dot_args),
c(methods::formalArgs(stats::lm),
methods::formalArgs(stats::lm.fit),
methods::formalArgs(stats::lm.wfit))
)]
# Call the submodel fitter:
return(suppressMessages(suppressWarnings(do.call(stats::lm, c(
list(formula = formula),
dot_args
)))))
} else {
# Exclude arguments from `...` which cannot be passed to stats::glm():
dot_args <- list(...)
dot_args <- dot_args[intersect(
names(dot_args),
c(methods::formalArgs(stats::glm),
methods::formalArgs(stats::glm.control))
)]
# Call the submodel fitter:
return(suppressMessages(suppressWarnings(do.call(stats::glm, c(
list(formula = formula, family = family),
dot_args
)))))
}
}

# Use package "mgcv" to fit additive non-multilevel submodels:
Expand All @@ -158,9 +160,10 @@ fit_gam_callback <- function(formula, ...) {
dot_args <- list(...)
dot_args <- dot_args[intersect(
names(dot_args),
union(methods::formalArgs(gam),
methods::formalArgs(mgcv::gam.fit))
c(methods::formalArgs(gam),
methods::formalArgs(mgcv::gam.fit))
)]
# Call the submodel fitter:
return(suppressMessages(suppressWarnings(do.call(gam, c(
list(formula = formula),
dot_args
Expand All @@ -176,10 +179,11 @@ fit_gamm_callback <- function(formula, projpred_formula_no_random,
dot_args <- list(...)
dot_args <- dot_args[intersect(
names(dot_args),
union(union(methods::formalArgs(gamm4),
methods::formalArgs(lme4::lFormula)),
methods::formalArgs(lme4::glFormula))
c(methods::formalArgs(gamm4),
methods::formalArgs(lme4::lFormula),
methods::formalArgs(lme4::glFormula))
)]
# Call the submodel fitter:
fit <- tryCatch({
suppressMessages(suppressWarnings(do.call(gamm4, c(
list(formula = projpred_formula_no_random, random = projpred_random,
Expand Down Expand Up @@ -211,9 +215,7 @@ fit_gamm_callback <- function(formula, projpred_formula_no_random,
return(fit)
}

# Use package "lme4" to fit submodels for multilevel reference models (with a
# fallback to "projpred"'s own implementation for fitting non-multilevel (and
# non-additive) submodels):
# Use package "lme4" to fit multilevel submodels:
fit_glmer_callback <- function(formula, family,
control = control_callback(family), ...) {
tryCatch({
Expand All @@ -224,6 +226,7 @@ fit_glmer_callback <- function(formula, family,
names(dot_args),
methods::formalArgs(lme4::lmer)
)]
# Call the submodel fitter:
return(suppressMessages(suppressWarnings(do.call(lme4::lmer, c(
list(formula = formula, control = control),
dot_args
Expand All @@ -235,6 +238,7 @@ fit_glmer_callback <- function(formula, family,
names(dot_args),
methods::formalArgs(lme4::glmer)
)]
# Call the submodel fitter:
return(suppressMessages(suppressWarnings(do.call(lme4::glmer, c(
list(formula = formula, family = family,
control = control),
Expand Down Expand Up @@ -413,20 +417,19 @@ check_conv <- function(fit) {
# Prediction functions for submodels --------------------------------------

subprd <- function(fits, newdata) {
return(do.call(cbind, lapply(fits, function(fit) {
# Only pass argument `allow.new.levels` to the predict() generic if the fit
# is multilevel:
has_grp <- inherits(fit, c("lmerMod", "glmerMod"))
has_add <- inherits(fit, c("gam", "gamm4"))
if (has_add && !is.null(newdata)) {
prd_list <- lapply(fits, function(fit) {
is_glmm <- inherits(fit, c("lmerMod", "glmerMod"))
is_gam_gamm <- inherits(fit, c("gam", "gamm4"))
if (is_gam_gamm && !is.null(newdata)) {
newdata <- cbind(`(Intercept)` = rep(1, NROW(newdata)), newdata)
}
if (!has_grp) {
return(predict(fit, newdata = newdata))
} else {
if (is_glmm) {
return(predict(fit, newdata = newdata, allow.new.levels = TRUE))
} else {
return(predict(fit, newdata = newdata))
}
})))
})
return(do.call(cbind, prd_list))
}

## FIXME: find a way that allows us to remove this
Expand Down
1 change: 1 addition & 0 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ print.vselsummary <- function(x, digits = 1, ...) {
cat(paste0("Suggested Projection Size: ", x$suggested_size, "\n"))
cat("\n")
cat("Selection Summary:\n")
where <- "tidyselect" %:::% "where"
print(
x$selection %>% dplyr::mutate(dplyr::across(
where(is.numeric),
Expand Down
10 changes: 5 additions & 5 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -430,18 +430,18 @@ deparse_combine <- function(x, max_char = NULL) {
#' @export
magrittr::`%>%`

# `R CMD check` throws a note when using <package>:::<function>() (for accessing
# <function> which is not exported by its <package>). Of course, usage of
# non-exported functions should be avoided, but sometimes there's no way around
# that. Thus, with the following helper operator, it is possible to redefine
# such functions here in projpred:
`%:::%` <- function(pkg, fun) {
# Note: `utils::getFromNamespace(fun, pkg)` could probably be used, too (but
# its documentation is unclear about the inheritance from parent
# environments).
get(fun, envir = asNamespace(pkg), inherits = FALSE)
}

# Function where() is not exported by package tidyselect, so redefine it here to
# avoid a note in R CMD check which would occur for usage of
# tidyselect:::where():
where <- "tidyselect" %:::% "where"

# Helper function to combine separate `list`s into a single `list`:
rbind2list <- function(x) {
as.list(do.call(rbind, lapply(x, as.data.frame)))
Expand Down
6 changes: 3 additions & 3 deletions R/projpred-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,15 @@
#'
#' Dupuis, J. A. and Robert, C. P. (2003). Variable selection in qualitative
#' models via an entropic explanatory power. *Journal of Statistical Planning
#' and Inference*, **111**(1-2):77–94. DOI: \doi{10.1016/S0378-3758(02)00286-0}.
#' and Inference*, **111**(1-2):77–94. \doi{10.1016/S0378-3758(02)00286-0}.
#'
#' Piironen, J. and Vehtari, A. (2017). Comparison of Bayesian predictive
#' methods for model selection. *Statistics and Computing*, **27**(3):711-735.
#' DOI: \doi{10.1007/s11222-016-9649-y}.
#' \doi{10.1007/s11222-016-9649-y}.
#'
#' Piironen, J., Paasiniemi, M., and Vehtari, A. (2020). Projective inference in
#' high-dimensional problems: Prediction and feature selection. *Electronic
#' Journal of Statistics*, **14**(1):2155-2197. DOI: \doi{10.1214/20-EJS1711}.
#' Journal of Statistics*, **14**(1):2155-2197. \doi{10.1214/20-EJS1711}.
#'
#' Catalina, A., Bürkner, P.-C., and Vehtari, A. (2020). Projection predictive
#' inference for generalized linear and additive multilevel models.
Expand Down
12 changes: 11 additions & 1 deletion R/refmodel.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
#' additionally to the properties required for [init_refmodel()]. For
#' non-default methods of [get_refmodel()], an object of the corresponding
#' class.
#' @param data Data used for fitting the reference model.
#' @param data Data used for fitting the reference model. Any `contrasts`
#' attributes of the dataset's columns are silently removed.
#' @param formula Reference model's formula. For general information on formulas
#' in \R, see [`formula`]. For multilevel formulas, see also package
#' \pkg{lme4} (in particular, functions [lme4::lmer()] and [lme4::glmer()]).
Expand Down Expand Up @@ -664,6 +665,15 @@ init_refmodel <- function(object, data, formula, family, ref_predfun = NULL,
offset <- rep(0, NROW(y))
}

# For avoiding the warning "contrasts dropped from factor <factor_name>" when
# predicting for each projected draw, e.g., for submodels fit with lm()/glm():
has_contr <- sapply(data, function(data_col) {
!is.null(attr(data_col, "contrasts"))
})
for (idx_col in which(has_contr)) {
attr(data[[idx_col]], "contrasts") <- NULL
}

# Functions ---------------------------------------------------------------

if (proper_model) {
Expand Down
2 changes: 1 addition & 1 deletion man/cv_varsel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mesquite.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/projpred-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/refmodel-init-get.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/test_as_matrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ test_that("as.matrix.projection() works", {
print(tstsetup)
print(rlang::hash(m)) # cat(m)
})
options(width = width_orig$width)
options(width_orig)
if (testthat_ed_max2) local_edition(2)
}
}
Expand Down Expand Up @@ -276,6 +276,6 @@ if (run_snaps) {
}
})

options(width = width_orig$width)
options(width_orig)
if (testthat_ed_max2) local_edition(2)
}
8 changes: 4 additions & 4 deletions tests/testthat/test_datafit.R
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ test_that(paste(
print(tstsetup)
print(rlang::hash(pl_with_args))
})
options(width = width_orig$width)
options(width_orig)
if (testthat_ed_max2) local_edition(2)
}
}
Expand Down Expand Up @@ -476,7 +476,7 @@ test_that(paste(
print(tstsetup)
print(rlang::hash(pp_with_args))
})
options(width = width_orig$width)
options(width_orig)
if (testthat_ed_max2) local_edition(2)
}
}
Expand Down Expand Up @@ -538,7 +538,7 @@ test_that(paste(
print(tstsetup)
print(smmry, digits = 6)
})
options(width = width_orig$width)
options(width_orig)
if (testthat_ed_max2) local_edition(2)
}
}
Expand Down Expand Up @@ -577,7 +577,7 @@ test_that(paste(
print(tstsetup)
print(smmry, digits = 6)
})
options(width = width_orig$width)
options(width_orig)
if (testthat_ed_max2) local_edition(2)
}
}
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test_methods_vsel.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ test_that("`x` of class \"vselsummary\" (based on varsel()) works", {
print(tstsetup)
print(smmrys_vs[[tstsetup]], digits = 6)
})
options(width = width_orig$width)
options(width_orig)
if (testthat_ed_max2) local_edition(2)
}
}
Expand All @@ -141,7 +141,7 @@ test_that("`x` of class \"vselsummary\" (based on cv_varsel()) works", {
print(tstsetup)
print(smmrys_cvvs[[tstsetup]], digits = 6)
})
options(width = width_orig$width)
options(width_orig)
if (testthat_ed_max2) local_edition(2)
}
}
Expand Down
Loading