diff --git a/R/internals-make-preds-wflw.R b/R/internals-make-preds-wflw.R index 9b91bd1..f99deb1 100644 --- a/R/internals-make-preds-wflw.R +++ b/R/internals-make-preds-wflw.R @@ -101,7 +101,6 @@ internal_make_wflw_predictions <- function(.model_tbl, .splits_obj){ fitted_wflw = obj |> dplyr::pull(7) |> purrr::pluck(1) # Get rec_obj - # rec_obj <- workflows::extract_preprocessor(fitted_wflw) # Create a safe stats::predict safe_stats_predict <- purrr::safely( @@ -133,9 +132,17 @@ internal_make_wflw_predictions <- function(.model_tbl, .splits_obj){ # Get training predictions train_res <- fitted_wflw |> broom::augment(new_data = rsample::training(splits_obj$splits)) |> - dplyr::mutate(.data_type = "training") |> - dplyr::select(.data_type, !!pred_col_nm) |> - purrr::set_names(c(".data_type", ".value")) + dplyr::mutate(.data_type = "training") + train_res_nms <- names(train_res) + if (".pred_class" %in% train_res_nms){ + train_res <- train_res |> + dplyr::select(.data_type, .pred_class) |> + purrr::set_names(c(".data_type", ".value")) + } else { + train_res <- train_res |> + dplyr::select(.data_type, !!pred_col_nm) |> + purrr::set_names(c(".data_type", ".value")) + } # Get actual outcome values pred_y <- names(fitted_wflw[["pre"]][["mold"]][["outcomes"]]) diff --git a/R/plot-regression-predictions.R b/R/plot-regression-predictions.R new file mode 100644 index 0000000..e69de29