Skip to content

Commit

Permalink
Start fixing stan-dev#345.
Browse files Browse the repository at this point in the history
  • Loading branch information
fweber144 committed Sep 15, 2022
1 parent fae19ef commit 9e85dc5
Showing 1 changed file with 32 additions and 17 deletions.
49 changes: 32 additions & 17 deletions R/cv_varsel.R
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,23 @@ cv_varsel.refmodel <- function(
## search options
opt <- nlist(lambda_min_ratio, nlambda, thresh, regul)

### TODO:
candidate_terms <- search_terms
# Only a quick-and-dirty solution (perhaps we can solve this in a more elegant
# way, but in any case, we should create a helper function for this):
candidate_terms <- gsub("[[:blank:]]*\\+[[:blank:]]*", " + ",
candidate_terms)
###
candidate_terms <- setdiff(candidate_terms, "1")

if (cv_method == "LOO") {
sel_cv <- loo_varsel(
refmodel = refmodel, method = method, nterms_max = nterms_max,
ndraws = ndraws, nclusters = nclusters, ndraws_pred = ndraws_pred,
nclusters_pred = nclusters_pred, refit_prj = refit_prj, penalty = penalty,
verbose = verbose, opt = opt, nloo = nloo,
validate_search = validate_search, search_terms = search_terms, ...
validate_search = validate_search, search_terms = search_terms,
candidate_terms = candidate_terms, ...
)
} else if (cv_method == "kfold") {
sel_cv <- kfold_varsel(
Expand Down Expand Up @@ -206,10 +216,6 @@ cv_varsel.refmodel <- function(
# paths. For the column names (and therefore the order of the solution terms
# in the columns), the solution path from the full-data search is used. Note
# that the following code assumes that all CV folds have equal weight.
candidate_terms <- split_formula(refmodel$formula,
data = refmodel$fetch_data(),
add_main_effects = FALSE)
candidate_terms <- setdiff(candidate_terms, "1")
solution_terms_cv_chr <- do.call(cbind, lapply(
seq_len(NROW(sel_cv$solution_terms_cv)),
function(i) {
Expand Down Expand Up @@ -310,7 +316,8 @@ parse_args_cv_varsel <- function(refmodel, cv_method, K, validate_search) {
loo_varsel <- function(refmodel, method, nterms_max, ndraws,
nclusters, ndraws_pred, nclusters_pred, refit_prj,
penalty, verbose, opt, nloo = NULL,
validate_search = TRUE, search_terms = NULL, ...) {
validate_search = TRUE, search_terms = NULL,
candidate_terms, ...) {
##
## Perform the validation of the searching process using LOO. validate_search
## indicates whether the selection is performed separately for each fold (for
Expand Down Expand Up @@ -364,7 +371,8 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws,
inds <- validset$inds

## initialize objects where to store the results
solution_terms_mat <- matrix(nrow = n, ncol = nterms_max - 1)
prv_len_soltrms <- NULL
solution_terms_mat <- matrix(nrow = n, ncol = length(candidate_terms))
loo_sub <- replicate(nterms_max, rep(NA, n), simplify = FALSE)
mu_sub <- replicate(nterms_max, rep(NA, n), simplify = FALSE)

Expand Down Expand Up @@ -441,13 +449,14 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws,
}
}

candidate_terms <- split_formula(refmodel$formula,
data = refmodel$fetch_data(),
add_main_effects = FALSE)
prv_len_soltrms <- length(search_path$solution_terms)
## with `match` we get the indices of the variables as they enter the
## solution path in `search_path$solution_terms`
solution <- match(search_path$solution_terms,
setdiff(candidate_terms, "1"))
### TODO:
# Need to adapt `search_path$solution_terms` so that the following match()
# call works.
###
solution <- match(search_path$solution_terms, candidate_terms)
for (i in seq_len(n)) {
solution_terms_mat[i, seq_along(solution)] <- solution
}
Expand Down Expand Up @@ -500,13 +509,18 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws,
mu_sub[[k]][i] <- summaries_sub[[k]]$mu
}

candidate_terms <- split_formula(refmodel$formula,
data = refmodel$fetch_data(),
add_main_effects = FALSE)
if (!is.null(prv_len_soltrms)) {
stopifnot(identical(length(search_path$solution_terms),
prv_len_soltrms))
}
prv_len_soltrms <- length(search_path$solution_terms)
## with `match` we get the indices of the variables as they enter the
## solution path in `search_path$solution_terms`
solution <- match(search_path$solution_terms,
setdiff(candidate_terms, "1"))
### TODO:
# Need to adapt `search_path$solution_terms` so that the following match()
# call works.
###
solution <- match(search_path$solution_terms, candidate_terms)
solution_terms_mat[i, seq_along(solution)] <- solution

if (verbose) {
Expand All @@ -530,6 +544,7 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws,
d_test <- list(type = "LOO", data = NULL, offset = refmodel$offset,
weights = refmodel$wobs, y = refmodel$y)

solution_terms_mat <- solution_terms_mat[, seq_along(solution), drop = FALSE]
out_list <- nlist(solution_terms_cv = solution_terms_mat, summaries, d_test)
if (!validate_search) {
out_list <- c(out_list, nlist(sel))
Expand Down

0 comments on commit 9e85dc5

Please sign in to comment.