Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement closures #2197

Open
bbbales2 opened this issue Nov 16, 2020 · 18 comments
Open

Implement closures #2197

bbbales2 opened this issue Nov 16, 2020 · 18 comments

Comments

@bbbales2
Copy link
Member

Description

This is a placeholder issue for implementing closures as described in https://github.com/stan-dev/design-docs/blob/master/designs/0004-closures-fun-types.md

I say placeholder cause that's the main specification, and there will also be associated pull requests in stanc3.

Current Version:

v3.3.0

bbbales2 added a commit to nhuurre/math that referenced this issue Nov 16, 2020
bbbales2 added a commit to nhuurre/math that referenced this issue Nov 16, 2020
@bbbales2 bbbales2 mentioned this issue Nov 16, 2020
5 tasks
@nhuurre
Copy link
Collaborator

nhuurre commented Nov 18, 2020

The stanc3 pull request is open: stan-dev/stanc3#742
Probably has lots of bugs but it's good enough that I can compile Lotka-Volterra model:

data {
  int<lower = 0> N;          // number of measurement times
  real ts[N];                // measurement times > 0
  vector[2] y_init;            // initial measured populations
  real<lower = 0> y[N, 2];   // measured populations
}
parameters {
  real<lower = 0> alpha;
  real<lower = 0> beta;
  real<lower = 0> gamma;
  real<lower = 0> delta;
  vector<lower = 0>[2] z_init;  // initial population
  real<lower = 0> sigma[2];   // measurement errors
}
model {
  functions
  vector dz_dt(real t,      // time
               vector z     // system state {prey, predator}
               ) {
    real u = z[1];
    real v = z[2];
   
    real du_dt = (alpha - beta * v) * u;
    real dv_dt = (-gamma + delta * u) * v;
  
    return [ du_dt, dv_dt ]';
  }
  vector[2] z[N]
  = ode_bdf(dz_dt, z_init, 0., ts);

  alpha ~ normal(1, 0.5);
  gamma ~ normal(1, 0.5);
  beta ~ normal(0.05, 0.05);
  delta ~ normal(0.05, 0.05);
  sigma ~ lognormal(-1, 1);
  z_init ~ lognormal(log(10), 1);
  for (k in 1:2) {
    y_init[k] ~ lognormal(log(z_init[k]), sigma[k]);
    y[ , k] ~ lognormal(log(z[, k]), sigma[k]);
  }
}

@bbbales2
Copy link
Member Author

I converted the sir model over here to practice with this. (@rok-cesnovar thanks for the binaries)

It is an absolute delight to not worry about passing in parameters or data.

A few comments:

  1. I get error when I define a function in the transformed parameters block (example model included down the page and data at the bottom):
sir_negbin_lambda.hpp:150:7: error: redefinition of 'sir_L23C9_cfunctor__'
class sir_L23C9_cfunctor__ {
      ^
sir_negbin_lambda.hpp:105:7: note: previous definition is here
class sir_L23C9_cfunctor__ {
      ^
sir_negbin_lambda.hpp:244:1: error: redefinition of 'sir_L23C9_impl__'
sir_L23C9_impl__(const int& N, const T1__& beta, const T2__& gamma,
^
sir_negbin_lambda.hpp:190:1: note: previous definition is here
sir_L23C9_impl__(const int& N, const T1__& beta, const T2__& gamma,
  1. If there is a parameter y defined in the transformed parameters block, I cannot have an argument to my function named y. I guess to allow this we would need to add overloading to the language, which is out of the scope of this?

  2. I'm not sold on the naming. What about:

    function dz_dt = (real t, vector z) -> vector { ... }

    I just copied the C++ lambda return type syntax. I wouldn't mind automatically deduced types there. Would this then make it possible to pass the functions as arguments without first defining them as variabies, i.e., ode_bdf((real t, vector z) -> vector { ... }, ...)?

Here is the sir model:

data {
  int<lower=1> n_days;
  vector[3] y0;
  real t0;
  real ts[n_days];
  int N;
  int cases[n_days];
}
transformed data {
  real x_r[0];
  int x_i[1] = { N };
}
parameters {
  real<lower=0> gamma;
  real<lower=0> beta;
  real<lower=0> phi_inv;
}
transformed parameters{
  vector[3] y[n_days];
  real phi = 1. / phi_inv;

  functions
  vector sir(real t, vector x) {
    real S = x[1];
    real I = x[2];
    real R = x[3];

    real dS_dt = -beta * I * S / N;
    real dI_dt =  beta * I * S / N - gamma * I;
    real dR_dt =  gamma * I;

    return [ dS_dt, dI_dt, dR_dt ]';
  }

  y = ode_rk45(sir, y0, t0, ts);
}
model {
  //priors                                                                                                                                                                                                                                                                     
  beta ~ normal(2, 1);
  gamma ~ normal(0.4, 0.5);
  phi_inv ~ exponential(5);

  //sampling distribution                                                                                                                                                                                                                                                      
  cases ~ neg_binomial_2(y[, 2], phi);
}

generated quantities {
  real R0 = beta / gamma;
  real recovery_time = 1 / gamma;
  real pred_cases[n_days];
  pred_cases = neg_binomial_2_rng(y[, 2], phi);
}

Lotka volterra model above but with rk45 solver:

data {
  int<lower = 0> N;          // number of measurement times                                                                                                                         
  real ts[N];                // measurement times > 0                                                                                                                               
  vector[2] y_init;            // initial measured populations                                                                                                                      
  real<lower = 0> y[N, 2];   // measured populations                                                                                                                                
}
parameters {
  real<lower = 0> alpha;
  real<lower = 0> beta;
  real<lower = 0> gamma;
  real<lower = 0> delta;
  vector<lower = 0>[2] z_init;  // initial population                                                                                                                               
  real<lower = 0> sigma[2];   // measurement errors                                                                                                                                 
}
model {
  functions
  vector dz_dt(real t,      // time                                                                                                                                                 
               vector z     // system state {prey, predator}                                                                                                                        
               ) {
    real u = z[1];
    real v = z[2];

    real du_dt = (alpha - beta * v) * u;
    real dv_dt = (-gamma + delta * u) * v;

    return [ du_dt, dv_dt ]';
  }

  vector[2] z[N] = ode_rk45_tol(dz_dt, z_init, 0., ts,
                                1e-5, 1e-3, 500);

  alpha ~ normal(1, 0.5);
  gamma ~ normal(1, 0.5);
  beta ~ normal(0.05, 0.05);
  delta ~ normal(0.05, 0.05);
  sigma ~ lognormal(-1, 1);
  z_init ~ lognormal(log(10), 1);
  for (k in 1:2) {
    y_init[k] ~ lognormal(log(z_init[k]), sigma[k]);
    y[ , k] ~ lognormal(log(z[, k]), sigma[k]);
  }
}

Lotka volterra data: lv.dat.txt
Sir data: sir.dat.txt (used the code here to generate the data file: https://mc-stan.org/users/documentation/case-studies/boarding_school_case_study.html)

@bbbales2
Copy link
Member Author

@nhuurre does this accurately reflect the todo on this:

I don't know how we could get this to work with map_rect because I do not know how it works. @wds15 do you have the time or inclination to inspect this?

We're trying to plan out 2.26 (stan-dev/cmdstan#957) and this is a feature that's pretty far into development.

@nhuurre
Copy link
Collaborator

nhuurre commented Dec 15, 2020

Stan Math has integrate_dae function but it's not exposed in the language so I guess it doesn't need closure support.

I think the only tricky part with map_rect is supporting MPI.

I deviated quite a bit from the design-doc. Main points:

  • The doc does not mention special suffixes _lpdf/_rng/_lp at all. These would be separate function types. I didn't allow them because they don't seem that necessary and while _rng/_lp are straightforward to implement the _lpdf/_lupdf distinction is a can of worms.
  • No lambda expressions because I don't think they add anything.
  • The design doc seems to imply that code like this is valid:
real(real) f;
if (sq) {
  f = (real x) { return x*x; }
} else {
  f = (real x) { return 2*x; }
}

but keep in mind that C++ lambdas do not allow this! Despite the two lambdas having the same signature their C++ types are distinct and incompatible.
Of course it would be possible to work around that but I don't see the point. Stan isn't a functional programming language that encourages passing around lots of small functions. So I'm going with no function reassignments. You write your function definition once and it's not going to change afterwards.

@bbbales2
Copy link
Member Author

I deviated quite a bit from the design-doc

Cool. Reasons seem good to me. Are the known bugs fixed in the compiler now? I vaguely remember I got an ode and a reduce_sum running, but I might be misremembering the second thing :D.

If it is, I wanna build a cmdstan tarball with the mac binaries so that @bob-carpenter can test. He said he'd be down to check this against the design doc (and presumably update the design doc or suggest changes).

@nhuurre
Copy link
Collaborator

nhuurre commented Dec 15, 2020

Are the known bugs fixed in the compiler now?

I did push a fix for the redcard model.

Now, looking into special suffixes, I see that this (invalid) model causes a C++ compiler error.

functions {
  real foo(real(real) f) {
    return f(0) + f(1);
  }
  real bar_rng(real s) {
    return normal_rng(s,1);
  }
}
transformed data {
  real z = foo(bar_rng);
}

But models that should compile still do compile correctly as far as I know.

@bbbales2
Copy link
Member Author

Cool I'll build a tarball then!

@bbbales2
Copy link
Member Author

Yo @bob-carpenter this contains a cmdstan & stanc3 that can compile closures: cmdstan-closures.zip

The goal here is to check the implementation with the design doc, and update stuff as needed.

To keep the zip file under 10 megabytes so Github would let me upload it I removed the stan and math repos from there. Once you unzip that, you should be able to go into the folder and run the command bash checkout.sh and it'll check out the stan and math repos you need to make this run.

Implementation of that script is just this:

git clone --depth=1 --single-branch --branch=develop https://github.com/stan-dev/stan stan
git clone --depth=1 --single-branch --branch=feature/ode-closures https://github.com/nhuurre/math.git stan/lib/stan_math

There's a Lotka Volterra one in this comment that you can use to test your build is working.

@bbbales2
Copy link
Member Author

@nhuurre I'm not sure what to do about map_rect still. I put up a quick pull for integrate_1d (#2397) and it probably won't be too hard to do a similar one for the algebra solvers.

I guess the worst thing that could happen is we make closures that aren't MPI compatible, or are just really bad in some way. A less worse thing would be that we make sure our closures are MPI-friendly, but just don't implement them for map_rect. Best would be we implement map_rect.

@wds15 I don't want to distract you from adjoint ODEs, so the easiest thing I can think of is we do a Hangouts next week or something (@nhuurre you're invited too ofc. if you wanna come). I can walk you through closures and then we can talk about what it'd take to make them map_rect friendly?

@wds15
Copy link
Contributor

wds15 commented Feb 27, 2021

I need to catch up on this. It looks super cool. Can we drop mpi map eect support? Is that an option?

EDIT: So I am scheduled for a hangout next week?

@bbbales2
Copy link
Member Author

So I am scheduled for a hangout next week?

Pre-Stan meeting work? Or I'm free any day at the Stan time or the hour before.

Can we drop mpi map eect support?

I wouldn't mind doing that. I wanna just give it one try before going that direction, but I also don't wanna keep delaying closures. It would be pretty easy to just throw runtime errors in Math, and presumably stanc3 could have an error on compile ("map_rect does not support closures" or something), so it wouldn't be the worst outcome in the world. But it also might not be that bad to do map_rect, I just don't know how it works.

@wds15
Copy link
Contributor

wds15 commented Mar 1, 2021

Today or tomorrow 10h am your time / 16h my time would work. Maybe easiest to email me? CU

@rok-cesnovar
Copy link
Member

I would vote against removing MPI map_rect support.

What we could do if making closures work here for MPI would be a ton of work is make stanc3 aware that its compiling for MPI, like we do for OpenCL (when STAN_OPENCL is set, stanc3 is called with --use-opencl) and then report a semantic error if closures are used with map_rect if MPI is on.

@wds15
Copy link
Contributor

wds15 commented Mar 1, 2021

I am not saying to cancel MPI map_rect in its current form. It‘s just that we may not want closures to be delayed significantly due to MPI stuff. Thus, we could roll out closures without MPI support (meaning that map_rect MPI will not work).

(Map_rect with MPI would be quite nice, of course)

@rok-cesnovar
Copy link
Member

It‘s just that we may not want closures to be delayed significantly due to MPI stuff. Thus, we could roll out closures without MPI support (meaning that map_rect MPI will not work).

Oh, I am definitely on board with that.

@bbbales2
Copy link
Member Author

bbbales2 commented Mar 1, 2021

Talked to @wds15 this morning about map_rect. It looks difficult. I think we would need to be able to serialize the functors somehow and also handle the optimization where data is only shipped once to the worker nodes (which neither thing straightforward).

Let's do integrate_1d first (#2397), and then algebra_solver, and then come back to make a try at map_rect.

@wds15
Copy link
Contributor

wds15 commented Mar 1, 2021

@nhuurre Have you already done a bit of benchmarking? This feature screems for application with ODEs... but these are super performance critical and I would be scared to put any friction into the turning wheels of the ODE integrators.

Other than that this looks amazing (and there are other domains than ODEs where this is highly useful)!

@SteveBronder
Copy link
Collaborator

Not a biggie but jw, is it possible in the spec to have anonymous lambdas like

  vector[2] z[N] = ode_rk45_tol(functions vector (real t, vector z) {
    real u = z[1];
    real v = z[2];

    real du_dt = (alpha - beta * v) * u;
    real dv_dt = (-gamma + delta * u) * v;

    return [ du_dt, dv_dt ]';
  }, z_init, 0., ts, 1e-5, 1e-3, 500);

If the compiler could handle type checking of the inputs and return in C++ this would just translate to

  std::vector<Eigen::Matrix<scalar_t, -1, -1>> z = 
    ode_rk45_tol([](auto&& t, auto&& z) {
      // compiled body same as current impl
  }, z_init, 0., ts, 1e-5, 1e-3, 500);

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants