Skip to content

Equinox v0.10.10

Compare
Choose a tag to compare
@github-actions github-actions released this 11 Jul 15:27
· 289 commits to main since this release

Performance improvements

These are the real highlight of this release.

  • equinox.internal.{while_loop, scan} now use new symbolic zero functionality, which may result in runtime speedups (and slight increases in compile times) as they can now skip calculating gradients for some quantities.
  • equinox.internal.{while_loop, scan}(..., buffers=...) now do their best to work around an XLA bug (jax-ml/jax#10197). This can reduce computational cost from quadratic scaling to linear scaling.
  • equinox.internal.{while_loop, scan} now includes several optimisations for the common case is which every step is checkpointed. (#415)

Features

  • equinox.filter_custom_{jvp,vjp} now support symbolic zeros.

    Previously, None was passed to represent symbolic zero tangent/cotangents for anything that wasn't a floating-point array -- but all floating-point-arrays always had materialised tangent/cotangents.

    With this release, None may also sometimes be passed as the tangent of floating-point arrays. In this case it represents a zero tangent/cotangent, and moreover this zero is "symbolic" -- that is to say it is known to be zero at compile time, which may allow you to write more-efficient custom JVP/VJP rules. (The canonical example is the inverse function theorem -- this involves a linear solve, parts of which you can skip if you know parts of it are zero.)

    In addition, filter_custom_vjp now takes another argument, perturbed, indicating whether a value actually needs cotangents calculated for it. You can skip calculating cotangents for anything that is not perturbed.

    For more information see jax.custom_jvp.defjvp(..., symbolic_zeros=True) and jax.custom_vjp.defvjp(..., symbolic_zeros=True), which provide the underlying behaviour that is being forwarded.

    Note that this is provided through a new API: filter_custom_jvp.def_jvp instead of filter_custom_jvp.defjvp, and filter_custom_vjp.{def_fwd, def_bwd} instead of filter_custom_vjp.defvjp. The old API will continue to exhibit the previous behaviour, for backward compatibility.

Misc

  • Apply functools.wraps to Module methods to preserve docstrings (Thanks @bowlingmh! #409)
  • Enumerations now perform their checks at compile time if possible. This sometimes makes it possible to get more efficent code, by special-casing on these values or eliding branches. (#417)

New Contributors

Full Changelog: v0.10.6...v0.10.10

(Why no v0.10.{7,8,9}? We had a bit of a rocky release this time around, and these got yanked for having bugs. Thanks to everyone who reported issues so quickly! Things look like they're stable now...)