Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backend-agnostic ABC #28

Merged
merged 4 commits into from
Mar 21, 2025
Merged

Backend-agnostic ABC #28

merged 4 commits into from
Mar 21, 2025

Conversation

willGraham01
Copy link
Collaborator

@willGraham01 willGraham01 commented Mar 20, 2025

Concerns #8 |

Adds a base class for backend-agnostic objects that will be needed in the package.

A BackendAgnostic class is setup such that it must have certain attributes, and a reference to some _backend_obj that will actually be doing the heavy-lifting for accessing / computing these attributes. This allows us to define methods on the BackendAgnostic class that conform to our expected frontend behaviours / syntax, and have these methods translate the arguments provided by the frontend into the arguments that need to be passed to the _backend_obj.

The most obvious example of these objects will be Distributions - these are required to have a sample method, but different backends (jax.random, distrax, numpyro, etc) have different syntaxes for sampling from their distributions. A BackendAgnostic distribution necessitates having a sample method, but for the different backends it would be implemented differently;

class Distribution(BackendAgnostic):

    @property
    def _frontend_provides(self):
        return ("sample",)

class JaxDistribution(Distribution):

    def sample(self, rng_key, sample_shape):
        return self._backend_obj.sample(rng_key, sample_shape)

class NumPyroDistribution(Distribution):

    def sample(self, rng_key, sample_shape):
        return self._backend_obj.sample(key=rng_key, sample_shape=sample_shape)

NOTE: This will be one of several "breakdown" PRs that all-together form #23.

NOTE 2: In light if #24, the jax-specific rng_keys should eventually be replaced with some backend-agnostic class that accepts a .split method.

@willGraham01 willGraham01 marked this pull request as ready for review March 20, 2025 10:55
@willGraham01 willGraham01 requested a review from mscroggs March 20, 2025 10:57
@mscroggs mscroggs mentioned this pull request Mar 20, 2025
@willGraham01 willGraham01 merged commit 65e26ee into main Mar 21, 2025
5 checks passed
@willGraham01 willGraham01 deleted the wgraham/backend-agnostic-abc branch March 21, 2025 09:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants