From 6a79e9f70a1aca31ce212d523d84ad4f6a115ba5 Mon Sep 17 00:00:00 2001 From: "Steven Paul Sanderson II, MPH" Date: Wed, 24 Jan 2024 22:35:46 -0500 Subject: [PATCH 1/2] Update internals-make-preds-wflw.R minor fix for when model is classification --- R/internals-make-preds-wflw.R | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/R/internals-make-preds-wflw.R b/R/internals-make-preds-wflw.R index 9b91bd1..f99deb1 100644 --- a/R/internals-make-preds-wflw.R +++ b/R/internals-make-preds-wflw.R @@ -101,7 +101,6 @@ internal_make_wflw_predictions <- function(.model_tbl, .splits_obj){ fitted_wflw = obj |> dplyr::pull(7) |> purrr::pluck(1) # Get rec_obj - # rec_obj <- workflows::extract_preprocessor(fitted_wflw) # Create a safe stats::predict safe_stats_predict <- purrr::safely( @@ -133,9 +132,17 @@ internal_make_wflw_predictions <- function(.model_tbl, .splits_obj){ # Get training predictions train_res <- fitted_wflw |> broom::augment(new_data = rsample::training(splits_obj$splits)) |> - dplyr::mutate(.data_type = "training") |> - dplyr::select(.data_type, !!pred_col_nm) |> - purrr::set_names(c(".data_type", ".value")) + dplyr::mutate(.data_type = "training") + train_res_nms <- names(train_res) + if (".pred_class" %in% train_res_nms){ + train_res <- train_res |> + dplyr::select(.data_type, .pred_class) |> + purrr::set_names(c(".data_type", ".value")) + } else { + train_res <- train_res |> + dplyr::select(.data_type, !!pred_col_nm) |> + purrr::set_names(c(".data_type", ".value")) + } # Get actual outcome values pred_y <- names(fitted_wflw[["pre"]][["mold"]][["outcomes"]]) From 23a0c343646f45f8bee8a8186484274a9a5fd4db Mon Sep 17 00:00:00 2001 From: "Steven Paul Sanderson II, MPH" Date: Wed, 24 Jan 2024 22:35:53 -0500 Subject: [PATCH 2/2] Create plot-regression-predictions.R start file --- R/plot-regression-predictions.R | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 R/plot-regression-predictions.R diff --git a/R/plot-regression-predictions.R b/R/plot-regression-predictions.R new file mode 100644 index 0000000..e69de29