Releases: patrick-kidger/equinox
Equinox v0.11.1
This is a minor bugfix release.
Bugfixes
- Checkpointed while loops (
eqx.internal.while_loop(..., kind="checkpointed")
) now perform a more careful analysis of which arguments need to be differentiated. (#548) This fix is the primary reason for this release -- it unlocks some efficiency improvements when solving SDEs in Diffrax: patrick-kidger/diffrax#320 - Fixed
Abstract{Class,}Var
misbehaving around multiple inheritance. (#544) - Better compatibility with the beartype library. In a few cases this was throwing some spurious errors to do with forward references. (#543)
Documentation
Other
- Static type checkers should now use Equinox's type hints correctly. (Specfically, we now have the
py.typed
marker file. Thanks @vidhanio! #547) - Added the
EQX_ON_ERROR_BREAKPOINT_FRAMES
environment variable, to work around JAX bug jax-ml/jax#16732 when usingEQX_ON_ERROR=breakpoint
. This new variable sets the number of stack frames you can access via theu
debugger command, when the on-error debugger is triggered. Set this to a small enough number, e.g.EQX_ON_ERROR_BREAKPOINT_FRAMES=1
, and it should fix unusual trace-time errors when usingEQX_ON_ERROR=breakpoint
.
New Contributors
Full Changelog: v0.11.0...v0.11.1
Equinox v0.11.0
Better errors
Equinox now includes several additional checks to guard against various bugs. If you have a new error, then this is probably an indication that your code always had a silent bug, and should be updated.
eqx.nn.LayerNorm
now correctly validates that the shape of its input. This was a common cause of silent bugs. (Thanks @dlwh for pointing this one out!)- Equinox now prints out a warning if you supply both
__init__
and__post_init__
-- the former actually overwrites the latter. (This is normal Python dataclass behaviour, but probably unexpected.) - Equinox now prevents you from assigning Module attributes with a bound method of your current instance, e.g.
Otherwise, you end up with two different copies of your model! One at
class Model(eqx.Module): foo: Callable def __init__(self): self.foo = self.bar def bar(self): ...
self
, the other atself.foo.__self__
. (The latter being in the bound method.) eqx.tree_at
now gives a better error message if you use it try to and update something that isn't a PyTree leaf. (Thanks @LouisDesdoigts!)
API changes
These should all be very minor.
- Breaking change:
eqx.nn.StateIndex
now takes the initial value, rather than a function that returns the initial value. - Breaking change: If using
eqx.field(converter=...)
, then conversion now happens before__post_init__
, rather than after it. - Prefer
eqx.nn.make_with_state
overeqx.nn.State
. The latter will continue to work, but the former is more memory-efficient. (It deletes the original copy of the initial state.) - Prefer
eqx.nn.inference_mode
overeqx.tree_inference
. The latter will continue to exist for backward compatibility. These are the same function, this is really just a matter of moving it into theeqx.nn
namespace where it always belonged.
Sharing layers
Equinox now supports sharing a layer between multiple parts of your model! This has probably been our longest-requested feature -- in large part because of how intractable it seemed. Equinox models are PyTrees, not PyDAGs, so how exactly are we supposed to have two different parts of our model point at the same layer?
The answer turned out to be the following -- in this example, we're reusing the embedding weight matrix between the initial embedding layer, and the final readout layer, of a language model.
class LanguageModel(eqx.Module):
shared: eqx.nn.Shared
def __init__(self):
embedding = eqx.nn.Embedding(...)
linear = eqx.nn.Linear(...)
# These two weights will now be tied together.
where = lambda embed_and_lin: embed_and_lin[1].weight
get = lambda embed_and_lin: embed_and_lin[0].weight
self.shared = eqx.nn.Shared((embedding, linear), where, get)
def __call__(self, tokens):
# Expand back out so we can evaluate these layers.
embedding, linear = self.shared()
assert embedding.weight is linear.weight # same parameter!
# Now go ahead and evaluate your language model.
...
here, eqx.nn.Shared(...)
simply removes all of the nodes at where
, so that we don't have two separate copies. Then when it is called at self.shared()
, it puts them back again. Note that this isn't a copy and doesn't incur any additional memory overhead; this all happens at the Python level, not the XLA level.
(The curious may like to take a look at the implementation in equinox/nn/_shared.py
, which turned out to be very simple.)
On a meta level, I'd like to comment that I'm quite proud of having gotten this one in! It means that Equinox now supports both stateful layers and shared layers, which have always been the two pieces that seemed out of reach when using something as simple as PyTrees to represent models. But it turns out that PyTrees really are all you need. :D
Other changes
Documentation
- Many documentation fixes courtesy of @colehaus and @Artur-Galstyan!
- Added two new examples to the documentation. Thank you to @ahmed-alllam for both of them!
- Deep convolutional GAN
- Vision Transformer
- Added an FAQ entry on comparisons between Equinox and PyTorch/Keras/Julia/Flax. It's a common enough question that should probably have had an answer before now.
- Added an FAQ entry on debugging recompilation.
Features
- Added
eqx.filter_checkpoint
, which as you might expect is a filtered version ofjax.checkpoint
. (Thanks @dlwh!) - Added
eqx.Module.__check_init__
. This is run in a similar fashion to__post_init__
; see the documentation. This can be used to check that invariants of your module hold after initialisation. - Added support for vmap'ing stateful layers, by adding
eqx.nn.State.{substate, update}
. This offers a way to subset or update aState
object, that so only the parts of it that need to be vmap'd are passed in. See the stateful documentation for an example of how to do this. - Runtime error should now produce much more readable results, without any of the terrifying
INTERNAL: Generated function failed: CpuCallback error
stuff! This clean-up of the runtime error message is done byeqx.filter_jit
, so that will need to be your top-level way of JIT'ing your computation. - Added
eqx.nn.StatefulLayer
-- this is (only!) witheqx.nn.Sequential
, to indicate that the layer should be called withx, state
, and not justx
. If you would like a custom stateful layer to be compatible withSequential
then go ahead and subclass this, and potentially implement theis_stateful
method. (Thanks @paganpasta!) - The forward pass of each
eqx.nn.*
layer is now wrapped in ajax.named_scope
, for better debugging experience. (Thanks @ahmed-alllam!) eqx.module_update_wrapper
no longer requires a second argument; it will look at the__wrapped__
attribute of its first argument.- Added
eqx.internal.closure_to_pytree
, for... you guessed it, turning function closures into PyTrees. The closed-over variables are treated as the subnodes in the PyTree. This will operate recursively so that closed-over closures will themselves become PyTrees, etc. Note that closed-over global variables are not included.
Bugfixes
eqx.tree_{serialise,deserialise}_leaves
now correctly handle unusual NumPy scalars, likebfloat16
. (Thanks @colehaus!)eqx.field(metadata=...)
arguments no longer results in thestatic
/converter
arguments being ignored. (Thanks @mjo22!)eqx.filter_custom_vjp
now supports residuals that are not arrays. (The residuals are the pytree that is passed between the forward and backward pass.)eqx.{AbstractVar,AbstractClassVar}
should now support overriden generics in subclasses. That is, something like this:should no longer raise spurious errors under certain conditions.class Foo(eqx.Module): x: eqx.AbstractVar[list[str]] class Bar(Foo): x: list[str]
eqx.internal.while_loop
now supports using custom (non-Equinox) pytrees in the state.eqx.tree_check
no longer raises some false positives.- Equinox modules now support
__init_subclass__
with additional class creation kwargs. (Thanks @ASEM000, @Roger-luo!)
New Contributors
- @homerjed made their first contribution in #445
- @LouisDesdoigts made their first contribution in #460
- @knyazer made their first contribution in #474
Full Changelog: v0.10.11...v0.11.0
Equinox v0.10.11
New features
-
Equinox now offers true runtime errors! This is available as
equinox.error_if
. This is something new under the JAX sun: these are raised eagerly during the execution, they work on TPU, and if you set the environment variableEQX_ON_ERROR=breakpoint
, then they'll even drop you into a debugger as soon as you hit an error. (These are basically a strict improvement overjax.experimental.checkify
, which doesn't offer many of these advantages.) -
Added a suite of debugging tools:
equinox.debug.announce_transform
: prints to stdout when it is transformed via jvp/vmap etc; very useful for keeping track of how many times a particular operation is getting transformed or compiled, when trying to minimise your compilation times.equinox.debug.backward_nan
: for debugging NaNs that only arise on the backward pass.equinox.debug.breakpoint_if
: opens a breakpoint if a condition is satisfied.equinox.debug.{store_dce, inspect_dce}
: used for checking whether certain variables are removed via the dead-code-elimination pass of the XLA compiler.
-
equinox.filter_jvp
now supports keyword arguments (which are treated as not differentiated).
Bugfixes
- Nested
filter_jvp
s will now no longer materialise symbolic zero tangents. (#422).
Documentation
- The marvellous Levanter library is now linked to in the documentation!
Full Changelog: v0.10.10...v0.10.11
Equinox v0.10.10
Performance improvements
These are the real highlight of this release.
equinox.internal.{while_loop, scan}
now use new symbolic zero functionality, which may result in runtime speedups (and slight increases in compile times) as they can now skip calculating gradients for some quantities.equinox.internal.{while_loop, scan}(..., buffers=...)
now do their best to work around an XLA bug (jax-ml/jax#10197). This can reduce computational cost from quadratic scaling to linear scaling.equinox.internal.{while_loop, scan}
now includes several optimisations for the common case is which every step is checkpointed. (#415)
Features
-
equinox.filter_custom_{jvp,vjp}
now support symbolic zeros.Previously,
None
was passed to represent symbolic zero tangent/cotangents for anything that wasn't a floating-point array -- but all floating-point-arrays always had materialised tangent/cotangents.With this release,
None
may also sometimes be passed as the tangent of floating-point arrays. In this case it represents a zero tangent/cotangent, and moreover this zero is "symbolic" -- that is to say it is known to be zero at compile time, which may allow you to write more-efficient custom JVP/VJP rules. (The canonical example is the inverse function theorem -- this involves a linear solve, parts of which you can skip if you know parts of it are zero.)In addition,
filter_custom_vjp
now takes another argument,perturbed
, indicating whether a value actually needs cotangents calculated for it. You can skip calculating cotangents for anything that is not perturbed.For more information see
jax.custom_jvp.defjvp(..., symbolic_zeros=True)
andjax.custom_vjp.defvjp(..., symbolic_zeros=True)
, which provide the underlying behaviour that is being forwarded.Note that this is provided through a new API:
filter_custom_jvp.def_jvp
instead offilter_custom_jvp.defjvp
, andfilter_custom_vjp.{def_fwd, def_bwd}
instead offilter_custom_vjp.defvjp
. The old API will continue to exhibit the previous behaviour, for backward compatibility.
Misc
- Apply functools.wraps to Module methods to preserve docstrings (Thanks @bowlingmh! #409)
- Enumerations now perform their checks at compile time if possible. This sometimes makes it possible to get more efficent code, by special-casing on these values or eliding branches. (#417)
New Contributors
- @bowlingmh made their first contribution in #409
Full Changelog: v0.10.6...v0.10.10
(Why no v0.10.{7,8,9}? We had a bit of a rocky release this time around, and these got yanked for having bugs. Thanks to everyone who reported issues so quickly! Things look like they're stable now...)
Equinox v0.10.6
Features
- Added
eqx.field
: this supportsconverter=...
andstatic=...
. The former is an extension to dataclasses that applies that conversion function when the field is assigned. The latter supersedes the oldeqx.static_field
. (#390) - Added
eqx.Enumeration
, which are JAX-compatible Enums. (Moved from `eqx.internal.Enumeration.) (#392) - Added
eqx.clear_caches
to clear internal caches and reduce memory usage. (#380) - Added
eqx.nn.BatchNorm(..., dtype=...)
(Thanks @Benjamin-Walker! #384) - Inside
eqx.internal.while_loop
: buffers now supportbuffer.at[index].add(...)
etc. (Thanks @packquickly! #395)
Changes
- Updated
typing->collections.abc
where appropriate;Tuple->tuple
etc. (#385)
Bugfixes
eqx.module_update_wrapper
no longer assigns__wrapped__
. (#381)
Full Changelog: v0.10.5...v0.10.6
Equinox v0.10.5
Quite a small release.
Bugfixes
- Fixed modules initialising twice (#369; this bug was introduced in the last couple of Equinox versions.)
Documentation
- Fix docstring typos in
MLP.__init__
. (Thanks @schmrlng! #366) - Added example ofor serialisation of hyperparameters (Thanks @bytbox! #374)
Misc
- Add
equinox.internal.eval_full
(likeequinox.internal.eval_{zeros, empty}
) (Thanks @RaderJason! #367) - Added JAX-compatible enums:
equinox.internal.Enumeration
(#375) - The minimum supported Python version has been bumped to 3.9 (#379)
New Contributors
Full Changelog: v0.10.4...v0.10.5
Equinox v0.10.4
Features
eqx.nn.{LayerNorm, GroupNorm}
can now accept a call-timestate
argument that they thread through unchanged. This means that they have the same API aseqx.nn.BatchNorm
, so that they may be used interchangeably.eqx.Module
s now work with the newjax.tree_util.tree_flatten_with_path
API. (#363)eqx.nn.MLP
now supportsuse_bias
anduse_final_bias
arguments. (Thanks @jlperla! #358)- Added
eqx.tree_check
to assert that a pytree does not contain duplicate elements, and does not contain any reference cycles. This may be useful to call on your models prior to training, to check that they are well-formed. (#355) - Added
eqx.tree_flatten_one_level
to flatten a pytree by one level only. (#355)
Internal (semi-undocumented / unstable) Features
eqx.internal.{error_if, branched_error_if, debug_backward_nans}
now have TPU support! This means that they now support all backends, and are (to my knowledge) the single best option for adding runtime checks to JAX programs. In addition they now eagerly will raise errors at trace-time if the predicate is a raw PythonTrue
. (#351)eqx.internal.scan
now supportsbuffers
andcheckpoints
arguments for finer-grained control over its autodiff. (#349)- Added
eqx.internal.scan_trick
, which can be used to minimise compilation time by wrapping nearby function invocations into a single scan. See this PR against Diffrax for an example.
Bugfixes
- Remove implicit rank promotion in
eqx.nn.ConvTranspose
(Thanks @khdlr! #335) eqx.static_field()
s were sometimes being put in leaves; this is now fixed. (This issue existed in v0.10.3 only.) (#338)eqx.filter_custom_jvp
will no longer raise the occasional spurious leaked tracer error. (When using traced non-floating arrays.) (#349)- Fixed crash when using zero-sized arrays inside
eqxi.while_loop(... kind='checkpointed')
(#331)
Other
- Now using
pyproject.toml
to handle everything (no moresetup.py
,.flake8
etc!) - Added example docs for autoparallel APIs (link)
eqx.internal.while_loop
should now have a slightly faster compile time. (#353)
New Contributors
Full Changelog: v0.10.3...v0.10.4
Equinox v0.10.3
Features
-
Added
equinox.nn.{State, StateIndex}
. This has been one of the longest-requested features for Equinox: we now have proper stateful operations! (In a carefully-controlled way -- see the new stateful docs and the new stateful example.) -
As an application of these new stateful operations: added
equinox.nn.{BatchNorm, SpectralNorm}
, which have graduated from experimental! Note that these have a slightly different API to their previous experimental versions. -
Added
equinox.Partial
, which is a tidied-up version ofjax.tree_util.Partial
. -
equinox.filter_{jit, pmap}
are now compatibile with ahead-of-time compilation. (#325) -
equinox.nn.LayerNorm
now supportsuse_weight
anduse_bias
arguments to disable each individually. This is reflecting the fact that many modern transformer architectures now use layer normalisation without bias. (#310; thanks @lockwo!) -
Added
equinox.internal.{AbstractVar, AbstractClassVar}
to denote abstract instance attributes and abstract class attributes respectively. (Analogous toabc.abstractmethod
denoting abstract methods.) The downstream scientific ecosystem is making heavy use of abstract base classes (e.g. all the ABCs in Diffrax) and these have turned out to be a really useful feature. See this docstring for more details. Right now these are an undocumented internal-only feature, but we could plausibly spin these out into their own library.
Tweaks
equinox.nn.Conv
should now be compatible with disabled rank promotion (#308; thanks @lockwo!)equinox.internal.loop
should now be compatible withjax.experimental.xmap
(patrick-kidger/diffrax#246)- Normalisation layers should now be tolerance to floating-point inaccuracies that occasionally produce negative variances. (#314; thanks @anh-tong!)
- The BERT example now has fixed dropout behaviour (#316; thanks @j5b!)
- Some doc fixes (#303; thanks @RaderJason!)
Removed
- Everything in
equinox.experimental.*
has been removed. See the new stateful functionality described above.
New Contributors
- @RaderJason made their first contribution in #303
- @lockwo made their first contribution in #311
- @anh-tong made their first contribution in #314
Full Changelog: v0.10.2...v0.10.3
Equinox v0.10.2
This release has lots of examples and bugfixes from several new contributors!
Features
eqx.nn.{Linear, MLP}
now support the string"scalar"
for their input and output sizes, to produce an array of shape()
rather than an array of shape(1,)
.- Added
equinox.internal.scan
for a checkpointed scan implementation. (It'd be interesting to see this used for an optimally-checkpointed scan-over-layers in an LLM?)
Documentation
- Much nicer examples! Big thanks to:
- @Artur-Galstyan for contributing a CNN-on-MNIST example;
- to @j5b for contributing a BERT example;
- to @Benjamin-Walker for contributing a U-Net example.
Bugfixes
eqx.filter_closure_convert
andeqx.internal.while_loop
now work with tree-math.- Improved numerical stability of
MultiheadAttention
, and fixed it producing NaNs in fully-masked case. (Thanks @j5b!) - Fixed (the deprecated, but still)
deterministic=True
being ignored inMultiheadAttention
. (Thanks @mk-0!) __new__
can now be overridden in subclasses ofeqx.Module
. (Thanks @ASEM000!)
Misc
- Now using ruff and pyright. (No longer using flake8 or isort.)
- Modules are now private-by-default, e.g.
equinox._jit
instead ofequinox.jit
. If you're broken by this change then you should make sure to import from the public interface: e.g.equinox.filter_jit
instead ofequinox._jit.filter_jit
. equinox.internal.while_loop(..., kind="checkpointed")
now supports readable buffers.eqx.filter_vmap
now supports all-None
s inin_axes
. (Thanks @RaderJason!)
Full Changelog: v0.10.1...v0.10.2
Equinox v0.10.1
The usual post-release hotfix.
See the v0.10.0 release notes for the interesting recent changes.
Changes in this release
- Fixed a number of English typos in strings and error messages.
- Fixed a couple of type annotations. (Thanks @dhirschfeld in #262)
- Removed spurious use of
typing_extensions
. - Fixed
eqx.filter_{vmap, pmap}(in_axes=dict(...), ...)
crashing when used alongside default arguments.
Full Changelog: v0.10.0...v0.10.1