Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use {marginaleffects} as default backend #342

Merged
merged 44 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
209c8c4
Use {marginaleffects} as default backend
strengejacke Jan 15, 2025
632fe79
Merge branch 'main' into strengejacke/issue341
strengejacke Jan 15, 2025
5b2a606
re-arrange tests
strengejacke Jan 15, 2025
2ec373b
Merge branch 'main' into strengejacke/issue341
strengejacke Jan 15, 2025
00f9047
update desc
strengejacke Jan 15, 2025
5cb88d0
tests
strengejacke Jan 15, 2025
493b60f
check all tests are also run with marginaleffects
strengejacke Jan 15, 2025
203a2ad
update snapshots
strengejacke Jan 15, 2025
4de6ac3
test
strengejacke Jan 15, 2025
6a194b3
Merge branch 'main' into strengejacke/issue341
strengejacke Jan 16, 2025
496f053
Merge branch 'main' into strengejacke/issue341
strengejacke Jan 16, 2025
9ff890e
Merge branch 'main' into strengejacke/issue341
strengejacke Jan 16, 2025
7ca9673
add tests
strengejacke Jan 16, 2025
9f86c6a
fix
strengejacke Jan 16, 2025
2303837
minor fixes
strengejacke Jan 17, 2025
4d3d17f
docs
strengejacke Jan 17, 2025
5b734c7
minor
strengejacke Jan 17, 2025
67d59b7
fix
strengejacke Jan 17, 2025
b4c2b1b
fixes, lintr
strengejacke Jan 17, 2025
bd97747
Fix #347 (#349)
DominiqueMakowski Jan 17, 2025
a0dc44e
adding test for #347
strengejacke Jan 17, 2025
b9849b3
styler
strengejacke Jan 17, 2025
baeb64d
switch
strengejacke Jan 17, 2025
0636390
typo
strengejacke Jan 17, 2025
192af6f
add test
strengejacke Jan 17, 2025
3133967
fix
strengejacke Jan 17, 2025
ed94e87
Update test-visualisation_recipe.R
strengejacke Jan 17, 2025
63f0f46
Update test-estimate_means.R
strengejacke Jan 17, 2025
f982367
Update test-estimate_contrasts.R
strengejacke Jan 17, 2025
be330c2
Update format.R
strengejacke Jan 17, 2025
1dc336b
Update format.R
strengejacke Jan 17, 2025
1c8d43c
Update estimate_contrasts.md
strengejacke Jan 17, 2025
5161c4b
Update test-estimate_means.R
strengejacke Jan 17, 2025
04fc2f7
Update format.R
strengejacke Jan 17, 2025
6e3301a
Update format.R
strengejacke Jan 17, 2025
008040a
D'oh!
strengejacke Jan 17, 2025
5b18ee4
Update test-glmmTMB.R
strengejacke Jan 17, 2025
8e7094d
Update format.R
strengejacke Jan 17, 2025
1504022
Update test-attributes_estimatefun.R
strengejacke Jan 17, 2025
7ae6a88
Update test-estimate_contrasts.R
strengejacke Jan 17, 2025
99d5ca2
update
strengejacke Jan 17, 2025
e85269f
Test on dev-version of marginaleffects
strengejacke Jan 17, 2025
8f12072
fix
strengejacke Jan 17, 2025
3a61673
finally
strengejacke Jan 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: modelbased
Title: Estimation of Model-Based Predictions, Contrasts and Means
Version: 0.8.9.31
Version: 0.8.9.32
Authors@R:
c(person(given = "Dominique",
family = "Makowski",
Expand Down Expand Up @@ -89,4 +89,4 @@ Config/testthat/parallel: true
Roxygen: list(markdown = TRUE)
Config/Needs/check: stan-dev/cmdstanr
Config/Needs/website: easystats/easystatstemplate
Remotes: easystats/bayestestR
Remotes: easystats/bayestestR, vincentarelbundock/marginaleffects
26 changes: 20 additions & 6 deletions R/estimate_means.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
#' [estimate_slopes()].
#'
#' @param model A statistical model.
#' @param by The predictor variable(s) at which to evaluate the desired effect
#' / mean / contrasts. Other predictors of the model that are not included
#' here will be collapsed and "averaged" over (the effect will be estimated
#' across them).
#' @param by The (focal) predictor variable(s) at which to evaluate the desired
#' effect / mean / contrasts. Other predictors of the model that are not
#' included here will be collapsed and "averaged" over (the effect will be
#' estimated across them). `by` can a character (vector) naming the focal
#' predictors (and optionally, representative values or levels), or a list of named
#' elements. See details in [`insight::get_datagrid()`].
#' @param predict Is passed to the `type` argument in `emmeans::emmeans()` (when
#' `backend = "emmeans"`) or in `marginaleffects::avg_predictions()` (when
#' `backend = "marginaleffects"`). For emmeans, see also
Expand Down Expand Up @@ -77,8 +79,20 @@
#' as default backend.
#' @param transform Deprecated, please use `predict` instead.
#' @param verbose Use `FALSE` to silence messages and warnings.
#' @param ... Other arguments passed, for instance, to [insight::get_datagrid()]
#' or functions from the **emmeans** or **marginaleffects** package.
#' @param ... Other arguments passed, for instance, to [insight::get_datagrid()],
#' to functions from the **emmeans** or **marginaleffects** package, or to process
#' Bayesian models via [bayestestR::describe_posterior()]. Examples:
#' - `insight::get_datagrid()`: Argument such as `length` or `range` can be used
#' to control the (number of) representative values.
#' - **marginaleffects**: Internally used functions are `avg_predictions()` for
#' means and contrasts, and `avg_slope()` for slopes. Therefore, arguments
#' for instance like `vcov`, `transform`, `equivalence` or `slope` can be
#' passed to those functions.
#' - **emmeans**: Internally used functions are `emmeans()` and `emtrends()`.
#' Additional arguments can be passed to these functions.
#' - Bayesian models: For Bayesian models, parameters are cleaned using
#' `describe_posterior()`, thus, arguments like, for example, `centrality`,
#' `rope_range`, or `test` are passed to that function.
#'
#' @inheritParams parameters::model_parameters.default
#' @inheritParams estimate_expectation
Expand Down
6 changes: 5 additions & 1 deletion R/estimate_slopes.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,12 @@ estimate_slopes <- function(model,
info <- attributes(estimated)

# Table formatting
table_footer <- paste("\nMarginal effects estimated for", info$trend)
if (!is.null(attributes(trends)$slope)) {
table_footer <- paste0(table_footer, "\nType of slope was ", attributes(trends)$slope)
}
attr(trends, "table_title") <- c("Estimated Marginal Effects", "blue")
attr(trends, "table_footer") <- c(paste("Marginal effects estimated for", info$trend), "blue")
attr(trends, "table_footer") <- c(table_footer, "blue")

# Add attributes
attr(trends, "model") <- model
Expand Down
42 changes: 41 additions & 1 deletion R/format.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
}

# arrange columns (not for contrast now)
by <- rev(attr(x, "focal_terms", exact = TRUE))

Check warning on line 13 in R/format.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/format.R,line=13,col=3,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.

Check warning on line 13 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=13,col=3,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.
if (!is.null(by) && all(by %in% colnames(x))) {
# arrange predictions
x <- datawizard::data_arrange(x, select = by)
Expand Down Expand Up @@ -114,6 +114,13 @@
if ("term" %in% colnames(x) && insight::n_unique(x$term) == 1) {
remove_columns <- c("Parameter", remove_columns)
}
# there are some exceptions for `estimate_slope()`, when the `Comparison`
# column contains information about the type of slope (dx/dy etc.). we want
# to remove this here, but add information as attribute.
if ("contrast" %in% colnames(x) && all(x$contrast %in% .marginaleffects_slopes())) {
remove_columns <- c("Comparison", "contrast", remove_columns)
attr(x, "slope") <- unique(x$contrast)
}
# reshape and format columns
params <- .standardize_marginaleffects_columns(
x,
Expand All @@ -132,10 +139,11 @@

#' @export
format.marginaleffects_contrasts <- function(x, model, p_adjust, comparison, ...) {
predict <- attributes(x)$predict

Check warning on line 142 in R/format.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/format.R,line=142,col=3,[object_overwrite_linter] 'predict' is an exported object from package 'stats'. Avoid re-using such symbols.

Check warning on line 142 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=142,col=3,[object_overwrite_linter] 'predict' is an exported object from package 'stats'. Avoid re-using such symbols.
by <- attributes(x)$by

Check warning on line 143 in R/format.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/format.R,line=143,col=3,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.

Check warning on line 143 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=143,col=3,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.
contrast <- attributes(x)$contrast
focal_terms <- attributes(x)$focal_terms
dgrid <- attributes(x)$datagrid

# clean "by" and contrast variable names, for the special cases. for example,
# if we have `by = "name [fivenum]"`, we just want "name"
Expand Down Expand Up @@ -163,7 +171,7 @@
if (!is.null(comparison) && is.character(comparison) && comparison %in% valid_options) {
# the goal here is to create tidy columns with the comparisons.
# marginaleffects returns a single column that contains all levels that
# are contrastet. We want to have the contrasted levels per predictor in
# are contrasted. We want to have the contrasted levels per predictor in
# a separate column. This is what we do here...

# split parameter column into comparison groups.
Expand All @@ -172,6 +180,38 @@
lapply(x$Parameter, .split_at_minus_outside_parentheses)
))

# When we filter contrasts, e.g. `contrast = c("vs", "am='1'")` or
# `contrast = c("vs", "am"), by = "gear='5'"`, we get no contrasts if one
# of the focal terms only has one unique value in the data grid. Thus,
# we need to exclude all those focal terms that only have one unique value
# in the data grid now. Fingers crossed that it works...
focal_terms <- focal_terms[lengths(lapply(dgrid[focal_terms], unique)) > 1]

# in the second example, `contrast = c("vs", "am"), by = "gear='5'"`, the
# `by` column is the one with one unique value only, we thus have to update
# `by` as well, and also `contrast` (the latter not(!) for numerics)...
by <- by[lengths(lapply(dgrid[by], unique)) > 1]

Check warning on line 193 in R/format.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/format.R,line=193,col=5,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.

Check warning on line 193 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=193,col=5,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.

# for contrasts, we also filter variables with one unique value, but we
# keep numeric variables. When these are hold constant in the data grid,
# they are set to their mean value - meaning, they only have one unique
# value in the data grid, anyway. so we need to keep them
keep_contrasts <- lengths(lapply(dgrid[contrast], unique)) > 1 | vapply(dgrid[contrast], is.numeric, logical(1)) # nolint
contrast <- contrast[keep_contrasts]

# set to NULL, if all by-values have been removed here
if (!length(by)) by <- NULL

Check warning on line 203 in R/format.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/format.R,line=203,col=22,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.

Check warning on line 203 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=203,col=22,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.

# if we have no contrasts left, e.g. due to `contrast = "time = factor(2)"`,
# we error here - we have no contrasts to show
if (!length(contrast)) {
insight::format_error("No contrasts to show. Please adjust `contrast`.")
}

# contrasts can't be longer than focal terms - make sure we have not
# removed too much (and that we now have captured all exceptions...)
if (length(contrast) > length(focal_terms)) focal_terms <- contrast

# for more than one term, we have comma-separated levels.
if (length(focal_terms) > 1) {
# we now have a data frame with each comparison-pairs as single column.
Expand Down Expand Up @@ -201,7 +241,7 @@
# unite back columns with focal contrasts - only needed when not slopes
if (inherits(x, "estimate_slopes")) {
contrast <- by
by <- NULL

Check warning on line 244 in R/format.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/format.R,line=244,col=9,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.

Check warning on line 244 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=244,col=9,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.
}

for (i in seq_along(contrast)) {
Expand Down Expand Up @@ -315,7 +355,7 @@
params <- params[c(setdiff(colnames(params), relocate_columns), relocate_columns)]

# relocate focal terms to the beginning
by <- attr(x, "focal_terms", exact = TRUE)

Check warning on line 358 in R/format.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/format.R,line=358,col=3,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.

Check warning on line 358 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=358,col=3,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.
if (!is.null(by) && all(by %in% colnames(params))) {
params <- datawizard::data_reorder(params, by, verbose = FALSE)
}
Expand Down Expand Up @@ -431,7 +471,7 @@
if (substring(input_string, match_positions[i], match_positions[i]) == "-") {
inside_parentheses <- FALSE
for (j in seq_along(match_positions)) {
if (i != j && match_positions[i] > match_positions[j] && match_positions[i] < (match_positions[j] + match_lengths[j])) {

Check warning on line 474 in R/format.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/format.R,line=474,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 128 characters.

Check warning on line 474 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=474,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 128 characters.
inside_parentheses <- TRUE
break
}
Expand Down
3 changes: 2 additions & 1 deletion R/get_marginalmeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ get_marginalmeans <- function(model,
.info_elements <- function() {
c(
"at", "by", "focal_terms", "adjusted_for", "predict", "trend", "comparison",
"contrast", "marginalize", "p_adjust", "datagrid", "preserve_range", "coef_name"
"contrast", "marginalize", "p_adjust", "datagrid", "preserve_range",
"coef_name", "slope"
)
}

Expand Down
6 changes: 6 additions & 0 deletions R/get_marginaltrends.R
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,9 @@ get_marginaltrends <- function(model,

trend
}


#' @keywords internal
.marginaleffects_slopes <- function() {
c("dY/dX", "eY/eX", "eY/dX", "dY/eX")
}
36 changes: 27 additions & 9 deletions R/visualisation_recipe_internal.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,49 @@

#' @keywords internal
.find_aes <- function(x) {
data <- as.data.frame(x)

Check warning on line 6 in R/visualisation_recipe_internal.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/visualisation_recipe_internal.R,line=6,col=3,[object_overwrite_linter] 'data' is an exported object from package 'utils'. Avoid re-using such symbols.
data$.group <- 1

att <- attributes(x)
aes <- list(
y = "Predicted",
group = 1
group = ".group"
)

# extract information for labels
model_data <- .safe(insight::get_data(attributes(x)$model, verbose = FALSE))
model_response <- attributes(x)$response

# Find predictors
by <- att$focal_terms

Check warning on line 20 in R/visualisation_recipe_internal.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/visualisation_recipe_internal.R,line=20,col=3,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.

# Main geom
if ("estimate_contrasts" %in% att$class) {
insight::format_error("Automated plotting is not yet implemented for this class.")
} else if ("estimate_means" %in% att$class) {
aes$y <- att$coef_name
} else if ("estimate_slopes" %in% att$class) {
aes$y <- "Slope"
if ("Comparison" %in% names(data)) {
# Insert "Comparison" column as the 2nd by so that it gets plotted as color
if (length(by) > 1) by[3:(length(by) + 1)] <- by[2:length(by)]
by[2] <- "Comparison"
}
} else if ("estimate_grouplevel" %in% att$class) {
aes$x <- "Level"
aes$y <- "Coefficient"
aes$type <- "grouplevel"
if (length(unique(data$Parameter)) > 1) {
aes$color <- "Parameter"
aes$group <- "Parameter"
data$.group <- paste(data$.group, data$Parameter)
}
if (length(unique(data$Group)) > 1) aes$facet <- "Group"
aes <- .find_aes_ci(aes, data)
return(list(aes = aes, data = data))
}

# Find predictors
by <- att$focal_terms
# 2nd try

# Assign predictors to aes
if (is.null(by)) {
by <- att$by
}
Expand All @@ -53,16 +62,15 @@
}
if (length(by) > 1) {
aes$color <- by[2]
aes$group <- by[2]
data$.group <- paste(data$.group, data[[by[2]]])
}
if (length(by) > 2) {
if (is.numeric(data[[by[3]]])) {
aes$alpha <- by[3]
} else {
aes$facet <- stats::as.formula(paste("~", paste(utils::tail(by, -2), collapse = " * ")))
}
data$.group <- paste(data[[by[2]]], "_", data[[by[3]]])
aes$group <- ".group"
data$.group <- paste(data$.group, data[[by[3]]])
}
if (length(by) > 3) {
aes$facet <- NULL
Expand Down Expand Up @@ -305,7 +313,17 @@
stroke = stroke
)

# set default alpha, it not mapped by aes
# check if we have matching columns in the raw data - some functions,
# likes slopes, have mapped these aes to other columns that are not part
# of the raw data - we set them to NULL
if (!is.null(aes$color) && !aes$color %in% colnames(rawdata)) {
out$aes$color <- NULL
}
if (!is.null(aes$alpha) && !aes$alpha %in% colnames(rawdata)) {
out$aes$alpha <- NULL
}

# set default alpha, if not mapped by aes
if (is.null(aes$alpha)) {
out$alpha <- 1 / 3
} else {
Expand Down
28 changes: 22 additions & 6 deletions man/estimate_contrasts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 22 additions & 6 deletions man/estimate_means.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 22 additions & 6 deletions man/estimate_slopes.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading