Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make distributions and graph talk to each other #30

Open
wants to merge 26 commits into
base: main
Choose a base branch
from

Conversation

mscroggs
Copy link
Collaborator

@mscroggs mscroggs commented Mar 20, 2025

Builds on top of #16.

Removes placeholder distribution classes in graph, and plugs in the proper distribution classes in their place

@mscroggs mscroggs marked this pull request as draft March 20, 2025 11:14
@mscroggs mscroggs changed the base branch from main to mscroggs/normal-example March 20, 2025 11:14
outcome_node_label: str,
samples: int,
*,
rng_key: jax.Array,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #24

concrete_dist = self._dist.construct(
**parameters, **self._constant_parameters
)
output[sample] = concrete_dist.sample(new_key[sample], 1)[0][0]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#25

return concrete_dist.sample(rng_key, samples)
output = np.zeros(samples)
new_key = jax.random.split(rng_key, samples)
for sample in range(samples):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#26

new_key = jax.random.split(rng_key, samples)
for sample in range(samples):
parameters = {
i: sampled_dependencies[j][sample] for i, j in self._parameters.items()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#29

Base automatically changed from mscroggs/normal-example to main March 20, 2025 14:09
@mscroggs mscroggs marked this pull request as ready for review March 20, 2025 14:58
@mscroggs mscroggs requested a review from willGraham01 March 20, 2025 14:58
@@ -76,7 +78,7 @@ def __init__(self) -> None:
"""Create a family of normal distributions."""
super().__init__(Normal, family_name="Normal")

def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal:
def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal: # type: ignore # noqa: PGH003
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy was very unhappy with everything I tried putting here. @willGraham01: Any idea what it should be?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we originally had this as the following, but either ruff or mypy was not happy that the base class didn't have the *,, but having *, **parameters in the base class is a syntax error...

Suggested change
def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal: # type: ignore # noqa: PGH003
def construct(self, *, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal: # type: ignore # noqa: PGH003

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant