Skip to content

Commit

Permalink
fix(jax): move jax support to optional experimental feature
Browse files Browse the repository at this point in the history
5 rk45 jax solver returning nan solution (#8)

* add global USE_JAX environment variable flag

* implement jax helper ode functions

* provide jax jit compilation

* add JAX install section

* remove unused import

* remove timing script

* move jax flags to class level

* correct type casting for jax support

* use np clip exclusively in dynamics step

* expose jax in example entities

* expose jax to point mass integrators

* use np.clip over jnp.clip
  • Loading branch information
jamie-cunningham authored Oct 25, 2024
1 parent 5c04f4e commit 1b4a52d
Show file tree
Hide file tree
Showing 26 changed files with 2,193 additions and 1,065 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@ The following command will install `safe-autonomy-simulation` into your local en
pip install safe-autonomy-simulation
```

### JAX support (experimental)

The `safe-autonomy-simulation` package supports numerical computation acceleration via the [JAX](https://jax.readthedocs.io/en/latest/index.html) library. This is an experimental feature.

The `safe_autonomy_simulation` package comes installed with JAX for CPU acceleration. JAX also provides GPU acceleration for numerical computing. If you'd like this feature the easiest way to install it is using `pip`:

```shell
pip install jax[cuda12]
```

For more information on installing JAX see the [official documentation](https://jax.readthedocs.io/en/latest/installation.html).

JAX can be enabled on a per-class basis for classes that support it. As JAX is an experimental feature it is not enabled by default.

### Installing from source

Alternatively, `safe-autonomy-simulation` can be installed from source using any of the following methods. Again, it is recommended to install this package in an isolated virtual environment. The following sections describe how to install `safe-autonomy-simulation` from source in an isolated virtual environment using `poetry`, `conda`, and `pip + virtualenv`.
Expand Down
1,720 changes: 983 additions & 737 deletions poetry.lock

Large diffs are not rendered by default.

7 changes: 3 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@ pint = "^0.23"
scipy = "^1.13.0"
scikit-learn = "^1.5.0"
typing-extensions = "^4.12.2"
jax = {extras = ["cpu"], version = "^0.4.30", optional = true}

[tool.poetry.extras]
jax = ["jax"]
jax = {extras = ["cpu"], version = "^0.4.30"}
diffrax = {version = "^0.6.0"}

[tool.poetry.group.test]
optional = true
Expand All @@ -37,6 +35,7 @@ optional = true
pytest = "^8.0.0"
pytest-cov = "4.1.0"
pyinstrument = "^4.3.0"
pytest-mock = "^3.14.0"

[tool.poetry.group.docs]
optional = true
Expand Down
2 changes: 2 additions & 0 deletions safe_autonomy_simulation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
controls,
sims,
utils,
jax,
)
from safe_autonomy_simulation.dynamics import Dynamics
from safe_autonomy_simulation.entities import Entity
Expand All @@ -32,4 +33,5 @@
"controls",
"sims",
"utils",
"jax",
]
6 changes: 1 addition & 5 deletions safe_autonomy_simulation/controls/control_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
import typing
import queue
import warnings

try:
import jax.numpy as np
except ImportError:
import numpy as np
import numpy as np


class ControlQueue:
Expand Down
6 changes: 1 addition & 5 deletions safe_autonomy_simulation/controls/no_control.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
"""A default control queue with an empty default control vector."""

import numpy as np
import safe_autonomy_simulation.controls.control_queue as c

try:
import jax.numpy as np
except ImportError:
import numpy as np


class NoControl(c.ControlQueue):
"""A control queue with an empty default control vector.
Expand Down
2 changes: 2 additions & 0 deletions safe_autonomy_simulation/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
ControlAffineODEDynamics,
LinearODEDynamics,
)
import safe_autonomy_simulation.dynamics.utils as utils

__all__ = [
"Dynamics", # base dynamics model
"PassThroughDynamics", # dynamics model that passes through control inputs
"ODEDynamics", # base dynamics model that integrates ordinary differential equations
"ControlAffineODEDynamics", # control-affine ODE dynamics model
"LinearODEDynamics", # linear ODE dynamics model
"utils", # utility functions for dynamics models
]
33 changes: 22 additions & 11 deletions safe_autonomy_simulation/dynamics/dynamics.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
"""Base class for state transition dynamics models of entities in the simulation environment."""

import typing

try:
import jax.numpy as np
except ImportError:
import numpy as np
import numpy as np
import jax.numpy as jnp


class Dynamics:
Expand All @@ -25,19 +22,26 @@ class Dynamics:
When a float, represents single limit applied to entire state vector.
When an ndarray, each element represents the limit to the corresponding state vector element.
By default, np.inf
use_jax : bool, optional
EXPERIMENTAL: Use JAX to accelerate state transition computation, by default False.
"""

def __init__(
self,
state_min: typing.Union[float, np.ndarray] = -np.inf,
state_max: typing.Union[float, np.ndarray] = np.inf,
use_jax: bool = False,
):
self.np = np
self.state_min = state_min
self.state_max = state_max
self.use_jax = use_jax
self.np = jnp if self.use_jax else np
self.state_min = self.np.copy(state_min)
self.state_max = self.np.copy(state_max)

def step(
self, step_size: float, state: np.ndarray, control: np.ndarray
self,
step_size: float,
state: np.ndarray,
control: np.ndarray,
) -> typing.Tuple[np.ndarray, np.ndarray]:
"""
Computes the dynamics state transition from the current state and control input.
Expand All @@ -57,11 +61,18 @@ def step(
Tuple of the system's next state and the state's instantaneous time derivative at the end of the step
"""
next_state, state_dot = self._step(step_size, state, control)
next_state = self.np.clip(next_state, self.state_min, self.state_max)
next_state = np.clip(
next_state,
self.state_min,
self.state_max,
)
return next_state, state_dot

def _step(
self, step_size: float, state: np.ndarray, control: np.ndarray
self,
step_size: float,
state: np.ndarray,
control: np.ndarray,
) -> typing.Tuple[np.ndarray, np.ndarray]:
"""Computes the next state and state derivative of the system
Expand Down
Loading

0 comments on commit 1b4a52d

Please sign in to comment.