Skip to content

Commit ec3e350

Browse files
author
Jing Xie
committed
Add Uniform pareto conjugates
1 parent 64b0e50 commit ec3e350

File tree

2 files changed

+132
-1
lines changed

2 files changed

+132
-1
lines changed

aemcmc/conjugates.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from aesara.graph.rewriting.basic import in2out, node_rewriter
33
from aesara.graph.rewriting.db import LocalGroupDB
44
from aesara.graph.rewriting.unify import eval_if_etuple
5-
from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV
5+
from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV, UniformRV
66
from etuples import etuple, etuplize
77
from kanren import eq, lall, run
88
from unification import var
@@ -268,13 +268,97 @@ def local_beta_negative_binomial_posterior(fgraph, node):
268268
return rv_var.owner.outputs
269269

270270

271+
def uniform_pareto_conjugateo(observed_val, observed_rv_expr, posterior_expr):
272+
r"""Produce a goal that represents the application of Bayes theorem
273+
for a pareto prior with a uniform with 0 as the lower bound observation model.
274+
275+
.. math::
276+
Y \sim \operatorname{Uniform}\left(0, \theta\right)
277+
278+
279+
280+
Parameters
281+
----------
282+
observed_val
283+
The observed value.
284+
observed_rv_expr
285+
An expression that represents the observed variable.
286+
posterior_exp
287+
An expression that represents the posterior distribution of the latent
288+
variable.
289+
290+
"""
291+
# beta-negative_binomial observation model
292+
x_lv, k_lv = var(), var()
293+
theta_rng_lv = var()
294+
theta_size_lv = var()
295+
theta_type_idx_lv = var()
296+
theta_et = etuple(
297+
etuplize(at.random.pareto),
298+
theta_rng_lv,
299+
theta_size_lv,
300+
theta_type_idx_lv,
301+
k_lv,
302+
x_lv,
303+
)
304+
Y_et = etuple(etuplize(at.random.uniform), var(), var(), var(), var(), theta_et)
305+
306+
new_x_et = etuple(at.math.max, observed_val)
307+
new_k_et = etuple(etuplize(at.add), k_lv, 1)
308+
309+
theta_posterior_et = etuple(
310+
etuplize(at.random.pareto),
311+
new_k_et,
312+
new_x_et,
313+
rng=theta_rng_lv,
314+
size=theta_size_lv,
315+
dtype=theta_type_idx_lv,
316+
)
317+
return lall(
318+
eq(observed_rv_expr, Y_et),
319+
eq(posterior_expr, theta_posterior_et),
320+
)
321+
322+
323+
@node_rewriter([UniformRV])
324+
def local_uniform_pareto_posterior(fgraph, node):
325+
sampler_mappings = getattr(fgraph, "sampler_mappings", None)
326+
327+
rv_var = node.outputs[1]
328+
key = ("local_beta_negative_binomial_posterior", rv_var)
329+
330+
if sampler_mappings is None or key in sampler_mappings.rvs_seen:
331+
return None # pragma: no cover
332+
333+
q = var()
334+
335+
rv_et = etuplize(rv_var)
336+
337+
res = run(None, q, uniform_pareto_conjugateo(rv_var, rv_et, q))
338+
res = next(res, None)
339+
340+
if res is None:
341+
return None # pragma: no cover
342+
343+
pareto_rv = rv_et[-1].evaled_obj
344+
pareto_posterior = eval_if_etuple(res)
345+
346+
sampler_mappings.rvs_to_samplers.setdefault(pareto_rv, []).append(
347+
("local_uniform_pareto_posterior", pareto_posterior, None)
348+
)
349+
sampler_mappings.rvs_seen.add(key)
350+
351+
return rv_var.owner.outputs
352+
353+
271354
conjugates_db = LocalGroupDB(apply_all_rewrites=True)
272355
conjugates_db.name = "conjugates_db"
273356
conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic")
274357
conjugates_db.register("gamma_poisson", local_gamma_poisson_posterior, "basic")
275358
conjugates_db.register(
276359
"negative_binomial", local_beta_negative_binomial_posterior, "basic"
277360
)
361+
conjugates_db.register("uniform", local_uniform_pareto_posterior, "basic")
278362

279363

280364
sampler_finder_db.register(

tests/test_conjugates.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
beta_binomial_conjugateo,
1111
beta_negative_binomial_conjugateo,
1212
gamma_poisson_conjugateo,
13+
uniform_pareto_conjugateo,
1314
)
1415

1516

@@ -157,3 +158,49 @@ def test_beta_negative_binomial_conjugate_expand():
157158
expanded = eval_if_etuple(expanded_expr)
158159

159160
assert isinstance(expanded.owner.op, type(at.random.beta))
161+
162+
163+
def test_uniform_pareto_conjugate_contract():
164+
"""Produce the closed-form posterior for the uniform observation model with
165+
a pareto prior.
166+
167+
"""
168+
srng = RandomStream(0)
169+
170+
xm_tt = at.scalar("xm")
171+
k_tt = at.scalar("k")
172+
theta_rv = srng.pareto(k_tt, xm_tt, name="theta")
173+
174+
# zero = at.iscalar("zero")
175+
Y_rv = srng.uniform(0, theta_rv)
176+
y_vv = Y_rv.clone()
177+
y_vv.tag.name = "y"
178+
179+
q_lv = var()
180+
(posterior_expr,) = run(1, q_lv, uniform_pareto_conjugateo(y_vv, Y_rv, q_lv))
181+
posterior = eval_if_etuple(posterior_expr)
182+
183+
assert isinstance(posterior.owner.op, type(at.random.pareto))
184+
185+
# Build the sampling function and check the results on limiting cases.
186+
sample_fn = aesara.function((xm_tt, k_tt, y_vv), posterior)
187+
assert sample_fn(1.0, 1000, 1) == pytest.approx(1.0, abs=0.01) # k = 1000
188+
assert sample_fn(1.0, 1, 0) == pytest.approx(0.0, abs=0.01) # all zeros
189+
190+
191+
def test_uniform_pareto_binomial_conjugate_expand():
192+
"""Expand a contracted beta-binomial observation model."""
193+
194+
srng = RandomStream(0)
195+
196+
k_tt = at.scalar("k")
197+
y_vv = at.iscalar("y")
198+
n_tt = at.scalar("n")
199+
200+
Y_rv = srng.pareto(at.max(y_vv), k_tt + n_tt)
201+
202+
e_lv = var()
203+
(expanded_expr,) = run(1, e_lv, uniform_pareto_conjugateo(e_lv, y_vv, Y_rv))
204+
expanded = eval_if_etuple(expanded_expr)
205+
206+
assert isinstance(expanded.owner.op, type(at.random.pareto))

0 commit comments

Comments
 (0)