Equinox v0.10.10
Performance improvements
These are the real highlight of this release.
equinox.internal.{while_loop, scan}
now use new symbolic zero functionality, which may result in runtime speedups (and slight increases in compile times) as they can now skip calculating gradients for some quantities.equinox.internal.{while_loop, scan}(..., buffers=...)
now do their best to work around an XLA bug (jax-ml/jax#10197). This can reduce computational cost from quadratic scaling to linear scaling.equinox.internal.{while_loop, scan}
now includes several optimisations for the common case is which every step is checkpointed. (#415)
Features
-
equinox.filter_custom_{jvp,vjp}
now support symbolic zeros.Previously,
None
was passed to represent symbolic zero tangent/cotangents for anything that wasn't a floating-point array -- but all floating-point-arrays always had materialised tangent/cotangents.With this release,
None
may also sometimes be passed as the tangent of floating-point arrays. In this case it represents a zero tangent/cotangent, and moreover this zero is "symbolic" -- that is to say it is known to be zero at compile time, which may allow you to write more-efficient custom JVP/VJP rules. (The canonical example is the inverse function theorem -- this involves a linear solve, parts of which you can skip if you know parts of it are zero.)In addition,
filter_custom_vjp
now takes another argument,perturbed
, indicating whether a value actually needs cotangents calculated for it. You can skip calculating cotangents for anything that is not perturbed.For more information see
jax.custom_jvp.defjvp(..., symbolic_zeros=True)
andjax.custom_vjp.defvjp(..., symbolic_zeros=True)
, which provide the underlying behaviour that is being forwarded.Note that this is provided through a new API:
filter_custom_jvp.def_jvp
instead offilter_custom_jvp.defjvp
, andfilter_custom_vjp.{def_fwd, def_bwd}
instead offilter_custom_vjp.defvjp
. The old API will continue to exhibit the previous behaviour, for backward compatibility.
Misc
- Apply functools.wraps to Module methods to preserve docstrings (Thanks @bowlingmh! #409)
- Enumerations now perform their checks at compile time if possible. This sometimes makes it possible to get more efficent code, by special-casing on these values or eliding branches. (#417)
New Contributors
- @bowlingmh made their first contribution in #409
Full Changelog: v0.10.6...v0.10.10
(Why no v0.10.{7,8,9}? We had a bit of a rocky release this time around, and these got yanked for having bugs. Thanks to everyone who reported issues so quickly! Things look like they're stable now...)