Skip to content

Commit

Permalink
Faster quaternion multiplication (#23)
Browse files Browse the repository at this point in the history
* Much faster JIT for quaternion multiplication

* minor

* add has_aux test

* fix has_aux test
  • Loading branch information
brentyi authored Nov 28, 2024
1 parent 41da742 commit a6e8adc
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 38 deletions.
19 changes: 3 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Where each group supports:
**`adjoint()`**, **`apply()`**, **`multiply()`**, **`inverse()`**,
**`identity()`**, **`from_matrix()`**, and **`as_matrix()`** operations. (see
[./examples/se3_example.py](./examples/se3_basics.py))
- Taylor approximations near singularities.
- Helpers for optimization on manifolds (see
[./examples/se3_optimization.py](./examples/se3_optimization.py),
<code>jaxlie.<strong>manifold.\*</strong></code>).
Expand All @@ -78,24 +79,10 @@ pip install jaxlie

---

### In the wild

- [jaxfg](https://github.com/brentyi/jaxfg) applies `jaxlie` to nonlinear least
squares problems with block-sparse structure. (for pose graph optimization,
bundle adjustment, etc)
- [tensorf-jax](https://github.com/brentyi/tensorf-jax) is an unofficial
implementation of
[Tensorial Radiance Fields (Chen et al, ECCV 2022)](https://apchenstu.github.io/TensoRF/)
using `jaxlie`.
![Render of a lego](https://github.com/brentyi/tensorf-jax/raw/main/lego_render.gif)

---

### Misc

`jaxlie` was originally written for our IROS 2021 paper
([link](https://github.com/brentyi/dfgo)). If it's useful for you, you're
welcome to cite:
`jaxlie` was originally written when I was learning about Lie groups for our IROS 2021 paper
([link](https://github.com/brentyi/dfgo)):

```
@inproceedings{yi2021iros,
Expand Down
4 changes: 2 additions & 2 deletions jaxlie/_se2.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def exp(cls, tangent: hints.Array) -> "SE2":
jax.Array,
jnp.where(
use_taylor,
jnp.ones_like(theta), # Any non-zero value should do here.
1.0, # Any non-zero value should do here.
theta,
),
)
Expand Down Expand Up @@ -193,7 +193,7 @@ def log(self) -> jax.Array:
# reverse-mode AD.
safe_cos_minus_one = jnp.where(
use_taylor,
jnp.ones_like(cos_minus_one), # Any non-zero value should do here.
1.0, # Any non-zero value should do here.
cos_minus_one,
)

Expand Down
4 changes: 2 additions & 2 deletions jaxlie/_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def exp(cls, tangent: hints.Array) -> SE3:
jax.Array,
jnp.where(
use_taylor,
jnp.ones_like(theta_squared), # Any non-zero value should do here.
1.0, # Any non-zero value should do here.
theta_squared,
),
)
Expand Down Expand Up @@ -173,7 +173,7 @@ def log(self) -> jax.Array:
# reverse-mode AD.
theta_squared_safe = jnp.where(
use_taylor,
jnp.ones_like(theta_squared), # Any non-zero value should do here.
1.0, # Any non-zero value should do here.
theta_squared,
)
del theta_squared
Expand Down
63 changes: 53 additions & 10 deletions jaxlie/_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,16 +311,59 @@ def apply(self, target: hints.Array) -> jax.Array:

@override
def multiply(self, other: SO3) -> SO3:
w0, x0, y0, z0 = jnp.moveaxis(self.wxyz, -1, 0)
w1, x1, y1, z1 = jnp.moveaxis(other.wxyz, -1, 0)
# Original implementation:
#
# w0, x0, y0, z0 = jnp.moveaxis(self.wxyz, -1, 0)
# w1, x1, y1, z1 = jnp.moveaxis(other.wxyz, -1, 0)
# return SO3(
# wxyz=jnp.stack(
# [
# -x0 * x1 - y0 * y1 - z0 * z1 + w0 * w1,
# x0 * w1 + y0 * z1 - z0 * y1 + w0 * x1,
# -x0 * z1 + y0 * w1 + z0 * x1 + w0 * y1,
# x0 * y1 - y0 * x1 + z0 * w1 + w0 * z1,
# ],
# axis=-1,
# )
# )
#
# This is great/fine/standard, but there are a lot of operations. This
# puts a lot of burden on the JIT compiler.
#
# Here's another implementation option. The JIT time is much faster, but the
# runtime is ~10% slower:
#
# inds = jnp.array([0, 1, 2, 3, 1, 0, 3, 2, 2, 3, 0, 1, 3, 2, 1, 0])
# signs = jnp.array([1, -1, -1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, 1, 1])
# return SO3(
# wxyz=jnp.einsum(
# "...ij,...j->...i",
# (self.wxyz[..., inds] * signs).reshape((*self.wxyz.shape, 4)),
# other.wxyz,
# )
# )
#
# For pose graph optimization on the sphere2500 dataset, the following
# speeds up *overall* JIT times by over 35%, without any runtime
# penalties.

# Hamilton product constants.
terms_i = jnp.array([[0, 1, 2, 3], [0, 1, 2, 3], [0, 2, 3, 1], [0, 3, 1, 2]])
terms_j = jnp.array([[0, 1, 2, 3], [1, 0, 3, 2], [2, 0, 1, 3], [3, 0, 2, 1]])
signs = jnp.array(
[
[1, -1, -1, -1],
[1, 1, 1, -1],
[1, 1, 1, -1],
[1, 1, 1, -1],
]
)

# Compute all components at once
q_outer = jnp.einsum("...i,...j->...ij", self.wxyz, other.wxyz)
return SO3(
wxyz=jnp.stack(
[
-x0 * x1 - y0 * y1 - z0 * z1 + w0 * w1,
x0 * w1 + y0 * z1 - z0 * y1 + w0 * x1,
-x0 * z1 + y0 * w1 + z0 * x1 + w0 * y1,
x0 * y1 - y0 * x1 + z0 * w1 + w0 * z1,
],
jnp.sum(
signs * q_outer[..., terms_i, terms_j],
axis=-1,
)
)
Expand All @@ -342,7 +385,7 @@ def exp(cls, tangent: hints.Array) -> SO3:
safe_theta = jnp.sqrt(
jnp.where(
use_taylor,
jnp.ones_like(theta_squared), # Any constant value should do here.
1.0, # Any constant value should do here.
theta_squared,
)
)
Expand Down
2 changes: 1 addition & 1 deletion jaxlie/manifold/_deltas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Helpers for recursively applying tangent-space deltas."""

from typing import Any, Callable, TypeVar, Union, cast, overload
from typing import Any, TypeVar, Union, overload

import jax
import numpy as onp
Expand Down
5 changes: 0 additions & 5 deletions tests/test_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,7 @@
from typing import Tuple, Type

import numpy as onp
from hypothesis import given, settings
from hypothesis import strategies as st
from jax import numpy as jnp
from utils import (
assert_arrays_close,
assert_transforms_close,
general_group_test,
sample_transform,
)
Expand Down
11 changes: 9 additions & 2 deletions tests/test_manifold.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from typing import Tuple, Type

import jax
import jaxlie
import numpy as onp
import pytest
from jax import numpy as jnp
from jax import tree_util

from utils import (
assert_arrays_close,
assert_transforms_close,
Expand All @@ -15,8 +17,6 @@
sample_transform,
)

import jaxlie


@general_group_test
def test_rplus_rminus(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]):
Expand Down Expand Up @@ -68,6 +68,13 @@ def loss(transform: jaxlie.MatrixLieGroup):
transform = Group.exp(sample_transform(Group, batch_axes).log())
original_loss = loss(transform)

assert_arrays_close(
jaxlie.manifold.grad(lambda transform: (loss(transform), None), has_aux=True)(
transform
)[0],
jaxlie.manifold.grad(loss)(transform),
)

@jax.jit
def step(t):
return jaxlie.manifold.rplus(t, -1e-3 * jaxlie.manifold.grad(loss)(t))
Expand Down

0 comments on commit a6e8adc

Please sign in to comment.