From 08271270c8e973440421f1ced2d4396a867672e4 Mon Sep 17 00:00:00 2001 From: gregmacfarlane Date: Thu, 18 May 2023 16:12:34 -0600 Subject: [PATCH] Allow augment to be used on newdata for mlogit --- R/mlogit-tidiers.R | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/R/mlogit-tidiers.R b/R/mlogit-tidiers.R index 345cb7cc1..e3362f8b5 100644 --- a/R/mlogit-tidiers.R +++ b/R/mlogit-tidiers.R @@ -25,6 +25,11 @@ #' augment(m) #' glance(m) #' +#' # augment with newdata +#' Fish2 <- Fish +#' Fish2$price <- ifelse(Fish2$income < 3000, Fish2$price * 0.7, Fish2$price ) +#' augment(m, newdata = Fish2) +#' #' @aliases mlogit_tidiers #' @export #' @family mlogit tidiers @@ -56,23 +61,37 @@ tidy.mlogit <- function(x, conf.int = FALSE, conf.level = 0.95, ...) { #' #' @inherit tidy.mlogit params examples #' @param data Not currently used +#' @param newdata Data frame on which to predict utility values. See `details`. #' -#' @details At the moment this only works on the estimation dataset. Need to set -#' it up to predict on another dataset. +#' @details Augmenting a new data frame requires that the data be +#' a `dfidx` data frame with ID and alternative information identified. #' #' @export #' @seealso [augment()] #' @family mlogit tidiers #' #' -augment.mlogit <- function(x, data = x$model, ...) { +augment.mlogit <- function(x, data = x$model, newdata = NULL, ...) { check_ellipses("newdata", "augment", "mlogit", ...) + + # So, the way mlogit handles prediction is kind of silly, because + # the developers have chosen to not implement a model.matrix method. + # Rather, mlogit uses update to create a NEW model object but without + # running a new maximum likelihood estimation, it constrains the parameters + # to their previously estimated values. + # This does unfortunately mean that the data to be predicted has to be + # in a dfidx format. + if (!is.null(newdata)) { + x <- update(x, start = coef(x, fixed = TRUE), data = newdata, iterlim = 0, + print.level = 0) + } # the ID variables are really messed up, so we're going to do some # retrofitting because this ends up being a pretty important element of # what we want to do with the results. idx <- x$model$idx + # augment reg <- x$model %>% as_augment_tibble() %>% dplyr::select(-idx) %>% @@ -85,11 +104,15 @@ augment.mlogit <- function(x, data = x$model, ...) { # reappend the id columns dplyr::mutate( id = idx$id1, - alternative = idx$id2, - .resid = as.vector(x$residuals) + alternative = idx$id2 ) %>% dplyr::select(id, alternative, chosen, everything()) - + + # residuals don't make sense for newdata + if(is.null(newdata)){ + reg$.resid = as.vector(x$residuals) + } + reg }