Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collect draw-wise projection warnings and check projection convergence #478

Merged
merged 46 commits into from
Nov 22, 2023

Conversation

fweber144
Copy link
Collaborator

This PR makes projpred catch messages and warnings from the draw-wise divergence minimizers and also check their convergence (as well as possible). Previously, projpred suppressed such messages and warnings and did not check convergence (PRs #259 and #444 started/modified the convergence checker, but it has remained a "hidden"—because unfinished—feature until now).

For deactivating these two features, global options projpred.warn_prj_drawwise and projpred.check_conv have been added (see the NEWS.md entries added here).

In my opinion, especially the convergence checker is a crucial feature, see, e.g., issue #323. The messages and warnings from the draw-wise divergence minimizers are intended as a help for the user to find out what might be going wrong without having to debug.

The convergence checks for additive models are probably still incomplete, even with this PR. I'll open a new issue for this.

Illustration:

# Setup -------------------------------------------------------------------

warn_length_orig <- options(warning.length = 8170)
devtools::load_all()

# glm_ridge(), glm_elnet() as submodel fitters ----------------------------

data("df_binom", package = "projpred")
dat <- data.frame(y = df_binom$y, df_binom$x)
fit_glm <- rstanarm::stan_glm(y ~ X1 + X2 + X3,
                              family = binomial(),
                              data = dat,
                              chains = 1,
                              iter = 500,
                              seed = 1140350788,
                              refresh = 0)

# Warning from glm_ridge():
prj <- project(fit_glm, predictor_terms = c("X1"), nclusters = 1, thresh = 0)

# Warning from glm_ridge() during the refits for performance evaluation:
vs <- varsel(fit_glm, method = "L1", nclusters_pred = 2, qa_updates_max = 2)
# Alternatively (this is a different warning, though):
vs <- varsel(fit_glm, method = "L1", nclusters_pred = 2, thresh_conv = 0)

# Warning from glm_ridge() during the forward search as well as during the
# refits for performance evaluation:
vs <- varsel(fit_glm, nclusters = 1, nclusters_pred = 2, qa_updates_max = 2)
# Alternatively (this is a different warning, though):
vs <- varsel(fit_glm, nclusters = 1, nclusters_pred = 2, thresh_conv = 0)

# Warning from glm_ridge() during the forward search:
vs <- varsel(fit_glm, nclusters = 1, nclusters_pred = 2,
             search_control = list(qa_updates_max = 2))
# Alternatively (this is a different warning, though):
vs <- varsel(fit_glm, nclusters = 1, nclusters_pred = 2,
             search_control = list(thresh_conv = 0))

# Warning from glm_ridge() during the refits for performance evaluation:
vs <- varsel(fit_glm, nclusters = 1, nclusters_pred = 2,
             search_control = list(), qa_updates_max = 2)
# Alternatively (this is a different warning, though):
vs <- varsel(fit_glm, nclusters = 1, nclusters_pred = 2,
             search_control = list(), thresh_conv = 0)

# Warning from glm_elnet() during the L1 search:
vs <- varsel(fit_glm, method = "L1", refit_prj = FALSE,
             search_control = list(thresh = 1e-330, nlambda = 1))

# MASS::polr() as submodel fitter -----------------------------------------

data("inhaler", package = "brms")
inhaler$rating <- as.factor(paste0("rtg", inhaler$rating))

fit_polr <- rstanarm::stan_polr(
  rating ~ period + carry + treat,
  data = inhaler,
  prior = rstanarm::R2(location = 0.5, what = "median"),
  chains = 1,
  iter = 500,
  seed = 1140350788,
  refresh = 0
)

# Non-convergence in MASS::polr():
prj <- project(fit_polr, predictor_terms = c("carry", "treat"), nclusters = 1,
               control = list(maxit = 1))

# Teardown ----------------------------------------------------------------

options(warn_length_orig)

for checking the convergence of a single submodel fit (not of a whole `outdmin`
object).
thrown in case of global option `projpred.warn_submodel_fits` set to `TRUE`.
… fit) to a warning

(to avoid that this causes an error; the code should still run through).
…bmodel_fits` and `projpred.check_conv`

(these local arguments can be passed to top-level functions like `varsel()`,
`cv_varsel()`, and `project()`).
where tuning parameters may be found (which in turn is achieved by mentioning
the class(es) of the submodel fits).
…model to most complex model."

Reason for the revert: For example, `class(<gam_fit>)` yields `c("gam", "glm",
"lm")`, so it's indeed better to start with the most complex type of model.
…ique

`stdout()` output messages as warnings.
(that's why we already needed all those `warn_expected <- "non-integer
tests to `warn_prj_drawwise()` and `check_conv()` as well.
`fit_s$mgcv.conv$fully.converged` may be (or perhaps is always) `NULL`.
(to avoid that such a minor issue as a defective convergence checker prevents
the code from running through).
…ed.check_conv`

(in the general package documentation).
thrown if the draw-wise divergence minimizer threw only informational messages).
@fweber144 fweber144 merged commit 97c5bea into stan-dev:master Nov 22, 2023
@fweber144 fweber144 deleted the check_conv_public branch November 22, 2023 14:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant