Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

5 rk45 jax solver returning nan solution #8

Merged
merged 13 commits into from
Oct 25, 2024
Merged
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@ The following command will install `safe-autonomy-simulation` into your local en
pip install safe-autonomy-simulation
```

### JAX support (experimental)

The `safe-autonomy-simulation` package supports numerical computation acceleration via the [JAX](https://jax.readthedocs.io/en/latest/index.html) library. This is an experimental feature.

The `safe_autonomy_simulation` package comes installed with JAX for CPU acceleration. JAX also provides GPU acceleration for numerical computing. If you'd like this feature the easiest way to install it is using `pip`:

```shell
pip install jax[cuda12]
```

For more information on installing JAX see the [official documentation](https://jax.readthedocs.io/en/latest/installation.html).

JAX can be enabled on a per-class basis for classes that support it. As JAX is an experimental feature it is not enabled by default.

### Installing from source

Alternatively, `safe-autonomy-simulation` can be installed from source using any of the following methods. Again, it is recommended to install this package in an isolated virtual environment. The following sections describe how to install `safe-autonomy-simulation` from source in an isolated virtual environment using `poetry`, `conda`, and `pip + virtualenv`.
Expand Down
1,720 changes: 983 additions & 737 deletions poetry.lock

Large diffs are not rendered by default.

7 changes: 3 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@ pint = "^0.23"
scipy = "^1.13.0"
scikit-learn = "^1.5.0"
typing-extensions = "^4.12.2"
jax = {extras = ["cpu"], version = "^0.4.30", optional = true}

[tool.poetry.extras]
jax = ["jax"]
jax = {extras = ["cpu"], version = "^0.4.30"}
diffrax = {version = "^0.6.0"}

[tool.poetry.group.test]
optional = true
Expand All @@ -37,6 +35,7 @@ optional = true
pytest = "^8.0.0"
pytest-cov = "4.1.0"
pyinstrument = "^4.3.0"
pytest-mock = "^3.14.0"

[tool.poetry.group.docs]
optional = true
Expand Down
2 changes: 2 additions & 0 deletions safe_autonomy_simulation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
controls,
sims,
utils,
jax,
)
from safe_autonomy_simulation.dynamics import Dynamics
from safe_autonomy_simulation.entities import Entity
Expand All @@ -32,4 +33,5 @@
"controls",
"sims",
"utils",
"jax",
]
6 changes: 1 addition & 5 deletions safe_autonomy_simulation/controls/control_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
import typing
import queue
import warnings

try:
import jax.numpy as np
except ImportError:
import numpy as np
import numpy as np


class ControlQueue:
Expand Down
6 changes: 1 addition & 5 deletions safe_autonomy_simulation/controls/no_control.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
"""A default control queue with an empty default control vector."""

import numpy as np
import safe_autonomy_simulation.controls.control_queue as c

try:
import jax.numpy as np
except ImportError:
import numpy as np


class NoControl(c.ControlQueue):
"""A control queue with an empty default control vector.
Expand Down
2 changes: 2 additions & 0 deletions safe_autonomy_simulation/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
ControlAffineODEDynamics,
LinearODEDynamics,
)
import safe_autonomy_simulation.dynamics.utils as utils

__all__ = [
"Dynamics", # base dynamics model
"PassThroughDynamics", # dynamics model that passes through control inputs
"ODEDynamics", # base dynamics model that integrates ordinary differential equations
"ControlAffineODEDynamics", # control-affine ODE dynamics model
"LinearODEDynamics", # linear ODE dynamics model
"utils", # utility functions for dynamics models
]
33 changes: 22 additions & 11 deletions safe_autonomy_simulation/dynamics/dynamics.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
"""Base class for state transition dynamics models of entities in the simulation environment."""

import typing

try:
import jax.numpy as np
except ImportError:
import numpy as np
import numpy as np
import jax.numpy as jnp


class Dynamics:
Expand All @@ -25,19 +22,26 @@ class Dynamics:
When a float, represents single limit applied to entire state vector.
When an ndarray, each element represents the limit to the corresponding state vector element.
By default, np.inf
use_jax : bool, optional
EXPERIMENTAL: Use JAX to accelerate state transition computation, by default False.
"""

def __init__(
self,
state_min: typing.Union[float, np.ndarray] = -np.inf,
state_max: typing.Union[float, np.ndarray] = np.inf,
use_jax: bool = False,
):
self.np = np
self.state_min = state_min
self.state_max = state_max
self.use_jax = use_jax
self.np = jnp if self.use_jax else np
self.state_min = self.np.copy(state_min)
self.state_max = self.np.copy(state_max)

def step(
self, step_size: float, state: np.ndarray, control: np.ndarray
self,
step_size: float,
state: np.ndarray,
control: np.ndarray,
) -> typing.Tuple[np.ndarray, np.ndarray]:
"""
Computes the dynamics state transition from the current state and control input.
Expand All @@ -57,11 +61,18 @@ def step(
Tuple of the system's next state and the state's instantaneous time derivative at the end of the step
"""
next_state, state_dot = self._step(step_size, state, control)
keatincf marked this conversation as resolved.
Show resolved Hide resolved
next_state = self.np.clip(next_state, self.state_min, self.state_max)
next_state = np.clip(
next_state,
self.state_min,
self.state_max,
)
return next_state, state_dot

def _step(
self, step_size: float, state: np.ndarray, control: np.ndarray
self,
step_size: float,
state: np.ndarray,
control: np.ndarray,
) -> typing.Tuple[np.ndarray, np.ndarray]:
"""Computes the next state and state derivative of the system

Expand Down
Loading
Loading