Skip to content

Commit

Permalink
Start to address stan-dev#402 by throwing a warning if `proj_predict(…
Browse files Browse the repository at this point in the history
…)` is

used with observation weights that are not all equal to `1`.
  • Loading branch information
fweber144 committed Apr 3, 2023
1 parent 126249b commit f3bb169
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 10 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ If you read this from a place other than <https://mc-stan.org/projpred/news/inde
* Output element `pct_solution_terms_cv` has now also been added to `vsel` objects returned by `varsel()`, but in that case, it is simply `NULL`. This (`pct_solution_terms_cv` being `NULL`) is now also the case if `validate_search = FALSE` was used in `cv_varsel()`.
* Minor enhancements in the documentation.
* Enhancements in the vignettes. In particular, section ["Troubleshooting"](https://mc-stan.org/projpred/articles/projpred.html#troubleshooting) of the main vignette has been revised.
* If `proj_predict()` is used with observation weights that are not all equal to `1`, a warning is now thrown. (GitHub: starts to address #402)

## Bug fixes

Expand Down
10 changes: 1 addition & 9 deletions R/augdat.R
Original file line number Diff line number Diff line change
Expand Up @@ -250,18 +250,10 @@ ppd_cats <- function(mu_arr, margin_draws = 3, wobs = 1, return_vec = FALSE) {
margin_cats <- 2
bind_fun <- rbind
}
### Currently unused:
# if (length(wobs) == 0) {
# wobs <- rep(1, length(y))
# } else if (length(wobs) == 1) {
# wobs <- rep(wobs, length(y))
# } else if (length(wobs) != length(y)) {
# stop("Argument `wobs` needs to be of length 0, 1, or `length(y)`.")
# }
###
n_draws <- dim(mu_arr)[margin_draws]
n_obs <- dim(mu_arr)[margin_obs]
n_cat <- dim(mu_arr)[margin_cats]
wobs <- parse_wobs_ppd(wobs, n_obs = n_obs)
if (return_vec) {
stopifnot(n_draws == 1)
bind_fun <- c
Expand Down
4 changes: 4 additions & 0 deletions R/extend_family.R
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ extend_family_poisson <- function(family) {
-2 * dpois_log_reduced(x = y, lamb = mu, wobs = weights)
}
ppd_poiss <- function(mu, dis, weights = 1) {
weights <- parse_wobs_ppd(weights, n_obs = length(mu))
rpois(length(mu), mu)
}

Expand Down Expand Up @@ -592,6 +593,7 @@ extend_family_gaussian <- function(family) {
-2 * weights * (-0.5 / dis^2 * (y - mu)^2 - log(dis))
}
ppd_gauss <- function(mu, dis, weights = 1) {
weights <- parse_wobs_ppd(weights, n_obs = length(mu))
rnorm(length(mu), mu, dis)
}

Expand Down Expand Up @@ -639,6 +641,7 @@ extend_family_gamma <- function(family) {
## weights*dgamma(y, dis, dis/matrix(mu), log= TRUE)
}
ppd_gamma <- function(mu, dis, weights = 1) {
weights <- parse_wobs_ppd(weights, n_obs = length(mu))
rgamma(length(mu), dis, dis / mu)
}

Expand Down Expand Up @@ -694,6 +697,7 @@ extend_family_student_t <- function(family) {
* log(1 + 1 / family$nu * ((y - mu) / dis)^2) - log(dis)))
}
ppd_student_t <- function(mu, dis, weights = 1) {
weights <- parse_wobs_ppd(weights, n_obs = length(mu))
rt(length(mu), family$nu) * dis + mu
}

Expand Down
1 change: 1 addition & 0 deletions R/latent.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ latent_ll_oscale_poiss <- function(ilpreds, y_oscale,
latent_ppd_oscale_poiss <- function(ilpreds_resamp, wobs, cl_ref,
wdraws_ref = rep(1, length(cl_ref)),
idxs_prjdraws) {
wobs <- parse_wobs_ppd(wobs, n_obs = ncol(ilpreds_resamp))
ppd <- rpois(prod(dim(ilpreds_resamp)), lambda = ilpreds_resamp)
ppd <- matrix(ppd, nrow = nrow(ilpreds_resamp), ncol = ncol(ilpreds_resamp))
return(ppd)
Expand Down
5 changes: 4 additions & 1 deletion R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@
#' @param ... Arguments passed to [project()] if `object` is not already an
#' object returned by [project()].
#'
#' @details In case of the latent projection and `transform = FALSE`:
#' @details Currently, [proj_predict()] ignores observation weights that are not
#' equal to `1`. A corresponding warning is thrown if this is the case.
#'
#' In case of the latent projection and `transform = FALSE`:
#' * Output element `pred` contains the linear predictors without any
#' modifications that may be due to the original response distribution (e.g.,
#' for a [brms::cumulative()] model, the ordered thresholds are not taken into
Expand Down
18 changes: 18 additions & 0 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -579,3 +579,21 @@ verb_out <- function(..., verbose = TRUE) {
cat(..., "\n", sep = "")
}
}

# Parse the argument containing the observation weights (`wobs` or `weights`)
# for the <family_object>$ppd() functions used by proj_predict():
parse_wobs_ppd <- function(wobs, n_obs) {
if (length(wobs) == 0) {
wobs <- rep(1, n_obs)
} else if (length(wobs) == 1) {
wobs <- rep(wobs, n_obs)
} else if (length(wobs) != n_obs) {
stop("Argument `wobs` needs to be of length 0, 1, or the number of ",
"observations.")
}
if (!all(wobs == 1) && getOption("projpred.warn_wobs_ppd", TRUE)) {
warning("Currently, proj_predict() ignores observation weights not equal ",
"to `1`.")
}
return(wobs)
}
3 changes: 3 additions & 0 deletions man/pred-projection.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions tests/testthat/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,9 @@ meth_tst <- list(
# Suppress the warning for interaction terms being selected before all involved
# main effects have been selected (only concerns L1 search):
options(projpred.warn_L1_interactions = FALSE)
# Suppress the warning thrown by proj_predict() in case of observation weights
# that are not all equal to `1`:
options(projpred.warn_wobs_ppd = FALSE)

search_trms_tst <- list(
default_search_trms = list(),
Expand Down

0 comments on commit f3bb169

Please sign in to comment.