Skip to content

Commit

Permalink
allow for more O1 optimization #1382
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Aug 12, 2022
1 parent bd99fe0 commit 8d7eeb7
Show file tree
Hide file tree
Showing 29 changed files with 172 additions and 140 deletions.
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ specified via `s(..., fx = TRUE)`.
* Reuse user-specified control arguments originally passed
to the Stan backend in `update` and related methods. (#1373, #1378)

### Other Changes

* Allow for more `O1` optimization of brms-generated Stan models
thanks to Aki Vehtari. (#1382)

### Bug Fixes

* Fix problems with missing boundaries of `sdme` parameters in models
Expand Down
7 changes: 4 additions & 3 deletions R/distributions.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
#' @param mu Vector of location values.
#' @param sigma Vector of scale values.
#' @param df Vector of degrees of freedom.
#' @param log,log.p Logical; If \code{TRUE}, values are returned on the log scale.
#' @param log Logical; If \code{TRUE}, values are returned on the log scale.
#' @param log.p Logical; If \code{TRUE}, values are returned on the log scale.
#' @param lower.tail Logical; If \code{TRUE} (default), return P(X <= x).
#' Else, return P(X > x) .
#'
Expand Down Expand Up @@ -929,7 +930,7 @@ rinv_gaussian <- function(n, mu = 1, shape = 1) {
#' @param size Vector of number of trials (zero or more).
#' @param mu Vector of means.
#' @param phi Vector of precisions.
#'
#'
#' @export
dbeta_binomial <- function(x, size, mu, phi, log = FALSE) {
require_package("extraDistr")
Expand All @@ -944,7 +945,7 @@ pbeta_binomial <- function(q, size, mu, phi, lower.tail = TRUE, log.p = FALSE) {
require_package("extraDistr")
alpha <- mu * phi
beta <- (1 - mu) * phi
extraDistr::pbbinom(q, size, alpha = alpha, beta = beta,
extraDistr::pbbinom(q, size, alpha = alpha, beta = beta,
lower.tail = lower.tail, log.p = log.p)
}

Expand Down
8 changes: 4 additions & 4 deletions R/make_stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ make_stancode <- function(formula, data, family = gaussian(),
scode_predictor[[i]][["model_def"]],
collapse_stanvars(stanvars, "likelihood", "start"),
scode_predictor[[i]][["model_comp_basic"]],
scode_predictor[[i]][["model_comp_eta"]],
scode_predictor[[i]][["model_comp_eta_loop"]],
scode_predictor[[i]][["model_comp_dpar_link"]],
scode_predictor[[i]][["model_comp_mu_link"]],
scode_predictor[[i]][["model_comp_dpar_trans"]],
scode_predictor[[i]][["model_comp_mix"]],
scode_predictor[[i]][["model_comp_arma"]],
Expand Down Expand Up @@ -156,9 +156,9 @@ make_stancode <- function(formula, data, family = gaussian(),
collapse_stanvars(stanvars, "likelihood", "start"),
scode_predictor[["model_no_pll_comp_basic"]],
scode_predictor[["model_comp_basic"]],
scode_predictor[["model_comp_eta"]],
scode_predictor[["model_comp_eta_loop"]],
scode_predictor[["model_comp_dpar_link"]],
scode_predictor[["model_comp_mu_link"]],
scode_predictor[["model_comp_dpar_trans"]],
scode_predictor[["model_comp_mix"]],
scode_predictor[["model_comp_arma"]],
Expand Down Expand Up @@ -308,7 +308,7 @@ make_stancode <- function(formula, data, family = gaussian(),
scode <- parse_model(scode, backend, silent = silent)
}
if (backend == "cmdstanr") {
if (requireNamespace("cmdstanr", quietly = TRUE) &&
if (requireNamespace("cmdstanr", quietly = TRUE) &&
cmdstanr::cmdstan_version() >= "2.29.0") {
tmp_file <- cmdstanr::write_stan_file(scode)
scode <- .canonicalize_stan_model(tmp_file, overwrite_file = FALSE)
Expand Down Expand Up @@ -446,7 +446,7 @@ normalize_stancode <- function(x) {
trimws(x)
}

# check if the currently installed Stan version requires older syntax
# check if the currently installed Stan version requires older syntax
# than the Stan version with which the model was initially fitted
require_old_stan_syntax <- function(object, backend, version) {
stopifnot(is.brmsfit(object))
Expand Down
70 changes: 31 additions & 39 deletions R/stan-predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -1666,19 +1666,16 @@ stan_nl <- function(bterms, data, nlpars, threads, inv_link = rep("", 2), ...) {
" // initialize non-linear predictor term\n",
" vector[N{resp}] {par};\n"
)
# make sure mu comes last as it might depend on other parameters
is_mu <- isTRUE("mu" %in% dpar_class(bterms[["dpar"]]))
position <- str_if(is_mu, "model_comp_mu_link", "model_comp_dpar_link")
if (bterms$loop) {
str_add(out[[position]]) <- glue(
str_add(out$model_comp_dpar_link) <- glue(
" for (n in 1:N{resp}) {{\n",
stan_nn_def(threads),
" // compute non-linear predictor values\n",
" {par}[n] = {inv_link[1]}{eta}{inv_link[2]};\n",
" }}\n"
)
} else {
str_add(out[[position]]) <- glue(
str_add(out$model_comp_dpar_link) <- glue(
" // compute non-linear predictor values\n",
" {par} = {inv_link[1]}{eta}{inv_link[2]};\n"
)
Expand Down Expand Up @@ -1819,8 +1816,11 @@ stan_eta_combine <- function(out, bterms, ranef, threads, primitive,
out$eta <- sub("^[ \t\r\n]+\\+", "", out$eta, perl = TRUE)
str_add(out$model_def) <- glue(
" // initialize linear predictor term\n",
" vector[N{resp}] {eta} ={out$eta};\n"
" vector[N{resp}] {eta} = rep_vector(0.0, N{resp});\n"
)
if (nzchar(out$eta)) {
str_add(out$model_comp_eta) <- glue(" {eta} +={out$eta};\n")
}
out$eta <- NULL
str_add(out$loopeta) <- stan_eta_re(ranef, threads = threads, px = px)
if (nzchar(out$loopeta)) {
Expand All @@ -1837,14 +1837,8 @@ stan_eta_combine <- function(out, bterms, ranef, threads, primitive,
out$loopeta <- NULL
# possibly transform eta before it is passed to the likelihood
if (sum(nzchar(inv_link))) {
# make sure mu comes last as it might depend on other parameters
is_mu <- isTRUE("mu" %in% dpar_class(bterms[["dpar"]]))
position <- str_if(is_mu, "model_comp_mu_link", "model_comp_dpar_link")
str_add(out[[position]]) <- glue(
" for (n in 1:N{resp}) {{\n",
" // apply the inverse link function\n",
" {eta}[n] = {inv_link[1]}{eta}[n]{inv_link[2]};\n",
" }}\n"
str_add(out$model_comp_dpar_link) <- glue(
" {eta} = {inv_link[1]}{eta}{inv_link[2]};\n"
)
}
out
Expand All @@ -1856,33 +1850,31 @@ stan_eta_combine <- function(out, bterms, ranef, threads, primitive,
# @param primitive use Stan's GLM likelihood primitives?
# @return a single character string
stan_eta_fe <- function(fixef, bterms, threads, primitive) {
if (length(fixef) && !primitive) {
p <- usc(combine_prefix(bterms))
center_X <- stan_center_X(bterms)
decomp <- get_decomp(bterms$fe)
sparse <- is_sparse(bterms$fe)
if (sparse) {
stopifnot(!center_X && decomp == "none")
csr_args <- sargs(
paste0(c("rows", "cols"), "(X", p, ")"),
paste0(c("wX", "vX", "uX", "b"), p)
)
eta_fe <- glue("csr_matrix_times_vector({csr_args})")
} else {
sfx_X <- sfx_b <- ""
if (decomp == "QR") {
sfx_X <- sfx_b <- "Q"
} else if (center_X) {
sfx_X <- "c"
}
slice <- stan_slice(threads)
eta_fe <- glue("X{sfx_X}{p}{slice} * b{sfx_b}{p}")
}
if (!length(fixef) || primitive) {
return("")
}
p <- usc(combine_prefix(bterms))
center_X <- stan_center_X(bterms)
decomp <- get_decomp(bterms$fe)
sparse <- is_sparse(bterms$fe)
if (sparse) {
stopifnot(!center_X && decomp == "none")
csr_args <- sargs(
paste0(c("rows", "cols"), "(X", p, ")"),
paste0(c("wX", "vX", "uX", "b"), p)
)
eta_fe <- glue(" + csr_matrix_times_vector({csr_args})")
} else {
resp <- usc(bterms$resp)
eta_fe <- glue("rep_vector(0.0, N{resp})")
sfx_X <- sfx_b <- ""
if (decomp == "QR") {
sfx_X <- sfx_b <- "Q"
} else if (center_X) {
sfx_X <- "c"
}
slice <- stan_slice(threads)
eta_fe <- glue(" + X{sfx_X}{p}{slice} * b{sfx_b}{p}")
}
glue(" + {eta_fe}")
eta_fe
}

# write the group-level part of the linear predictor
Expand Down
14 changes: 7 additions & 7 deletions inst/chunks/fun_cauchit.stan
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
/* compute the cauchit link
* Args:
* p: a scalar in (0, 1)
* p: a vector in (0, 1)
* Returns:
* a scalar in (-Inf, Inf)
* a vector in (-Inf, Inf)
*/
real cauchit(real p) {
vector cauchit(vector p) {
return tan(pi() * (p - 0.5));
}
/* compute the inverse of the cauchit link
* Args:
* y: a scalar in (-Inf, Inf)
* y: a vector in (-Inf, Inf)
* Returns:
* a scalar in (0, 1)
* a vector in (0, 1)
*/
real inv_cauchit(real y) {
return cauchy_cdf(y, 0, 1);
vector inv_cauchit(vector y) {
return atan(y) / pi() + 0.5;
}
6 changes: 3 additions & 3 deletions inst/chunks/fun_cloglog.stan
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
/* compute the cloglog link
* Args:
* p: a scalar in (0, 1)
* p: a vector in (0, 1)
* Returns:
* a scalar in (-Inf, Inf)
* a vector in (-Inf, Inf)
*/
real cloglog(real p) {
vector cloglog(vector p) {
return log(-log1m(p));
}
16 changes: 8 additions & 8 deletions inst/chunks/fun_logm1.stan
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
/* compute the logm1 link
* Args:
* p: a positive scalar
* p: a positive vector
* Returns:
* a scalar in (-Inf, Inf)
* a vector in (-Inf, Inf)
*/
real logm1(real y) {
return log(y - 1);
vector logm1(vector y) {
return log(y - 1.0);
}
/* compute the inverse of the logm1 link
* Args:
* y: a scalar in (-Inf, Inf)
* y: a vector in (-Inf, Inf)
* Returns:
* a positive scalar
* a positive vector
*/
real expp1(real y) {
return exp(y) + 1;
vector expp1(vector y) {
return exp(y) + 1.0;
}
12 changes: 6 additions & 6 deletions inst/chunks/fun_softit.stan
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
/* compute the softit link
* Args:
* p: a scalar in (0, 1)
* p: a vector in (0, 1)
* Returns:
* a scalar in (-Inf, Inf)
* a vector in (-Inf, Inf)
*/
real softit(real p) {
vector softit(vector p) {
return log(expm1(-p / (p - 1)));
}
/* compute the inverse of the sofit link
* Args:
* y: a scalar in (-Inf, Inf)
* y: a vector in (-Inf, Inf)
* Returns:
* a scalar in (0, 1)
* a vector in (0, 1)
*/
real inv_softit(real y) {
vector inv_softit(vector y) {
return log1p_exp(y) / (1 + log1p_exp(y));
}
6 changes: 3 additions & 3 deletions inst/chunks/fun_softplus.stan
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
/* softplus link function inverse to 'log1p_exp'
* Args:
* x: a positive scalar
* x: a positive vector
* Returns:
* a scalar in (-Inf, Inf)
* a vector in (-Inf, Inf)
*/
real log_expm1(real x) {
vector log_expm1(vector x) {
return log(expm1(x));
}
16 changes: 8 additions & 8 deletions inst/chunks/fun_squareplus.stan
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
/* squareplus inverse link function (squareplus itself)
* Args:
* x: a scalar in (-Inf, Inf)
* x: a vector in (-Inf, Inf)
* Returns:
* a positive scalar
* a positive vector
*/
real squareplus(real x) {
return (x + sqrt(x^2 + 4)) / 2;
vector squareplus(vector x) {
return (x + sqrt(square(x) + 4)) / 2;
}
/* squareplus link function (inverse squareplus)
* Args:
* x: a positive scalar
* x: a positive vector
* Returns:
* a scalar in (-Inf, Inf)
* a vector in (-Inf, Inf)
*/
real inv_squareplus(real x) {
return (x^2 - 1) / x;
vector inv_squareplus(vector x) {
return (square(x) - 1) ./ x;
}
12 changes: 6 additions & 6 deletions inst/chunks/fun_tan_half.stan
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
/* compute the tan_half link
* Args:
* x: a scalar in (-pi, pi)
* x: a vector in (-pi, pi)
* Returns:
* a scalar in (-Inf, Inf)
* a vector in (-Inf, Inf)
*/
real tan_half(real x) {
vector tan_half(vector x) {
return tan(x / 2);
}
/* compute the inverse of the tan_half link
* Args:
* y: a scalar in (-Inf, Inf)
* y: a vector in (-Inf, Inf)
* Returns:
* a scalar in (-pi, pi)
* a vector in (-pi, pi)
*/
real inv_tan_half(real y) {
vector inv_tan_half(vector y) {
return 2 * atan(y);
}
4 changes: 3 additions & 1 deletion man/AsymLaplace.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/BetaBinomial.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/Dirichlet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/ExGaussian.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 8d7eeb7

Please sign in to comment.