Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

5 rk45 jax solver returning nan solution #8

Merged
merged 13 commits into from
Oct 25, 2024

Conversation

jamie-cunningham
Copy link
Collaborator

This change moves JAX functions to a separate module and allows toggling use of JAX via an environment variable.

@jamie-cunningham jamie-cunningham linked an issue Sep 6, 2024 that may be closed by this pull request
@jamie-cunningham jamie-cunningham self-assigned this Sep 6, 2024
@jamie-cunningham
Copy link
Collaborator Author

@keatincf I moved the JAX functionality to its own package and added an environment variable check for enabling JAX. Unfortunately using JAX seems to add significant overhead, resulting in a 10x slowdown vs numpy. I couldn't nail down why during profiling.

I also haven't directly addressed the NaN issue yet. I couldn't find any good guidance in the solver docs for how to deal with invalid solutions. Note that I used the diffrax package for JAX solvers as it appears the JAX ode solver has been declared out of scope by the JAX team and no longer is actively developed. Any thoughts are welcome.

@keatincf
Copy link
Collaborator

keatincf commented Sep 9, 2024

test_ode is failing because the _clip_state_dot_direct was moved out of ODEDynamics.

I spent a little bit of time trying to see if there was anything I could see that could be causing the slowdown due to JAX. I didn't see anything that would make me think JAX was recompling more than expected. Removing the jit decorators causes a further slow down, but only a second or two extra slowdown. I was worried that maybe the functions being compiled weren't pure enough be compiled properly.

I also tried to utilize a GPU enabled JAX, but that generated a larger slowdown. I'm not sure if that was partially caused by a driver/CUDA version mismatch (Ubuntu's repos have CUDA 12.4 and JAX is using CUDA 12.6).

It looks like the biggest slowdown might be using diffrax compared to scipy.integrate.

My guess is that JAX does not provide any major improvement on CPUs compared to numpy/scipy. It could provide better performance on a GPU/TPU, but we may need to adjust the code more to properly take advantage of JAX.

It should be ok to keep JAX in since you've noted that the use of JAX is experimental.

@jamie-cunningham
Copy link
Collaborator Author

test_ode is failing because the _clip_state_dot_direct was moved out of ODEDynamics.

I spent a little bit of time trying to see if there was anything I could see that could be causing the slowdown due to JAX. I didn't see anything that would make me think JAX was recompling more than expected. Removing the jit decorators causes a further slow down, but only a second or two extra slowdown. I was worried that maybe the functions being compiled weren't pure enough be compiled properly.

I also tried to utilize a GPU enabled JAX, but that generated a larger slowdown. I'm not sure if that was partially caused by a driver/CUDA version mismatch (Ubuntu's repos have CUDA 12.4 and JAX is using CUDA 12.6).

It looks like the biggest slowdown might be using diffrax compared to scipy.integrate.

My guess is that JAX does not provide any major improvement on CPUs compared to numpy/scipy. It could provide better performance on a GPU/TPU, but we may need to adjust the code more to properly take advantage of JAX.

It should be ok to keep JAX in since you've noted that the use of JAX is experimental.

Yeah I also tried just using scipy.integrate in a JIT wrapper but there were numpy vs jax issues that led me to diffrax. It doesn't seem like JAX handles third-party packages using numpy very well.

Copy link
Collaborator

@keatincf keatincf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kyle brought up a use case that this solution doesn't cover, or makes it harder to achieve. When running the inspection environment, you may only want JAX to be used for the dynamcis solver used by RTA and not use JAX when stepping the dynamics of the entities. Using the environment variable would make it harder, if not impossible, to have multiple solvers in an instance, with some using JAX and some not.

@jamie-cunningham
Copy link
Collaborator Author

JAX flags have been moved to the class level. Class objects can now decide if they want to use JAX or not. JAX will only be used if it is available in the environment.

Copy link
Collaborator

@keatincf keatincf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't tell if the issue I'm having is Python's loose typing or the intermingling of jax and non-jax code. It's hard to track whether types are staying consistent and whether casting is happening too often.

safe_autonomy_simulation/opts.py Outdated Show resolved Hide resolved
safe_autonomy_simulation/utils/conversion.py Outdated Show resolved Hide resolved
safe_autonomy_simulation/dynamics/dynamics.py Show resolved Hide resolved
@jamie-cunningham
Copy link
Collaborator Author

casting

Yeah Jax requires that input arrays to JIT functions are Jax types hence the casting. Since we aren't globally setting a Jax flag then it doesn't make sense to have Jax arrays in non-Jax compatible code. When there was a global flag we could just make everything a Jax type internally. Now we need to make sure that inputs to JIT functions are Jax types.

I can go back in and see if I can beef up the type hinting. Unfortunately checking a Jax type requires Jax to be available at runtime. We could make Jax a non-optional dependency again since we are being explicit about its use.

Copy link
Collaborator

@keatincf keatincf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have some new comments. We could potentially move some of the comments to other issues, if we find that is appropriate.

safe_autonomy_simulation/dynamics/ode.py Outdated Show resolved Hide resolved
safe_autonomy_simulation/dynamics/ode.py Show resolved Hide resolved
safe_autonomy_simulation/dynamics/ode.py Show resolved Hide resolved
@jamie-cunningham jamie-cunningham merged commit 1b4a52d into main Oct 25, 2024
4 checks passed
@jamie-cunningham jamie-cunningham deleted the 5-rk45-jax-solver-returning-nan-solution branch October 25, 2024 17:53
jamie-cunningham pushed a commit that referenced this pull request Oct 25, 2024
## [2.1.1](v2.1.0...v2.1.1) (2024-10-25)

### Bug Fixes

* **jax:** move jax support to optional experimental feature ([1b4a52d](1b4a52d)), closes [#8](#8)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RK45 Jax Solver Returning NaN Solution
2 participants