Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

5 rk45 jax solver returning nan solution #8

Merged
merged 13 commits into from
Oct 25, 2024
10 changes: 3 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,15 @@ pip install jax[cuda12]

For more information on installing JAX see the [official documentation](https://jax.readthedocs.io/en/latest/installation.html).

As JAX is an experimental feature it is not enabled by default. You can tell `safe-autonomy-simulation` to use JAX by setting the following environment variable:
JAX can be enabled on a per-class basis for classes that support it. As JAX is an experimental feature it is not enabled by default.

```shell
USE_JAX="1"
```

Similarly, JAX can be turned off by setting the environment variable:
You can additionally force `safe-autonomy-simulation` to not use JAX at all by setting the following environment variable:

```shell
USE_JAX="0"
```

If the `USE_JAX` environment variable is not detected, the default behavior of `safe-autonomy-simulation` is to not use JAX.
This will turn off JAX globally for `safe-autonomy-simulation` even if a class requests it.

### Installing from source

Expand Down
13 changes: 7 additions & 6 deletions safe_autonomy_simulation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
continuous simulations in python.
"""

from safe_autonomy_simulation.opts import use_jax
from safe_autonomy_simulation.opts import jax_available
from safe_autonomy_simulation import (
dynamics,
entities,
Expand All @@ -18,9 +18,6 @@
from safe_autonomy_simulation.controls import ControlQueue
from safe_autonomy_simulation.simulator import Simulator

if use_jax():
from safe_autonomy_simulation import jax


__all__ = [
# core classes
Expand All @@ -36,6 +33,10 @@
"controls",
"sims",
"utils",
"use_jax",
"jax",
"jax_available",
]


if jax_available():
from safe_autonomy_simulation import jax
__all__.append("jax")
7 changes: 1 addition & 6 deletions safe_autonomy_simulation/controls/control_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,7 @@
import typing
import queue
import warnings
import safe_autonomy_simulation

if safe_autonomy_simulation.use_jax():
import jax.numpy as np
else:
import numpy as np
import numpy as np


class ControlQueue:
Expand Down
7 changes: 1 addition & 6 deletions safe_autonomy_simulation/controls/no_control.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
"""A default control queue with an empty default control vector."""

import safe_autonomy_simulation
import numpy as np
import safe_autonomy_simulation.controls.control_queue as c

if safe_autonomy_simulation.use_jax():
import jax.numpy as np
else:
import numpy as np


class NoControl(c.ControlQueue):
"""A control queue with an empty default control vector.
Expand Down
37 changes: 27 additions & 10 deletions safe_autonomy_simulation/dynamics/dynamics.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Base class for state transition dynamics models of entities in the simulation environment."""

import typing
import numpy as np
import safe_autonomy_simulation
from safe_autonomy_simulation.utils import cast_jax

if safe_autonomy_simulation.use_jax():
import jax.numpy as np
else:
import numpy as np
if safe_autonomy_simulation.jax_available():
import jax.numpy as jnp


class Dynamics:
Expand All @@ -26,19 +26,26 @@ class Dynamics:
When a float, represents single limit applied to entire state vector.
When an ndarray, each element represents the limit to the corresponding state vector element.
By default, np.inf
use_jax : bool, optional
EXPERIMENTAL: Use JAX to accelerate state transition computation, by default False.
"""

def __init__(
self,
state_min: typing.Union[float, np.ndarray] = -np.inf,
state_max: typing.Union[float, np.ndarray] = np.inf,
use_jax: bool = False,
):
self.np = np
self.state_min = state_min
self.state_max = state_max
self.use_jax = safe_autonomy_simulation.jax_available and use_jax
self.np = jnp if self.use_jax else np
self.state_min = self.np.copy(state_min)
self.state_max = self.np.copy(state_max)

def step(
self, step_size: float, state: np.ndarray, control: np.ndarray
self,
step_size: float,
state: np.ndarray,
control: np.ndarray,
) -> typing.Tuple[np.ndarray, np.ndarray]:
"""
Computes the dynamics state transition from the current state and control input.
Expand All @@ -57,12 +64,22 @@ def step(
Tuple[np.ndarray, np.ndarray]
Tuple of the system's next state and the state's instantaneous time derivative at the end of the step
"""
state = cast_jax(state, use_jax=self.use_jax)
control = cast_jax(control, use_jax=self.use_jax)
next_state, state_dot = self._step(step_size, state, control)
keatincf marked this conversation as resolved.
Show resolved Hide resolved
next_state = self.np.clip(next_state, self.state_min, self.state_max)

next_state = self.np.clip(
keatincf marked this conversation as resolved.
Show resolved Hide resolved
cast_jax(next_state, use_jax=self.use_jax),
self.state_min,
self.state_max,
)
return next_state, state_dot

def _step(
self, step_size: float, state: np.ndarray, control: np.ndarray
self,
step_size: float,
state: np.ndarray,
control: np.ndarray,
) -> typing.Tuple[np.ndarray, np.ndarray]:
"""Computes the next state and state derivative of the system

Expand Down
64 changes: 49 additions & 15 deletions safe_autonomy_simulation/dynamics/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import safe_autonomy_simulation
import safe_autonomy_simulation.dynamics as dynamics
import numpy as np
from safe_autonomy_simulation.utils import cast_jax


class ODEDynamics(dynamics.Dynamics):
Expand Down Expand Up @@ -40,6 +41,8 @@ class ODEDynamics(dynamics.Dynamics):
'RK45' is slow but very accurate. If jax is available, can be JIT compiled for speed.
'Euler' is fast but very inaccurate.
By default, 'RK45'.
use_jax : bool, optional
EXPERIMENTAL: Use JAX to accelerate state transition computation, by default False.
"""

def __init__(
Expand All @@ -50,10 +53,12 @@ def __init__(
state_dot_min: typing.Union[float, np.ndarray] = -np.inf,
state_dot_max: typing.Union[float, np.ndarray] = np.inf,
integration_method: str = "RK45",
use_jax: bool = False,
):
super().__init__(
state_min=state_min,
state_max=state_max,
use_jax=use_jax,
)

assert (
Expand All @@ -64,8 +69,8 @@ def __init__(
]
), f"invalid integration method {integration_method}, must be one of 'RK45', 'Euler'"
self.integration_method = integration_method
self.state_dot_min = state_dot_min
self.state_dot_max = state_dot_max
self.state_dot_min = self.np.copy(state_dot_min)
self.state_dot_max = self.np.copy(state_dot_max)

assert isinstance(
trajectory_samples, int
Expand All @@ -79,7 +84,10 @@ def __init__(
self.trajectory_t = None

def compute_state_dot(
self, t: float, state: np.ndarray, control: np.ndarray
self,
t: float,
state: np.ndarray,
control: np.ndarray,
) -> np.ndarray:
"""
Computes the instantaneous time derivative of the state vector
Expand All @@ -99,48 +107,63 @@ def compute_state_dot(
np.ndarray
Instantaneous time derivative of the state vector.
"""
state = cast_jax(state, use_jax=self.use_jax)
control = cast_jax(control, use_jax=self.use_jax)

# Get clip functions
clip_fn = (
keatincf marked this conversation as resolved.
Show resolved Hide resolved
safe_autonomy_simulation.dynamics.utils.clip_state_dot
if not safe_autonomy_simulation.use_jax()
if not self.use_jax
else safe_autonomy_simulation.jax.ode.clip_state_dot
)
clip_at_state_limits_fn = (
safe_autonomy_simulation.dynamics.utils.clip_state_dot_at_state_limits
if not safe_autonomy_simulation.use_jax()
if not self.use_jax
else safe_autonomy_simulation.jax.ode.clip_state_dot_at_state_limits
)
# Compute state derivative
state_dot = self._compute_state_dot(t, state, control)
# Clip state derivative values
state_dot = clip_fn(
state_dot=state_dot, s_min=self.state_dot_min, s_max=self.state_dot_max
state_dot=state_dot,
s_min=self.state_dot_min,
s_max=self.state_dot_max,
)
# Clip state derivative values to ensure state remains within bounds
state_dot = clip_at_state_limits_fn(
keatincf marked this conversation as resolved.
Show resolved Hide resolved
state=state,
state_dot=state_dot,
s_min=self.state_min,
s_max=self.state_max,
s_min=self.state_dot_min,
s_max=self.state_dot_max,
)
return state_dot

def _compute_state_dot(
self, t: float, state: np.ndarray, control: np.ndarray
self,
t: float,
state: np.ndarray,
control: np.ndarray,
) -> np.ndarray:
raise NotImplementedError

def _step(self, step_size: float, state: np.ndarray, control: np.ndarray):
def _step(
self,
step_size: float,
state: np.ndarray,
control: np.ndarray,
):
state = cast_jax(state, use_jax=self.use_jax)
control = cast_jax(control, use_jax=self.use_jax)
if self.integration_method == "RK45":
step_fn = (
safe_autonomy_simulation.dynamics.utils.step_rk45
if not safe_autonomy_simulation.use_jax()
if not self.use_jax
else safe_autonomy_simulation.jax.ode.step_rk45
)
else:
step_fn = (
safe_autonomy_simulation.dynamics.utils.step_euler
if not safe_autonomy_simulation.use_jax()
if not self.use_jax
else safe_autonomy_simulation.jax.ode.step_euler
)
next_state, state_dot, self.trajectory, self.trajectory_t = step_fn(
Expand Down Expand Up @@ -192,10 +215,17 @@ class ControlAffineODEDynamics(ODEDynamics):
By default, 'RK45'.
"""

def _compute_state_dot(self, t: float, state: np.ndarray, control: np.ndarray):
def _compute_state_dot(
self,
t: float,
state: np.ndarray,
control: np.ndarray,
):
state = cast_jax(state, use_jax=self.use_jax)
control = cast_jax(control, use_jax=self.use_jax)
transition_fn = (
safe_autonomy_simulation.dynamics.utils.affine_transition
if not safe_autonomy_simulation.use_jax()
if not self.use_jax
else safe_autonomy_simulation.jax.ode.affine_transition
)
state_dot = transition_fn(
Expand Down Expand Up @@ -280,6 +310,8 @@ class LinearODEDynamics(ControlAffineODEDynamics):
'RK45' is slow but very accurate. If jax is available, can be JIT compiled for speed.
'Euler' is fast but very inaccurate.
By default, 'RK45'.
use_jax : bool, optional
EXPERIMENTAL: Use JAX to accelerate state transition computation, by default False.
"""

def __init__(
Expand All @@ -292,6 +324,7 @@ def __init__(
state_dot_min: typing.Union[float, np.ndarray] = -np.inf,
state_dot_max: typing.Union[float, np.ndarray] = np.inf,
integration_method: str = "RK45",
use_jax: bool = False
):
super().__init__(
trajectory_samples=trajectory_samples,
Expand All @@ -300,6 +333,7 @@ def __init__(
state_dot_min=state_dot_min,
state_dot_max=state_dot_max,
integration_method=integration_method,
use_jax=use_jax
)

assert len(A.shape) == 2, f"A must be a 2D matrix. Instead got shape {A.shape}"
Expand All @@ -316,7 +350,7 @@ def __init__(
self.B = self.np.copy(B)

def state_transition_system(self, state: np.ndarray) -> np.ndarray:
return self.A @ state
return self.A @ cast_jax(state, use_jax=self.use_jax)

def state_transition_input(self, state: np.ndarray) -> np.ndarray:
return self.B
2 changes: 1 addition & 1 deletion safe_autonomy_simulation/dynamics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def clip_state_dot_at_state_limits(
upper_bounded_states = state >= s_max

lower_bounded_clipped = np.clip(state_dot, 0, np.inf)
upper_bounded_clipped = np.clip(state_dot, np.inf, 0)
upper_bounded_clipped = np.clip(state_dot, -np.inf, 0)

state_dot = np.where(lower_bounded_states, lower_bounded_clipped, state_dot)
state_dot = np.where(upper_bounded_states, upper_bounded_clipped, state_dot)
Expand Down
7 changes: 1 addition & 6 deletions safe_autonomy_simulation/entities/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,12 @@

import typing
import typing_extensions
import safe_autonomy_simulation
import numpy as np
import safe_autonomy_simulation.dynamics as d
import safe_autonomy_simulation.materials as materials
import safe_autonomy_simulation.controls.control_queue as controls
import safe_autonomy_simulation.utils.sets as sets

if safe_autonomy_simulation.use_jax():
import jax.numpy as np
else:
import numpy as np


class Entity:
"""
Expand Down
Loading
Loading