|
2 | 2 | from aesara.graph.rewriting.basic import in2out, node_rewriter
|
3 | 3 | from aesara.graph.rewriting.db import LocalGroupDB
|
4 | 4 | from aesara.graph.rewriting.unify import eval_if_etuple
|
5 |
| -from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV |
| 5 | +from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV, UniformRV |
6 | 6 | from etuples import etuple, etuplize
|
7 | 7 | from kanren import eq, lall, run
|
8 | 8 | from unification import var
|
@@ -268,13 +268,99 @@ def local_beta_negative_binomial_posterior(fgraph, node):
|
268 | 268 | return rv_var.owner.outputs
|
269 | 269 |
|
270 | 270 |
|
| 271 | +def uniform_pareto_conjugateo(observed_val, observed_rv_expr, posterior_expr): |
| 272 | + r"""Produce a goal that represents the application of Bayes theorem |
| 273 | + for a pareto prior with a uniform with 0 as the lower bound observation model. |
| 274 | +
|
| 275 | + .. math:: |
| 276 | + Y \sim \operatorname{Uniform}\left(0, \theta\right) |
| 277 | +
|
| 278 | +
|
| 279 | +
|
| 280 | + Parameters |
| 281 | + ---------- |
| 282 | + observed_val |
| 283 | + The observed value. |
| 284 | + observed_rv_expr |
| 285 | + An expression that represents the observed variable. |
| 286 | + posterior_exp |
| 287 | + An expression that represents the posterior distribution of the latent |
| 288 | + variable. |
| 289 | +
|
| 290 | + """ |
| 291 | + # beta-negative_binomial observation model |
| 292 | + x_lv, k_lv = var(), var() |
| 293 | + theta_rng_lv = var() |
| 294 | + theta_size_lv = var() |
| 295 | + theta_type_idx_lv = var() |
| 296 | + theta_et = etuple( |
| 297 | + etuplize(at.random.pareto), |
| 298 | + theta_rng_lv, |
| 299 | + theta_size_lv, |
| 300 | + theta_type_idx_lv, |
| 301 | + k_lv, |
| 302 | + x_lv, |
| 303 | + ) |
| 304 | + Y_et = etuple(etuplize(at.random.beta), var(), var(), var(), 1, theta_et) |
| 305 | + |
| 306 | + # new_x_et = at.max(observed_val) |
| 307 | + new_x_et = at.max(observed_val, x_lv) |
| 308 | + new_k_et = etuple(etuplize(at.add), k_lv, 1) |
| 309 | + |
| 310 | + theta_posterior_et = etuple( |
| 311 | + etuplize(at.random.pareto), |
| 312 | + new_k_et, |
| 313 | + new_x_et, |
| 314 | + rng=theta_rng_lv, |
| 315 | + size=theta_size_lv, |
| 316 | + dtype=theta_type_idx_lv, |
| 317 | + ) |
| 318 | + |
| 319 | + return lall( |
| 320 | + eq(observed_rv_expr, Y_et), |
| 321 | + eq(posterior_expr, theta_posterior_et), |
| 322 | + ) |
| 323 | + |
| 324 | + |
| 325 | +@node_rewriter([UniformRV]) |
| 326 | +def local_uniform_pareto_posterior(fgraph, node): |
| 327 | + sampler_mappings = getattr(fgraph, "sampler_mappings", None) |
| 328 | + |
| 329 | + rv_var = node.outputs[1] |
| 330 | + key = ("local_beta_negative_binomial_posterior", rv_var) |
| 331 | + |
| 332 | + if sampler_mappings is None or key in sampler_mappings.rvs_seen: |
| 333 | + return None # pragma: no cover |
| 334 | + |
| 335 | + q = var() |
| 336 | + |
| 337 | + rv_et = etuplize(rv_var) |
| 338 | + |
| 339 | + res = run(None, q, uniform_pareto_conjugateo(rv_var, rv_et, q)) |
| 340 | + res = next(res, None) |
| 341 | + |
| 342 | + if res is None: |
| 343 | + return None # pragma: no cover |
| 344 | + |
| 345 | + pareto_rv = rv_et[-1].evaled_obj |
| 346 | + pareto_posterior = eval_if_etuple(res) |
| 347 | + |
| 348 | + sampler_mappings.rvs_to_samplers.setdefault(pareto_rv, []).append( |
| 349 | + ("local_uniform_pareto_posterior", pareto_posterior, None) |
| 350 | + ) |
| 351 | + sampler_mappings.rvs_seen.add(key) |
| 352 | + |
| 353 | + return rv_var.owner.outputs |
| 354 | + |
| 355 | + |
271 | 356 | conjugates_db = LocalGroupDB(apply_all_rewrites=True)
|
272 | 357 | conjugates_db.name = "conjugates_db"
|
273 | 358 | conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic")
|
274 | 359 | conjugates_db.register("gamma_poisson", local_gamma_poisson_posterior, "basic")
|
275 | 360 | conjugates_db.register(
|
276 | 361 | "negative_binomial", local_beta_negative_binomial_posterior, "basic"
|
277 | 362 | )
|
| 363 | +conjugates_db.register("uniform", local_uniform_pareto_posterior, "basic") |
278 | 364 |
|
279 | 365 |
|
280 | 366 | sampler_finder_db.register(
|
|
0 commit comments