Skip to content

Commit af45283

Browse files
author
Jing Xie
committed
Add Uniform pareto conjugates
1 parent 64b0e50 commit af45283

File tree

2 files changed

+110
-1
lines changed

2 files changed

+110
-1
lines changed

aemcmc/conjugates.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from aesara.graph.rewriting.basic import in2out, node_rewriter
33
from aesara.graph.rewriting.db import LocalGroupDB
44
from aesara.graph.rewriting.unify import eval_if_etuple
5-
from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV
5+
from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV, UniformRV
66
from etuples import etuple, etuplize
77
from kanren import eq, lall, run
88
from unification import var
@@ -268,13 +268,99 @@ def local_beta_negative_binomial_posterior(fgraph, node):
268268
return rv_var.owner.outputs
269269

270270

271+
def uniform_pareto_conjugateo(observed_val, observed_rv_expr, posterior_expr):
272+
r"""Produce a goal that represents the application of Bayes theorem
273+
for a pareto prior with a uniform with 0 as the lower bound observation model.
274+
275+
.. math::
276+
Y \sim \operatorname{Uniform}\left(0, \theta\right)
277+
278+
279+
280+
Parameters
281+
----------
282+
observed_val
283+
The observed value.
284+
observed_rv_expr
285+
An expression that represents the observed variable.
286+
posterior_exp
287+
An expression that represents the posterior distribution of the latent
288+
variable.
289+
290+
"""
291+
# beta-negative_binomial observation model
292+
x_lv, k_lv = var(), var()
293+
theta_rng_lv = var()
294+
theta_size_lv = var()
295+
theta_type_idx_lv = var()
296+
theta_et = etuple(
297+
etuplize(at.random.pareto),
298+
theta_rng_lv,
299+
theta_size_lv,
300+
theta_type_idx_lv,
301+
k_lv,
302+
x_lv,
303+
)
304+
Y_et = etuple(etuplize(at.random.beta), var(), var(), var(), 1, theta_et)
305+
306+
# new_x_et = at.max(observed_val)
307+
new_x_et = at.max(observed_val, x_lv)
308+
new_k_et = etuple(etuplize(at.add), k_lv, 1)
309+
310+
theta_posterior_et = etuple(
311+
etuplize(at.random.pareto),
312+
new_k_et,
313+
new_x_et,
314+
rng=theta_rng_lv,
315+
size=theta_size_lv,
316+
dtype=theta_type_idx_lv,
317+
)
318+
319+
return lall(
320+
eq(observed_rv_expr, Y_et),
321+
eq(posterior_expr, theta_posterior_et),
322+
)
323+
324+
325+
@node_rewriter([UniformRV])
326+
def local_uniform_pareto_posterior(fgraph, node):
327+
sampler_mappings = getattr(fgraph, "sampler_mappings", None)
328+
329+
rv_var = node.outputs[1]
330+
key = ("local_beta_negative_binomial_posterior", rv_var)
331+
332+
if sampler_mappings is None or key in sampler_mappings.rvs_seen:
333+
return None # pragma: no cover
334+
335+
q = var()
336+
337+
rv_et = etuplize(rv_var)
338+
339+
res = run(None, q, uniform_pareto_conjugateo(rv_var, rv_et, q))
340+
res = next(res, None)
341+
342+
if res is None:
343+
return None # pragma: no cover
344+
345+
pareto_rv = rv_et[-1].evaled_obj
346+
pareto_posterior = eval_if_etuple(res)
347+
348+
sampler_mappings.rvs_to_samplers.setdefault(pareto_rv, []).append(
349+
("local_uniform_pareto_posterior", pareto_posterior, None)
350+
)
351+
sampler_mappings.rvs_seen.add(key)
352+
353+
return rv_var.owner.outputs
354+
355+
271356
conjugates_db = LocalGroupDB(apply_all_rewrites=True)
272357
conjugates_db.name = "conjugates_db"
273358
conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic")
274359
conjugates_db.register("gamma_poisson", local_gamma_poisson_posterior, "basic")
275360
conjugates_db.register(
276361
"negative_binomial", local_beta_negative_binomial_posterior, "basic"
277362
)
363+
conjugates_db.register("uniform", local_uniform_pareto_posterior, "basic")
278364

279365

280366
sampler_finder_db.register(

tests/test_conjugates.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
beta_binomial_conjugateo,
1111
beta_negative_binomial_conjugateo,
1212
gamma_poisson_conjugateo,
13+
uniform_pareto_conjugateo,
1314
)
1415

1516

@@ -157,3 +158,25 @@ def test_beta_negative_binomial_conjugate_expand():
157158
expanded = eval_if_etuple(expanded_expr)
158159

159160
assert isinstance(expanded.owner.op, type(at.random.beta))
161+
162+
163+
def test_uniform_pareto_conjugate_contract():
164+
"""Produce the closed-form posterior for the uniform observation model with
165+
a pareto prior.
166+
167+
"""
168+
srng = RandomStream(0)
169+
170+
xm_tt = at.scalar("xm")
171+
k_tt = at.scalar("k")
172+
theta_rv = srng.pareto(k_tt, xm_tt, name="theta")
173+
174+
Y_rv = srng.uniform(0, theta_rv)
175+
y_vv = Y_rv.clone()
176+
y_vv.tag.name = "y"
177+
178+
q_lv = var()
179+
(posterior_expr,) = run(1, q_lv, uniform_pareto_conjugateo(y_vv, Y_rv, q_lv))
180+
posterior = eval_if_etuple(posterior_expr)
181+
182+
assert isinstance(posterior.owner.op, type(at.random.pareto))

0 commit comments

Comments
 (0)