Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix jax backend sampling with variable names that are not valid identifiers #135

Merged
merged 3 commits into from
Jul 9, 2024

Conversation

aseyboldt
Copy link
Member

No description provided.

@@ -127,18 +127,34 @@ def test_det(backend, gradient_backend):
assert trace.posterior.b.shape[-1] == 2


@parameterize_backends
def test_non_identifier_names(backend, gradient_backend):
Copy link
Member

@ricardoV94 ricardoV94 Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are non identifier names? What was failing?

Copy link
Member Author

@aseyboldt aseyboldt Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we have a model like

    with pm.Model() as model:
        a = pm.Data("a/b", shape=2)

or a nested model, then the variables in the generated logp function have different names (they can't be called a/b after all...).
But I was using kwargs with the real variable names as keys to pass the values to the logp function.
So I just switched to positional args, since we know the order.

Copy link
Member

@ricardoV94 ricardoV94 Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe call the test invalid_kwarg_names?

How do we do it in PyMC when we have the point func with names like this? Seems like it should also fail because we unpack a dict, or that somehow works?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shared variables are typically hidden in the pytensor function, and that uses positional args.
And the strange pymc point functions also assume things are in the right order (which has bitten me more than once...)
I called it non-identifier because this happens more or less if name.isidentifier() is false.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the point func just called the underlying function with **state. It's true it doesn't have to interact with shareds, but the name thing should also be an issue there?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do we convert it to positional arguments? Inside PyTensor? Because here it's just unpacking the dict: https://github.com/pymc-devs/pymc/blob/main/pymc%2Fpytensorf.py#L617-L624

Copy link
Member

@ricardoV94 ricardoV94 Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway I thought the problem was you cant unpack non identifiers in Python

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, no limitation anywhere :-)

It is perfectly legal in python (and I think also in jax) to have a function

def foo(**kwargs):
    pass

And then pass it something like {"a/b": 1}. As long as foo takes the kwargs as **kwargs that's allowed.

The problem is that the jax function we generate in dispatch generates something like

def jax_dispatch(a_b):
    pass

and we can't call that function with foo(**{"a/b": 1}), because the name of the variable and the key in the dictionary don't match.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm right I didn't expect those internal variables to ever be called with kwargs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or for the names to even respect those from the original graph

@aseyboldt aseyboldt merged commit 523c193 into pymc-devs:main Jul 9, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants