You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Jax could be an interesting framework to use with its autodiff and just-in-time compiling capabilities. I personally thought there was value in having a purely numpy-based package to begin with, in terms of having lower barrier to entry (many people in our lab group were familiar with python/numpy/lie groups but not with Jax).
There's potentially a solution that mixes jax with the current code, where jax is used for JIT-able state/process/measurement model evaluations along with autodiff Jacobians, but the output of the models is converted back to regular numpy arrays for use by the same filter implementations as there is now. While this would not produce end-to-end differentiable/compilable code, it would speed up evaluation/jacobian calculations compared to the current finite differencing.
To start exploring this, I recommend we define some new jax-based abstract classes for state/process/measurement models. Say JaxState, JaxProcessModel and JaxMeasurementModel. These classes could have the default jacobian implementations use jax's autodiff, and the user is responsible for writing jax-compatible code in the evaluate method.
If you'd like to tackle this. Please reach out :)
The text was updated successfully, but these errors were encountered:
Jax could be an interesting framework to use with its autodiff and just-in-time compiling capabilities. I personally thought there was value in having a purely numpy-based package to begin with, in terms of having lower barrier to entry (many people in our lab group were familiar with python/numpy/lie groups but not with Jax).
There's potentially a solution that mixes jax with the current code, where jax is used for JIT-able state/process/measurement model evaluations along with autodiff Jacobians, but the output of the models is converted back to regular numpy arrays for use by the same filter implementations as there is now. While this would not produce end-to-end differentiable/compilable code, it would speed up evaluation/jacobian calculations compared to the current finite differencing.
To start exploring this, I recommend we define some new jax-based abstract classes for state/process/measurement models. Say
JaxState
,JaxProcessModel
andJaxMeasurementModel
. These classes could have the default jacobian implementations use jax's autodiff, and the user is responsible for writing jax-compatible code in theevaluate
method.If you'd like to tackle this. Please reach out :)
The text was updated successfully, but these errors were encountered: