-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
5 rk45 jax solver returning nan solution #8
5 rk45 jax solver returning nan solution #8
Conversation
…:act3-ace/safe-autonomy-simulation into 5-rk45-jax-solver-returning-nan-solution
@keatincf I moved the JAX functionality to its own package and added an environment variable check for enabling JAX. Unfortunately using JAX seems to add significant overhead, resulting in a 10x slowdown vs numpy. I couldn't nail down why during profiling. I also haven't directly addressed the NaN issue yet. I couldn't find any good guidance in the solver docs for how to deal with invalid solutions. Note that I used the |
test_ode is failing because the _clip_state_dot_direct was moved out of ODEDynamics. I spent a little bit of time trying to see if there was anything I could see that could be causing the slowdown due to JAX. I didn't see anything that would make me think JAX was recompling more than expected. Removing the jit decorators causes a further slow down, but only a second or two extra slowdown. I was worried that maybe the functions being compiled weren't pure enough be compiled properly. I also tried to utilize a GPU enabled JAX, but that generated a larger slowdown. I'm not sure if that was partially caused by a driver/CUDA version mismatch (Ubuntu's repos have CUDA 12.4 and JAX is using CUDA 12.6). It looks like the biggest slowdown might be using diffrax compared to scipy.integrate. My guess is that JAX does not provide any major improvement on CPUs compared to numpy/scipy. It could provide better performance on a GPU/TPU, but we may need to adjust the code more to properly take advantage of JAX. It should be ok to keep JAX in since you've noted that the use of JAX is experimental. |
Yeah I also tried just using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kyle brought up a use case that this solution doesn't cover, or makes it harder to achieve. When running the inspection environment, you may only want JAX to be used for the dynamcis solver used by RTA and not use JAX when stepping the dynamics of the entities. Using the environment variable would make it harder, if not impossible, to have multiple solvers in an instance, with some using JAX and some not.
JAX flags have been moved to the class level. Class objects can now decide if they want to use JAX or not. JAX will only be used if it is available in the environment. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't tell if the issue I'm having is Python's loose typing or the intermingling of jax and non-jax code. It's hard to track whether types are staying consistent and whether casting is happening too often.
Yeah Jax requires that input arrays to JIT functions are Jax types hence the casting. Since we aren't globally setting a Jax flag then it doesn't make sense to have Jax arrays in non-Jax compatible code. When there was a global flag we could just make everything a Jax type internally. Now we need to make sure that inputs to JIT functions are Jax types. I can go back in and see if I can beef up the type hinting. Unfortunately checking a Jax type requires Jax to be available at runtime. We could make Jax a non-optional dependency again since we are being explicit about its use. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have some new comments. We could potentially move some of the comments to other issues, if we find that is appropriate.
## [2.1.1](v2.1.0...v2.1.1) (2024-10-25) ### Bug Fixes * **jax:** move jax support to optional experimental feature ([1b4a52d](1b4a52d)), closes [#8](#8)
This change moves JAX functions to a separate module and allows toggling use of JAX via an environment variable.