From 64b0e508124efa919649862f0483f9f7c2e90e71 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sat, 18 Feb 2023 17:23:09 -0600 Subject: [PATCH] Updates for new Aesara deprecations and AePPL version --- aemcmc/transforms.py | 2 +- environment.yml | 4 ++-- setup.py | 4 ++-- tests/test_basic.py | 1 - tests/test_conjugates.py | 1 - tests/test_rewriting.py | 2 +- 6 files changed, 6 insertions(+), 8 deletions(-) diff --git a/aemcmc/transforms.py b/aemcmc/transforms.py index ad4d56b..106d586 100644 --- a/aemcmc/transforms.py +++ b/aemcmc/transforms.py @@ -113,7 +113,7 @@ def invgamma_exponential(invgamma_expr, invexponential_expr): size=size_lv, dtype=dtype_lv, ) - invexponential_et = etuple(at.true_div, at.as_tensor(1.0), exponential_et) + invexponential_et = etuple(at.true_divide, at.as_tensor(1.0), exponential_et) return lall( eq(invgamma_expr, invgamma_et), eq(invexponential_expr, invexponential_et) diff --git a/environment.yml b/environment.yml index 19e87f1..7628a9f 100644 --- a/environment.yml +++ b/environment.yml @@ -11,8 +11,8 @@ dependencies: - compilers - numpy>=1.18.1 - scipy>=1.4.0 - - aesara>=2.8.3 - - aeppl>=0.1.0 + - aesara>=2.8.11 + - aeppl>=0.1.2 - aehmc>=0.0.10 - polyagamma>=1.3.2 - cons diff --git a/setup.py b/setup.py index b78c690..d5eb997 100644 --- a/setup.py +++ b/setup.py @@ -38,8 +38,8 @@ def get_versions(): install_requires=[ "numpy>=1.18.1", "scipy>=1.4.0", - "aesara>=2.8.3", - "aeppl>=0.1.0", + "aesara>=2.8.11", + "aeppl>=0.1.2", "aehmc>=0.0.10", "polyagamma>=1.3.2", "cons", diff --git a/tests/test_basic.py b/tests/test_basic.py index 936328a..37ae013 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -145,7 +145,6 @@ def test_nuts_with_closed_form(): assert beta_rv in sampler.sample_steps -@pytest.mark.xfail(reason="An `OpFromGraph` fix is needed to remove extra updates") def test_create_gibbs(): srng = RandomStream(0) diff --git a/tests/test_conjugates.py b/tests/test_conjugates.py index 64a6d00..a86fa06 100644 --- a/tests/test_conjugates.py +++ b/tests/test_conjugates.py @@ -31,7 +31,6 @@ def test_gamma_poisson_conjugate_contract(): q_lv = var() (posterior_expr,) = run(1, q_lv, gamma_poisson_conjugateo(y_vv, Y_rv, q_lv)) posterior = eval_if_etuple(posterior_expr) - aesara.dprint(posterior) assert isinstance(posterior.owner.op, type(at.random.gamma)) diff --git a/tests/test_rewriting.py b/tests/test_rewriting.py index 41b0795..07945f8 100644 --- a/tests/test_rewriting.py +++ b/tests/test_rewriting.py @@ -119,7 +119,7 @@ def test_SubsumingElemwise_constant_inputs(): srng = at.random.RandomStream(0) s = at.lscalar("s") - # The `1` is the constant input to a `true_div` `Elemwise` that should be + # The `1` is the constant input to a `true_divide` `Elemwise` that should be # "subsumed" Z = srng.exponential(1, size=s, name="Z") mu = 1 / Z