-
-
Notifications
You must be signed in to change notification settings - Fork 147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jitting a module with segment sums #903
Comments
Welcome to JAX and Equinox! :) You can check recompilations using Note that recompilation is triggered by passing new static arguments (non-arrays) through As such you should be able to isolate the cause of recompilations solely by checking what is passed across the JIT boundary :) (e.g. using I hope that helps! |
Thanks for your response. By reading more thoroughly your excellent documentation I had found out about eqz.debug.assert_max_traces, so I can confirm that indeed there were recompilations happening, and I assumed this was due to the segment_sums in my model because jax documentation says that to jit a segment_sum one must ensure that the num_segments does not change; so I have been chasing a red herring? It surprises me that the suspects are passing new static arguments or changing shapes/dtypes, as you suggest, but I will make sure of this. Thanks again! |
Sorry! Me again! When you say "Note that recompilation is triggered by passing new static arguments (non-arrays) through eqx.filter_jit", do I interpret this correctly to mean that the arguments have to be arrays? A dict is not allowed? I am certainly passing a number of book-keeping arrays, but bundled up in a dictionary; the sizes and shapes of those book keeping arrays are the same from sample to sample, but their values inevitably change. But that shouldn't trigger recompilation, should it? |
Hi Eduardo, Hi Patrick, a dictionary with arrays is a PyTree of Arrays, this should not be a problem. However, it is important that the numbers in your dictionary are actually wrapped in arrays, like so: import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array, PyTree
@eqx.filter_jit
@eqx.debug.assert_max_traces(max_traces=1)
def some_function(x: Array, book: PyTree[Array]) -> tuple[Array, PyTree[Array]]:
y = 2 * x
book['a'] += 1
return y, book
x = jnp.linspace(0, 1, 10)
dict_of_arrays = {'a': jnp.array(1), 'b': jnp.array(2)}
dict_of_numbers = {'a': 1, 'b': 2}
for _ in range(2):
y, dict_of_arrays = some_function(x, dict_of_arrays)
for _ in range(2):
y, dict_of_numbers = some_function(x, dict_of_numbers)
``
Note that passing a dictionary of numbers will trigger recompilation. |
Thank you @johannahaffner for your feedback. Indeed at least one of the items in my dict of indices was itself a dict of numbers; I was not aware that this would trigger recompilation, so I have changed following your remarks. I think I am closing in on the solution thanks to your and @patrick-kidger's comments, but my problem still persists; I need to try a little harder. Specifically, my application is a graph neural network, and more specifically a graph regression. I batch my graphs together so that all the batches have exactly the same total number of nodes and the same total number of edges (I did this thinking it would help with segment_sums, although according to @patrick-kidger this is not the issue). However there remain other book-keeping indices, to track which node sends messages to which other nodes (an same for edges) that cannot easily be made of the same size, as these depend on each graph's particular pattern of connections. I suspect the problem may lie here. Could this be it? Thanks again to you both for your helpful comments. |
That sounds like it could be it! Can you work with padded arrays for the bookkeeping too? I'm not too familiar with graph neural networks, but it sounds like this could be an incidence matrix? |
This I did already (padding with zeros) with the node and edge features so as to make sure that each batch (a graph composed of multiple graphs) had the same total number of nodes and edges; I am not sure if I can do the same with the book-keeping indices, but I will give it a good try. Thank you so much! And yes, there is an incidence matrix, but there are other arrays as well that need to be dealt with, which makes it a little bit more tedious. I'm sure there's a way around it, though. |
Great! |
Hello; I am new to equinox and the jax ecosystem, so apologies if this is a trivial question, but I am having some difficulty in sorting it out:
I have written an eqx.Module and would like to use jit in its training; the problem is the module has several spots where ops.segment_sum are used. According to jax documentation this is ok, provided one can make the num_segments argument to segment_sum static (by which I interpret constant throughout all samples in the dataset). I have done this by constructing batches of samples in such a way that the num_segments argument is indeed the same for every batch. However, when decorating my training step function with @eqx.filter_jit it is slower than not jitting at all, and I suspect that this is because jit is recompiling every time because it is assuming that num_segments changes from batch to batch, when it does not.
I guess I need to declare those num_segment values as static somewhere, perhaps as an option to @eqx.filter_jit, but I have not found how to do it; can someone point me in the right direction? Thanks!
And while I am at it: is there a way to confirm that jit is recompiling at every step? At present this is my impression (I was expecting my jitted code to be much faster than the unjitted), but I am not 100% sure.
Thank you in advance for any hints received!
The text was updated successfully, but these errors were encountered: