Skip to content

Commit

Permalink
fix(jax): don't jit compile external function calls
Browse files Browse the repository at this point in the history
dont jit compile external function (#12)
  • Loading branch information
jamie-cunningham authored Oct 25, 2024
1 parent db66b3b commit 6f591c3
Show file tree
Hide file tree
Showing 10 changed files with 2,121 additions and 1,671 deletions.
3,761 changes: 2,103 additions & 1,658 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ packages = [
[tool.poetry.dependencies]
python = ">=3.9, <3.13"
scipy = ">=1.10.0"
jax = { version = "0.4.26", extras = ["cpu"] }
jax = { version = "^0.4.30", extras = ["cpu"] }
numpy = "^1.23.5"
matplotlib = "^3.8.0"
quadprog = "^0.1.11"
pydantic = "^2.8.0"
safe-autonomy-simulation = "^2.0.12"
safe-autonomy-simulation = "^2.1.2"
pyqt6 = "^6.7.1"

[tool.poetry.group.lint.dependencies]
Expand Down
7 changes: 5 additions & 2 deletions run_time_assurance/zoo/cwh/docking_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,16 @@ def _setup_docking_properties(
A=A,
B=B,
integration_method=integration_method,
use_jax=True,
)

assert integration_method in (
"RK45",
"Euler",
), f"Invalid integration method {integration_method}, must be 'RK45' or 'Euler'"

jit_compile_dict.setdefault("pred_state", True)
jit_compile_dict.setdefault("integrate", True)
# jit_compile_dict.setdefault("pred_state", True)
# jit_compile_dict.setdefault("integrate", True)

def _setup_docking_constraints(
self,
Expand Down Expand Up @@ -144,6 +145,8 @@ def _docking_pred_state(
self, state: jnp.ndarray, step_size: float, control: jnp.ndarray
) -> jnp.ndarray:
"""Predicts the next state given the current state and control action"""
state = np.array(state)
control = np.array(control)
out, _ = self.dynamics.step(step_size, state, control)
return out

Expand Down
5 changes: 3 additions & 2 deletions run_time_assurance/zoo/cwh/docking_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,16 @@ def _setup_docking_properties(
A=A,
B=B,
integration_method=integration_method,
use_jax=True
)

assert integration_method in (
"Euler",
"RK45",
), f"Invalid integration method {integration_method}, must be 'Euler' or 'RK45'"

jit_compile_dict.setdefault("pred_state", True)
jit_compile_dict.setdefault("integrate", True)
# jit_compile_dict.setdefault("pred_state", True)
# jit_compile_dict.setdefault("integrate", True)

def _setup_docking_constraints(
self,
Expand Down
2 changes: 1 addition & 1 deletion run_time_assurance/zoo/cwh/random_sample_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
bounds = np.array([[-10000, -10000, -5, -5], [10000, 10000, 5, 5]])

A, B = generate_cwh_matrices(defaults.M_DEFAULT, defaults.N_DEFAULT, mode="2d")
self.dynamics = dynamics.LinearODEDynamics(A=A, B=B, integration_method="RK45")
self.dynamics = dynamics.LinearODEDynamics(A=A, B=B, integration_method="RK45", use_jax=True)

super().__init__(
rta=rta,
Expand Down
5 changes: 3 additions & 2 deletions run_time_assurance/zoo/integrators/integrator_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def _setup_docking_properties(
m=m,
mode="1d",
integration_method=integration_method,
use_jax=True
)
self.A = jnp.array(self.dynamics.A)
self.B = jnp.array(self.dynamics.B)
Expand All @@ -70,8 +71,8 @@ def _setup_docking_properties(
integration_method in ("RK45", "Euler")
), f"Invalid integration method {integration_method}, must be either 'RK45' or 'Euler'"

jit_compile_dict.setdefault("pred_state", True)
jit_compile_dict.setdefault("integrate", True)
# jit_compile_dict.setdefault("pred_state", True)
# jit_compile_dict.setdefault("integrate", True)

def _setup_docking_constraints_explicit(self) -> OrderedDict:
"""generates explicit constraints used in the docking problem"""
Expand Down
2 changes: 1 addition & 1 deletion test/cwh_inspection/test_inspection_1v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, rta, constraint_keys=[], step_size=1, **kwargs):
self.rta.constraints = new_constraints

A, B = generate_cwh_matrices(M_DEFAULT, N_DEFAULT, mode="3d")
self.dynamics = LinearODEDynamics(A=A, B=B, integration_method='RK45')
self.dynamics = LinearODEDynamics(A=A, B=B, integration_method='RK45', use_jax=True)

# Specify LQR gains
Q = np.eye(6) * 0.05 # State cost
Expand Down
2 changes: 1 addition & 1 deletion test/docking_rta/test_docking_rta_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, rta, random_init=False):
self.docking_region = 1 # m

A, B = generate_cwh_matrices(M_DEFAULT, N_DEFAULT, mode="2d")
self.dynamics = LinearODEDynamics(A=A, B=B, integration_method='RK45')
self.dynamics = LinearODEDynamics(A=A, B=B, integration_method='RK45', use_jax=True)

# Specify LQR gains
Q = np.eye(4) * 0.05 # State cost
Expand Down
2 changes: 1 addition & 1 deletion test/docking_rta/test_docking_rta_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, rta, random_init=False):
self.docking_region = 1 # m

A, B = generate_cwh_matrices(M_DEFAULT, N_DEFAULT, mode="3d")
self.dynamics = LinearODEDynamics(A=A, B=B, integration_method='RK45')
self.dynamics = LinearODEDynamics(A=A, B=B, integration_method='RK45', use_jax=True)

# Specify LQR gains
Q = np.eye(6) * 0.05 # State cost
Expand Down
2 changes: 1 addition & 1 deletion test/integrator_rta/test_integrator_rta_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, rta):
self.docking_region = 0.1
self.docking_max_vel = 0.1
self.dynamics = PointMassIntegratorDynamics(
m=M_DEFAULT, mode="1d", integration_method="RK45"
m=M_DEFAULT, mode="1d", integration_method="RK45", use_jax=True
)

super().__init__(
Expand Down

0 comments on commit 6f591c3

Please sign in to comment.