Skip to content

Commit

Permalink
Make more distributions symbolic so they work in different backends
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Feb 19, 2025
1 parent ce5f2a2 commit 9168e75
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 154 deletions.
12 changes: 8 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ jobs:
name: ${{ matrix.os }} ${{ matrix.floatx }}
fail_ci_if_error: false

external_samplers:
alternative_backends:
needs: changes
if: ${{ needs.changes.outputs.changes == 'true' }}
strategy:
Expand All @@ -293,7 +293,11 @@ jobs:
floatx: [float64]
python-version: ["3.12"]
test-subset:
- tests/sampling/test_jax.py tests/sampling/test_mcmc_external.py
- |
tests/distributions/test_random_alternative_backends.py
tests/sampling/test_jax.py
tests/sampling/test_mcmc_external.py
fail-fast: false
runs-on: ${{ matrix.os }}
env:
Expand All @@ -308,7 +312,7 @@ jobs:
persist-credentials: false
- uses: mamba-org/setup-micromamba@v2
with:
environment-file: conda-envs/environment-jax.yml
environment-file: conda-envs/environment-alternative-backends.yml
create-args: >-
python=${{matrix.python-version}}
environment-name: pymc-test
Expand All @@ -327,7 +331,7 @@ jobs:
with:
token: ${{ secrets.CODECOV_TOKEN }} # use token for more robust uploads
env_vars: TEST_SUBSET
name: JAX tests - ${{ matrix.os }} ${{ matrix.floatx }}
name: Alternative backend tests - ${{ matrix.os }} ${{ matrix.floatx }}
fail_ci_if_error: false

Check warning

Code scanning / zizmor

overly broad permissions Warning test

overly broad permissions

float32:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies:
- cloudpickle
- h5py>=2.7
- zarr>=2.5.0,<3
- nutpie >= 0.13.4
# Jaxlib version must not be greater than jax version!
- blackjax>=1.2.2
- jax>=0.4.28
Expand Down
29 changes: 17 additions & 12 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -2595,23 +2595,27 @@ def dist(cls, nu, **kwargs):
return Gamma.dist(alpha=nu / 2, beta=1 / 2, **kwargs)


class WeibullBetaRV(RandomVariable):
class WeibullBetaRV(SymbolicRandomVariable):
name = "weibull"
signature = "(),()->()"
dtype = "floatX"
extended_signature = "[rng],[size],(),()->[rng],()"
_print_name = ("Weibull", "\\operatorname{Weibull}")

def __call__(self, alpha, beta, size=None, **kwargs):
return super().__call__(alpha, beta, size=size, **kwargs)

@classmethod
def rng_fn(cls, rng, alpha, beta, size) -> np.ndarray:
if size is None:
size = np.broadcast_shapes(alpha.shape, beta.shape)
return np.asarray(beta * rng.weibull(alpha, size=size))
def rv_op(cls, alpha, beta, *, rng=None, size=None) -> np.ndarray:
alpha = pt.as_tensor(alpha)
beta = pt.as_tensor(beta)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

if rv_size_is_none(size):
size = pt.broadcast_shapes(alpha.shape, beta.shape)

weibull_beta = WeibullBetaRV()
next_rng, raw_weibull = pt.random.weibull(alpha, size=size).owner.outputs
draws = beta * raw_weibull
return cls(
inputs=[rng, size, alpha, beta],
outputs=[next_rng, draws],
)(rng, size, alpha, beta)


class Weibull(PositiveContinuous):
Expand Down Expand Up @@ -2660,7 +2664,8 @@ class Weibull(PositiveContinuous):
Scale parameter (beta > 0).
"""

rv_op = weibull_beta
rv_type = WeibullBetaRV
rv_op = WeibullBetaRV.rv_op

@classmethod
def dist(cls, alpha, beta, *args, **kwargs):
Expand Down
Loading

0 comments on commit 9168e75

Please sign in to comment.