Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

5 rk45 jax solver returning nan solution #8

Merged
merged 13 commits into from
Oct 25, 2024
Merged
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,34 @@ The following command will install `safe-autonomy-simulation` into your local en
pip install safe-autonomy-simulation
```

### Installing with JAX (experimental)

The `safe-autonomy-simulation` package supports numerical computation acceleration via the [JAX](https://jax.readthedocs.io/en/latest/index.html) library. This is an experimental feature.

If you want to use JAX acceleration you can install JAX with `safe-autonomy-simulation` using `pip`:

```shell
pip install safe-autonomy-simulation[jax]
```

This will install JAX into your local environment for CPU acceleration. JAX also provides GPU acceleration for numerical computing. If you'd like this feature the easiest way to install it is using `pip`:

```shell
pip install jax[cuda12]
```

For more information on installing JAX see the [official documentation](https://jax.readthedocs.io/en/latest/installation.html).

JAX can be enabled on a per-class basis for classes that support it. As JAX is an experimental feature it is not enabled by default.

You can additionally force `safe-autonomy-simulation` to not use JAX at all by setting the following environment variable:

```shell
USE_JAX="0"
```

This will turn off JAX globally for `safe-autonomy-simulation` even if a class requests it.

### Installing from source

Alternatively, `safe-autonomy-simulation` can be installed from source using any of the following methods. Again, it is recommended to install this package in an isolated virtual environment. The following sections describe how to install `safe-autonomy-simulation` from source in an isolated virtual environment using `poetry`, `conda`, and `pip + virtualenv`.
Expand Down
989 changes: 572 additions & 417 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ scipy = "^1.13.0"
scikit-learn = "^1.5.0"
typing-extensions = "^4.12.2"
jax = {extras = ["cpu"], version = "^0.4.30", optional = true}
diffrax = {version = "^0.6.0", optional = true}

[tool.poetry.extras]
jax = ["jax"]
jax = ["jax", "diffrax"]

[tool.poetry.group.test]
optional = true
Expand Down
7 changes: 7 additions & 0 deletions safe_autonomy_simulation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
continuous simulations in python.
"""

from safe_autonomy_simulation.opts import jax_available
from safe_autonomy_simulation import (
dynamics,
entities,
Expand Down Expand Up @@ -32,4 +33,10 @@
"controls",
"sims",
"utils",
"jax_available",
]


if jax_available():
from safe_autonomy_simulation import jax
__all__.append("jax")
6 changes: 1 addition & 5 deletions safe_autonomy_simulation/controls/control_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
import typing
import queue
import warnings

try:
import jax.numpy as np
except ImportError:
import numpy as np
import numpy as np


class ControlQueue:
Expand Down
6 changes: 1 addition & 5 deletions safe_autonomy_simulation/controls/no_control.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
"""A default control queue with an empty default control vector."""

import numpy as np
import safe_autonomy_simulation.controls.control_queue as c

try:
import jax.numpy as np
except ImportError:
import numpy as np


class NoControl(c.ControlQueue):
"""A control queue with an empty default control vector.
Expand Down
2 changes: 2 additions & 0 deletions safe_autonomy_simulation/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
ControlAffineODEDynamics,
LinearODEDynamics,
)
import safe_autonomy_simulation.dynamics.utils as utils

__all__ = [
"Dynamics", # base dynamics model
"PassThroughDynamics", # dynamics model that passes through control inputs
"ODEDynamics", # base dynamics model that integrates ordinary differential equations
"ControlAffineODEDynamics", # control-affine ODE dynamics model
"LinearODEDynamics", # linear ODE dynamics model
"utils", # utility functions for dynamics models
]
38 changes: 28 additions & 10 deletions safe_autonomy_simulation/dynamics/dynamics.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Base class for state transition dynamics models of entities in the simulation environment."""

import typing
import numpy as np
import safe_autonomy_simulation
from safe_autonomy_simulation.utils import cast_jax

try:
import jax.numpy as np
except ImportError:
import numpy as np
if safe_autonomy_simulation.jax_available():
import jax.numpy as jnp


class Dynamics:
Expand All @@ -25,19 +26,26 @@ class Dynamics:
When a float, represents single limit applied to entire state vector.
When an ndarray, each element represents the limit to the corresponding state vector element.
By default, np.inf
use_jax : bool, optional
EXPERIMENTAL: Use JAX to accelerate state transition computation, by default False.
"""

def __init__(
self,
state_min: typing.Union[float, np.ndarray] = -np.inf,
state_max: typing.Union[float, np.ndarray] = np.inf,
use_jax: bool = False,
):
self.np = np
self.state_min = state_min
self.state_max = state_max
self.use_jax = safe_autonomy_simulation.jax_available and use_jax
self.np = jnp if self.use_jax else np
self.state_min = self.np.copy(state_min)
self.state_max = self.np.copy(state_max)

def step(
self, step_size: float, state: np.ndarray, control: np.ndarray
self,
step_size: float,
state: np.ndarray,
control: np.ndarray,
) -> typing.Tuple[np.ndarray, np.ndarray]:
"""
Computes the dynamics state transition from the current state and control input.
Expand All @@ -56,12 +64,22 @@ def step(
Tuple[np.ndarray, np.ndarray]
Tuple of the system's next state and the state's instantaneous time derivative at the end of the step
"""
state = cast_jax(state, use_jax=self.use_jax)
control = cast_jax(control, use_jax=self.use_jax)
next_state, state_dot = self._step(step_size, state, control)
keatincf marked this conversation as resolved.
Show resolved Hide resolved
next_state = self.np.clip(next_state, self.state_min, self.state_max)

next_state = self.np.clip(
keatincf marked this conversation as resolved.
Show resolved Hide resolved
cast_jax(next_state, use_jax=self.use_jax),
self.state_min,
self.state_max,
)
return next_state, state_dot

def _step(
self, step_size: float, state: np.ndarray, control: np.ndarray
self,
step_size: float,
state: np.ndarray,
control: np.ndarray,
) -> typing.Tuple[np.ndarray, np.ndarray]:
"""Computes the next state and state derivative of the system

Expand Down
Loading
Loading