Skip to content

Commit

Permalink
Merge pull request #227 from spsanderson/development
Browse files Browse the repository at this point in the history
Development
  • Loading branch information
spsanderson authored Jan 25, 2024
2 parents 15834c6 + 23a0c34 commit 5371ced
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions R/internals-make-preds-wflw.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ internal_make_wflw_predictions <- function(.model_tbl, .splits_obj){
fitted_wflw = obj |> dplyr::pull(7) |> purrr::pluck(1)

# Get rec_obj
# rec_obj <- workflows::extract_preprocessor(fitted_wflw)

# Create a safe stats::predict
safe_stats_predict <- purrr::safely(
Expand Down Expand Up @@ -133,9 +132,17 @@ internal_make_wflw_predictions <- function(.model_tbl, .splits_obj){
# Get training predictions
train_res <- fitted_wflw |>
broom::augment(new_data = rsample::training(splits_obj$splits)) |>
dplyr::mutate(.data_type = "training") |>
dplyr::select(.data_type, !!pred_col_nm) |>
purrr::set_names(c(".data_type", ".value"))
dplyr::mutate(.data_type = "training")
train_res_nms <- names(train_res)
if (".pred_class" %in% train_res_nms){
train_res <- train_res |>
dplyr::select(.data_type, .pred_class) |>
purrr::set_names(c(".data_type", ".value"))
} else {
train_res <- train_res |>
dplyr::select(.data_type, !!pred_col_nm) |>
purrr::set_names(c(".data_type", ".value"))
}

# Get actual outcome values
pred_y <- names(fitted_wflw[["pre"]][["mold"]][["outcomes"]])
Expand Down
Empty file added R/plot-regression-predictions.R
Empty file.

0 comments on commit 5371ced

Please sign in to comment.