Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jitting a module with segment sums #903

Open
errhernandez opened this issue Nov 28, 2024 · 8 comments
Open

jitting a module with segment sums #903

errhernandez opened this issue Nov 28, 2024 · 8 comments
Labels
question User queries

Comments

@errhernandez
Copy link

Hello; I am new to equinox and the jax ecosystem, so apologies if this is a trivial question, but I am having some difficulty in sorting it out:

I have written an eqx.Module and would like to use jit in its training; the problem is the module has several spots where ops.segment_sum are used. According to jax documentation this is ok, provided one can make the num_segments argument to segment_sum static (by which I interpret constant throughout all samples in the dataset). I have done this by constructing batches of samples in such a way that the num_segments argument is indeed the same for every batch. However, when decorating my training step function with @eqx.filter_jit it is slower than not jitting at all, and I suspect that this is because jit is recompiling every time because it is assuming that num_segments changes from batch to batch, when it does not.

I guess I need to declare those num_segment values as static somewhere, perhaps as an option to @eqx.filter_jit, but I have not found how to do it; can someone point me in the right direction? Thanks!

And while I am at it: is there a way to confirm that jit is recompiling at every step? At present this is my impression (I was expecting my jitted code to be much faster than the unjitted), but I am not 100% sure.

Thank you in advance for any hints received!

@patrick-kidger
Copy link
Owner

Welcome to JAX and Equinox! :)

You can check recompilations using eqx.debug.assert_max_traces.. This will also give you which argument caused recompilations.

Note that recompilation is triggered by passing new static arguments (non-arrays) through eqx.filter_jit, or by changing the shape/dtype of the arrays passed to eqx.filter_jit. It is not triggered by changes in segment_sum. Instead, if a traced (non-static) argument is passed to segment_sum then it should simply raise an error during compilation.

As such you should be able to isolate the cause of recompilations solely by checking what is passed across the JIT boundary :) (e.g. using eqx.debug.assert_max_traces or by any other means)

I hope that helps!

@patrick-kidger patrick-kidger added the question User queries label Nov 29, 2024
@errhernandez
Copy link
Author

Thanks for your response. By reading more thoroughly your excellent documentation I had found out about eqz.debug.assert_max_traces, so I can confirm that indeed there were recompilations happening, and I assumed this was due to the segment_sums in my model because jax documentation says that to jit a segment_sum one must ensure that the num_segments does not change; so I have been chasing a red herring? It surprises me that the suspects are passing new static arguments or changing shapes/dtypes, as you suggest, but I will make sure of this. Thanks again!

@errhernandez
Copy link
Author

Sorry! Me again!

When you say "Note that recompilation is triggered by passing new static arguments (non-arrays) through eqx.filter_jit", do I interpret this correctly to mean that the arguments have to be arrays? A dict is not allowed? I am certainly passing a number of book-keeping arrays, but bundled up in a dictionary; the sizes and shapes of those book keeping arrays are the same from sample to sample, but their values inevitably change. But that shouldn't trigger recompilation, should it?

@johannahaffner
Copy link

Hi Eduardo, Hi Patrick,

a dictionary with arrays is a PyTree of Arrays, this should not be a problem. However, it is important that the numbers in your dictionary are actually wrapped in arrays, like so:

import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array, PyTree

@eqx.filter_jit
@eqx.debug.assert_max_traces(max_traces=1)
def some_function(x: Array, book: PyTree[Array]) -> tuple[Array, PyTree[Array]]:
    y = 2 * x
    book['a'] += 1
    return y, book

x = jnp.linspace(0, 1, 10)

dict_of_arrays = {'a': jnp.array(1), 'b': jnp.array(2)}
dict_of_numbers = {'a': 1, 'b': 2}

for _ in range(2):
    y, dict_of_arrays = some_function(x, dict_of_arrays)

for _ in range(2):
    y, dict_of_numbers = some_function(x, dict_of_numbers)
``

Note that passing a dictionary of numbers will trigger recompilation.

@errhernandez
Copy link
Author

Thank you @johannahaffner for your feedback. Indeed at least one of the items in my dict of indices was itself a dict of numbers; I was not aware that this would trigger recompilation, so I have changed following your remarks. I think I am closing in on the solution thanks to your and @patrick-kidger's comments, but my problem still persists; I need to try a little harder.

Specifically, my application is a graph neural network, and more specifically a graph regression. I batch my graphs together so that all the batches have exactly the same total number of nodes and the same total number of edges (I did this thinking it would help with segment_sums, although according to @patrick-kidger this is not the issue). However there remain other book-keeping indices, to track which node sends messages to which other nodes (an same for edges) that cannot easily be made of the same size, as these depend on each graph's particular pattern of connections. I suspect the problem may lie here. Could this be it?

Thanks again to you both for your helpful comments.

@johannahaffner
Copy link

johannahaffner commented Dec 2, 2024

That sounds like it could be it!

Can you work with padded arrays for the bookkeeping too? I'm not too familiar with graph neural networks, but it sounds like this could be an incidence matrix?

@errhernandez
Copy link
Author

errhernandez commented Dec 2, 2024

This I did already (padding with zeros) with the node and edge features so as to make sure that each batch (a graph composed of multiple graphs) had the same total number of nodes and edges; I am not sure if I can do the same with the book-keeping indices, but I will give it a good try. Thank you so much!

And yes, there is an incidence matrix, but there are other arrays as well that need to be dealt with, which makes it a little bit more tedious. I'm sure there's a way around it, though.

@johannahaffner
Copy link

I'm sure there's a way around it, though.

Great!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants