Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sampling speed improvements #1382

Open
avehtari opened this issue Jul 27, 2022 · 8 comments
Open

sampling speed improvements #1382

avehtari opened this issue Jul 27, 2022 · 8 comments

Comments

@avehtari
Copy link
Contributor

avehtari commented Jul 27, 2022

Two potential speed-ups for spline and HSGP models (without random effects)

1. use _glm more often.

E.g. formula y ~ s(x) generates code with

    vector[N] mu = Intercept + rep_vector(0.0, N) + Xs * bs + Zs_1_1 * s_1_1;
    target += normal_lpdf(Y | mu, sigma);

this can changed to use _glm (in the cases when there is _glm for the family used)

    target += normal_id_glm_lpdf(Y | append_col(Xs,Zs_1_1), Intercept, append_row(bs,s_1_1), sigma);

This provided 80% drop in sampling time. This should be applicable for example for models of the form y ~ x1 + x2 + s(x1) + s(x2)

2. make code compatible with --O1 SoA (DONE)

See stanc --O1 optimization discourse post https://discourse.mc-stan.org/t/30-40-drop-in-sampling-time-using-stanc-o1-optimizations/28347

Even if not changing to use _glm, the code can be changed to benefit from stanc optimization --O1 and structure-of-arrays memory mapping, e.g. changing

    // initialize linear predictor term
    vector[N] mu = Intercept + rep_vector(0.0, N) + Xs * bs + Zs_1_1 * s_1_1;
    // initialize linear predictor term
    vector[N] sigma = Intercept_sigma + rep_vector(0.0, N) + Xs_sigma * bs_sigma + Zs_sigma_1_1 * s_sigma_1_1;
    for (n in 1:N) {
      // apply the inverse link function
      sigma[n] = exp(sigma[n]);
    }
    target += normal_lpdf(Y | mu, sigma);

to

    vector[N] mu = Intercept + Xs * bs + Zs_1_1 * s_1_1;
    // initialize linear predictor term
    vector[N] sigma = Intercept_sigma + Xs_sigma * bs_sigma + Zs_sigma_1_1 * s_sigma_1_1;
    // apply the inverse link function
    sigma = exp(sigma);
    target += normal_lpdf(Y | mu, sigma);

drops the sampling time 40% when using --O1. In the first code, use of SoA is blocked by adding data vector rep_vector(0.0, N) and by for loop over sigma. This speedup would work also with families that don't have _glm function in Stan.

@avehtari
Copy link
Contributor Author

For y ~ x1 + x2 + s(x1) + s(x2) type of models, getting --O1 benefit, it is enough to change

    vector[N] mu = Intercept + rep_vector(0.0, N) + Xs * bs + Zs_1_1 * s_1_1 + Zs_2_1 * s_2_1;

to

    vector[N] mu = Intercept + (rep_vector(0.0, N) + Xs * bs + Zs_1_1 * s_1_1 + Zs_2_1 * s_2_1);

see stan-dev/stanc3#1232 for details

@paul-buerkner paul-buerkner added this to the brms 2.17.0++ milestone Aug 5, 2022
@wds15
Copy link
Contributor

wds15 commented Aug 8, 2022

And one more thing: Current partial log link functions as used with the cmdstanr backend do prevent the SoA optimisation due to forlooping. A - hopefully - simple change for brms avoids the foor loops by using vectored index expressionion,s which work with SoA. Here is the key change to make:

  // compute partial sums of the log-likelihood
  real partial_log_lik_lpmf(int[] seq, int start, int end, data vector Y, data matrix Xc, vector b, real Intercept, real sigma, data int[] J_1, data vector Z_1_1, vector r_1_1) {
    real ptarget = 0;
    int N = end - start + 1;
    // initialize linear predictor term
    //vector[N] mu;  = Intercept + rep_vector(0.0, N);
    vector[N] mu = rep_vector(Intercept, N);
    /*
    for (n in 1:N) {
      // add more terms to the linear predictor
      int nn = n + start - 1;
      mu[n] += r_1_1[J_1[nn]] * Z_1_1[nn];
    }
    */
    mu += r_1_1[J_1[start:end]] .* Z_1_1[start:end];
    ptarget += normal_id_glm_lpdf(Y[start:end] | Xc[start:end], mu, b, sigma);
    return ptarget;
  }

Note that this still needs speed testing (SoA & AoS compiled versions).

The goal is to get SoA to work with reduce_sum, which is currently blocked by the compiler. See here: stan-dev/stanc3#1236

We will have to let the stanc folks weigh in here, but I think that SoA is compatible with reduce_sum.

Here is a small R code to generate the full program as well as the full slightly modified brms program:

soa-example-brms-soa_stan.txt

horseshoe-scale_R.txt

@paul-buerkner
Copy link
Owner

paul-buerkner commented Aug 12, 2022

For y ~ x1 + x2 + s(x1) + s(x2) type of models, getting --O1 benefit, it is enough to change

    vector[N] mu = Intercept + rep_vector(0.0, N) + Xs * bs + Zs_1_1 * s_1_1 + Zs_2_1 * s_2_1;

to

    vector[N] mu = Intercept + (rep_vector(0.0, N) + Xs * bs + Zs_1_1 * s_1_1 + Zs_2_1 * s_2_1);

see stan-dev/stanc3#1232 for details

@avehtari would the following work to prevent the rep_vector(0.0, N) problem?

vector[N] mu = rep_vector(0.0, N);
mu += Intercept + Xs * bs + Zs_1_1 * s_1_1;

Edit: Figured out how to check that myself.

@paul-buerkner
Copy link
Owner

paul-buerkner commented Aug 12, 2022

@wds15 About the vectorized index expressions for multilevel terms, they still seem to be slower with stan 2.30 than the currently used loops. Here is an example for a varying intercept, varying slope model.

Old Stan code (with loops over r_* variables):

// generated with brms 2.17.6
functions {
 /* compute correlated group-level effects
  * Args:
  *   z: matrix of unscaled group-level effects
  *   SD: vector of standard deviation parameters
  *   L: cholesky factor correlation matrix
  * Returns:
  *   matrix of scaled group-level effects
  */
  matrix scale_r_cor(matrix z, vector SD, matrix L) {
    // r is stored in another dimension order than z
    return transpose(diag_pre_multiply(SD, L) * z);
  }
}
data {
  int<lower=1> N;  // total number of observations
  vector[N] Y;  // response variable
  int<lower=1> K;  // number of population-level effects
  matrix[N, K] X;  // population-level design matrix
  // data for group-level effects of ID 1
  int<lower=1> N_1;  // number of grouping levels
  int<lower=1> M_1;  // number of coefficients per level
  int<lower=1> J_1[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_1_1;
  vector[N] Z_1_2;
  int<lower=1> NC_1;  // number of group-level correlations
  int prior_only;  // should the likelihood be ignored?
}
transformed data {
  int Kc = K - 1;
  matrix[N, Kc] Xc;  // centered version of X without an intercept
  vector[Kc] means_X;  // column means of X before centering
  for (i in 2:K) {
    means_X[i - 1] = mean(X[, i]);
    Xc[, i - 1] = X[, i] - means_X[i - 1];
  }
}
parameters {
  vector[Kc] b;  // population-level effects
  real Intercept;  // temporary intercept for centered predictors
  real<lower=0> sigma;  // dispersion parameter
  vector<lower=0>[M_1] sd_1;  // group-level standard deviations
  matrix[M_1, N_1] z_1;  // standardized group-level effects
  cholesky_factor_corr[M_1] L_1;  // cholesky factor of correlation matrix
}
transformed parameters {
  matrix[N_1, M_1] r_1;  // actual group-level effects
  // using vectors speeds up indexing in loops
  vector[N_1] r_1_1;
  vector[N_1] r_1_2;
  real lprior = 0;  // prior contributions to the log posterior
  // compute actual group-level effects
  r_1 = scale_r_cor(z_1, sd_1, L_1);
  r_1_1 = r_1[, 1];
  r_1_2 = r_1[, 2];
  lprior += student_t_lpdf(Intercept | 3, 288.7, 59.3);
  lprior += student_t_lpdf(sigma | 3, 0, 59.3)
    - 1 * student_t_lccdf(0 | 3, 0, 59.3);
  lprior += student_t_lpdf(sd_1 | 3, 0, 59.3)
    - 2 * student_t_lccdf(0 | 3, 0, 59.3);
  lprior += lkj_corr_cholesky_lpdf(L_1 | 1);
}
model {
  // likelihood including constants
  if (!prior_only) {
    // initialize linear predictor term
    vector[N] mu = rep_vector(0.0, N);
    mu += Intercept;
    for (n in 1:N) {
      // add more terms to the linear predictor
      mu[n] += r_1_1[J_1[n]] * Z_1_1[n] + r_1_2[J_1[n]] * Z_1_2[n];
    }
    target += normal_id_glm_lpdf(Y | Xc, mu, b, sigma);
  }
  // priors including constants
  target += lprior;
  target += std_normal_lpdf(to_vector(z_1));
}
generated quantities {
  // actual population-level intercept
  real b_Intercept = Intercept - dot_product(means_X, b);
  // compute group-level correlations
  corr_matrix[M_1] Cor_1 = multiply_lower_tri_self_transpose(L_1);
  vector<lower=-1,upper=1>[NC_1] cor_1;
  // extract upper diagonal of correlation matrix
  for (k in 1:M_1) {
    for (j in 1:(k - 1)) {
      cor_1[choose(k - 1, 2) + j] = Cor_1[j, k];
    }
  }
}

New Stan code (without loops over r_* variables):

// generated with brms 2.17.6
functions {
 /* compute correlated group-level effects
  * Args:
  *   z: matrix of unscaled group-level effects
  *   SD: vector of standard deviation parameters
  *   L: cholesky factor correlation matrix
  * Returns:
  *   matrix of scaled group-level effects
  */
  matrix scale_r_cor(matrix z, vector SD, matrix L) {
    // r is stored in another dimension order than z
    return transpose(diag_pre_multiply(SD, L) * z);
  }
}
data {
  int<lower=1> N;  // total number of observations
  vector[N] Y;  // response variable
  int<lower=1> K;  // number of population-level effects
  matrix[N, K] X;  // population-level design matrix
  // data for group-level effects of ID 1
  int<lower=1> N_1;  // number of grouping levels
  int<lower=1> M_1;  // number of coefficients per level
  int<lower=1> J_1[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_1_1;
  vector[N] Z_1_2;
  int<lower=1> NC_1;  // number of group-level correlations
  int prior_only;  // should the likelihood be ignored?
}
transformed data {
  int Kc = K - 1;
  matrix[N, Kc] Xc;  // centered version of X without an intercept
  vector[Kc] means_X;  // column means of X before centering
  for (i in 2:K) {
    means_X[i - 1] = mean(X[, i]);
    Xc[, i - 1] = X[, i] - means_X[i - 1];
  }
}
parameters {
  vector[Kc] b;  // population-level effects
  real Intercept;  // temporary intercept for centered predictors
  real<lower=0> sigma;  // dispersion parameter
  vector<lower=0>[M_1] sd_1;  // group-level standard deviations
  matrix[M_1, N_1] z_1;  // standardized group-level effects
  cholesky_factor_corr[M_1] L_1;  // cholesky factor of correlation matrix
}
transformed parameters {
  matrix[N_1, M_1] r_1;  // actual group-level effects
  // using vectors speeds up indexing in loops
  vector[N_1] r_1_1;
  vector[N_1] r_1_2;
  real lprior = 0;  // prior contributions to the log posterior
  // compute actual group-level effects
  r_1 = scale_r_cor(z_1, sd_1, L_1);
  r_1_1 = r_1[, 1];
  r_1_2 = r_1[, 2];
  lprior += student_t_lpdf(Intercept | 3, 288.7, 59.3);
  lprior += student_t_lpdf(sigma | 3, 0, 59.3)
    - 1 * student_t_lccdf(0 | 3, 0, 59.3);
  lprior += student_t_lpdf(sd_1 | 3, 0, 59.3)
    - 2 * student_t_lccdf(0 | 3, 0, 59.3);
  lprior += lkj_corr_cholesky_lpdf(L_1 | 1);
}
model {
  // likelihood including constants
  if (!prior_only) {
    // initialize linear predictor term
    vector[N] mu = rep_vector(0.0, N);
    mu += Intercept + r_1_1[J_1] .* Z_1_1 + r_1_2[J_1] .* Z_1_2;
    target += normal_id_glm_lpdf(Y | Xc, mu, b, sigma);
  }
  // priors including constants
  target += lprior;
  target += std_normal_lpdf(to_vector(z_1));
}
generated quantities {
  // actual population-level intercept
  real b_Intercept = Intercept - dot_product(means_X, b);
  // compute group-level correlations
  corr_matrix[M_1] Cor_1 = multiply_lower_tri_self_transpose(L_1);
  vector<lower=-1,upper=1>[NC_1] cor_1;
  // extract upper diagonal of correlation matrix
  for (k in 1:M_1) {
    for (j in 1:(k - 1)) {
      cor_1[choose(k - 1, 2) + j] = Cor_1[j, k];
    }
  }
}

Corresponding R code:

library(cmdstanr)
library(lme4)

sleepstudy_long <- Reduce("rbind", replicate(10, sleepstudy, simplify = FALSE))
sdata <- make_standata(Reaction ~ Days + (Days|Subject), sleepstudy_long)

mod_old <- cmdstan_model("test_old.stan")
mod_old$check_syntax(stanc_options = list("debug-mem-patterns", "O1"))
fit_old <- mod_old$sample(data = sdata)

mod_new <- cmdstan_model("test_new.stan")
mod_new$check_syntax(stanc_options = list("debug-mem-patterns", "O1"))
fit_new <- mod_new$sample(data = sdata)

@paul-buerkner
Copy link
Owner

@avehtari both the problem with rep_vector(0.0, N) and the application of the inverse link in a loop are now fixed and don't block SoA anymore. Thank you for your input!

The extended use of GLM functions is a bit more complicated issue (due to the interaction of multiple model parts), and I cannot tackle it now, but will work on it hopefully soon.

@paul-buerkner paul-buerkner removed this from the brms 2.17.0++ milestone Aug 12, 2022
@avehtari
Copy link
Contributor Author

stan 2.30

I assume 2.30.1?

the vectorized index expressions for multilevel terms, they still seem to be slower with stan 2.30 than the currently used loops.

Having SoA alone is not sufficient for the speedup. Even if everything else in Stan generated C++ would be perfect, It is possible that if the index is jumping around, and the memory access is still slow. The blog post https://viralinstruction.com/posts/hardware/ illustrates nicely the big differences in speed depending on whether the memory is accessed in order or not.

@paul-buerkner
Copy link
Owner

Yes. 2.30.1. Thank you for your additional insights.

@wds15
Copy link
Contributor

wds15 commented Aug 12, 2022

Thanks @paul-buerkner for trying out the array index thing. I am surprised to slows things down... will look into it once more myself. I know that Steve worked on it to make these type of indexing fast (having brms in mind), so let's hope it can be resolved.

In any case - getting SoA for reduce_sum is not going to happen soon (one thing I investigated), since SoA does not work with user defined functions (UDFs). However, simple UDFs are inlined by the stanc3 compiler and therefore work with SoA possibly.

What I will try to write myself is a binomial_logit_glm unless someone else does it before me, but we really need this function.

Great to see O1 compatibility getting into brms!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants