Skip to content

Releases: patrick-kidger/equinox

Equinox v0.11.4

14 Apr 13:04
Compare
Choose a tag to compare

Features

  • Added eqx.filter_shard. This lowers to jax.lax.with_sharding_constraint as a single way to transfer data, or reshard data, both inside and outside of JIT! (No more jax.device_put.) In addition, the parallelism example has been updated to use this simpler new functionality. (Thanks @homerjed and @dlwh! #688, #691)

  • Added eqx.filter_{jacfwd,jacrev,hessian}. These do what you expect! (Thanks @lockwo! #677)

  • Added eqx.nn.RotaryPostionalEmbedding. This is designed to be used in conjunction with the existing eqx.nn.MultiheadAttention. (Thanks @Artur-Galstyan! #568)

  • Added support for padding='VALID', padding='SAME', padding='SAME_LOWER' to the convolutional layers: eqx.nn.{Conv, ...}. (Thanks @ChenAo-Phys! #658)

  • Added support for padding_mode='ZEROS', padding_mode='REFLECT', padding_mode='REPLICATE', padding_mode='CIRCULAR' to the convolutional layers: eqx.nn.{Conv, ...}. (Thanks @ChenAo-Phys! #658)

  • Added a dtype argument to eqx.nn.{MultiheadAttention, Linear, Conv, ...} for specifying the dtype of their parameters. In addition eqx.nn.BatchNorm will now also uses its dtype argument to determine the dtype of its weights and bias, not just the dtype of its moving statistics. (Thanks @Artur-Galstyan and @AakashKumarNain! #680, #689)

Compatibility

  • eqx.error_if is now compatible with JAX 0.4.26, which changed JAX's own reporting of error messages slightly. (Thanks @hawkinsp! #670)

  • Added a warning that checks for doing something like:

    class MyModule(eqx.Module):
    	fn: Callable
    
        def __init__(self, ...):
    	    self.fn = jax.vmap(some_fn)

    As this is an easy source of bugs. (The vmap'd function is not a PyTree so will not propagate anything in the PyTree stucture of some_fn.)

Technical internal stuff

  • eqx.internal.while_loop(..., kind="checkpointed") will now only propagate forward JVP tracers for those outputs which are perturbed due to the input to the loop being perturbed. (Rather than all of them.) This change just means that later calls to a nondifferentiable operation, like jax.pure_callback or eqx.internal.nondifferentiable, will no longer crash at trace time. (See patrick-kidger/diffrax#396.)

  • eqx.internal.while_loop(..., kind="bounded") will now handle certain vmap+grad combinations without crashing. (It seems like JAX is adding some spurious batch tracers.) (See patrick-kidger/optimistix#48 (comment))

  • the transpose rule for eqx.internal.create_vprim now understands symbolic zeros, fixing a crash for grad-of-vmap-of-<lineax.linear_solve that we only use some outputs from>. (See patrick-kidger/optimistix#48.)

  • The type annotation for the input of any converter function used in eqx.field(converter=...) will now be used as the type annotation in any dataclass-autogenerated __init__ functions. In particular this should mean such functions are now compatible with runtime type checkers like beartype. (jaxtyping users, you were already covered: this checks the assigned annotations instead.)

New Contributors

Full Changelog: v0.11.3...v0.11.4

Equinox v0.11.3

10 Jan 21:26
Compare
Choose a tag to compare

Features

  • Added equinox.nn.RMSNorm.
  • Added equinox.nn.WeightNorm.
  • equinox.tree_deserialise_leaves now treats jax.ShapeDtypeStructs in the same way as arrays. This makes it possible to avoid instantiating the initial model parameters only to throw them away again, by using equinox.filter_eval_shape:
    model = eqx.filter_eval_shape(Model, ...hyperparameters...)
    model = eqx.tree_deserialise_leaves(load_path, model)
    (#259)

Bugfixes

  • equinox.internal.noinline no longer initialises the JAX backend on use.
  • equinox.filter_jit(...).lower(..., some_kwarg=...) no longer crashes (#625, #627)
  • The state of equionx.nn.BatchNorm now uses the default floating point dtype, rather than always using float32.
  • equinox.nn.MultiheadAttention should now perform the softmax in float32 even when the input is of lower dtype. (This is important for numerical stability.)

Refactor

  • All the layers in equinox.nn.{Linear, MLP, ...} now standardise on accepting extra **kwargs and not calling super().__init__. The intention is that these layers be treated as final, i.e. not subclassable. (Previously things were inconsistent: some did this and some did not.)
  • Should now be compatible with JAX_NUMPY_DTYPE_PROMOTION=strict and JAX_NUMPY_RANK_PROMOTION=raise, and this is checked in tests.
  • Better error message when no kwargs passed to filter_grad (Thanks @knyazer! #589)

Internal features

These are undocumented internal features, that may be changed at any time.

  • Added EQX_GETKEY_SEED for use with equinox.internal.GetKey.
  • equinox.internal.while_loop now has its runtime errors removed. This should help with compatibility with TPUs. (#628)

New Contributors

Full Changelog: v0.11.2...v0.11.3

Equinox v0.11.2

13 Nov 18:28
Compare
Choose a tag to compare

Features

  • Added eqx.filter_jit(..., donate="all-except-first") and eqx.filter_jit(..., donate="warn-except-first"). This offers a way to donate all arguments except the first one. (If you have multiple such arguments then just pack them together into a tuple in the first argument.) This aims to be a low-overhead easy way to handle buffer donation.
  • Added eqx.debug.{assert_max_traces, get_num_traces}, which aim to provide a friendly way of asserting that a JIT'd function is not recompiled -- and if it is, which argument changed to cause the recompilation.
  • eqx.tree_pprint and eqx.tree_pformat now handle PyTorch tensors and jax.ShapeDtypeStructs.
  • eqx.tree_equal now has new arguments:
    • typematch=True: this will require that every leaf have precisely the same type as each other, i.e. right now the requirement is essentially leaf == leaf2; with this flag it becomes type(leaf) == type(leaf2) and leaf == leaf2.
    • rtol and atol: setting these to nonzero values allows for checking that inexact (floating or complex) arrays are allclose, rather than exactly equal.
    • The expectation is that these will be useful in unit tests, e.g. to write checks of the form assert eqx.tree_equal(output, expected_output, typematch=True, rtol=1e-5, atol=1e-5).

Bugfixes

  • Previously, a learnt activation function for eqx.nn.MLP would use the exact same learnt weights for every neuron in every layer. Now, a separate copy of the activation function is used in each location.
  • Subclasses of eqx.Module should now have their __init__ signatures correctly reported by downstream tooling, e.g. automated doc generators, some IDEs. (Thanks @danielward27! #573)

Typing

  • eqx.filter_value_and_grad now declares that it preserves the return type of its function (Thanks @ConnorBaker! #557)

Documentation

  • Fix missing index argument in docstring example for StateIndex (Thanks @edwardwli! #556)
  • Fixed broken link in eqx.Enumueration docstrings (Thanks @LouisDesdoigts! #579)
  • Fixed missing shape specification by in one of the tricks. (Thanks @homerjed! #582)

Other

  • Improved a few IPython tracebacks with appropriate __tracebackhide__ = True assignments.
  • Subclassedeqx.Enumerations can now override the message associated with their parent Enumeration: this now produces a warning rather than an error.
  • Documented the EQX_ON_ERROR_BREAKPOINT_FRAMES config variable, which is used to work around a JAX bug when setting EQX_ON_ERROR=breakpoint.
  • Can now monkey-patch the methods of an eqx.Module, e.g.
    class Foo(eqx.Module):
        def f(self): ...
    
    Foo.f = some_transform(Foo.f)
    the anticipated use-case for this is to make it easier for typecheckers; see #584.
  • eqx.debug.store_dce now supports non-arrays in its argument.
  • eqx.Enumeration.where(traced_pred, x, x) will now statically return x without tracing. This is occasionally useful to better propagate information at compile time.

Internal features (not officially supported, advanced use only)

  • Added eqx.internal.GetKey. This generates a random JAX PRNG key when called, and crucially has a nice __repr__ reporting what the seed value is. This should not be used in normal JAX code! This is intended as a convenience for tests, so that the random seed appears in the debug printout of a failed test.
  • Added eqx.internal.MaybeBuffer to indicate that an argument of an eqx.internal.{while_loop,scan} might be wrapped in a buffer.
  • Added eqx.internal.buffer_at_set to support buffer.at[...].set(..., pred=...) whilst being agnostic to whether buffer is a JAX array or one of our while loop buffers.

New Contributors

Full Changelog: v0.11.1...v0.11.2

Equinox v0.11.1

13 Oct 02:17
Compare
Choose a tag to compare

This is a minor bugfix release.

Bugfixes

  • Checkpointed while loops (eqx.internal.while_loop(..., kind="checkpointed")) now perform a more careful analysis of which arguments need to be differentiated. (#548) This fix is the primary reason for this release -- it unlocks some efficiency improvements when solving SDEs in Diffrax: patrick-kidger/diffrax#320
  • Fixed Abstract{Class,}Var misbehaving around multiple inheritance. (#544)
  • Better compatibility with the beartype library. In a few cases this was throwing some spurious errors to do with forward references. (#543)

Documentation

  • Fix scan-over-layers example in docs. (Thanks @mcbal! #542)

Other

  • Static type checkers should now use Equinox's type hints correctly. (Specfically, we now have the py.typed marker file. Thanks @vidhanio! #547)
  • Added the EQX_ON_ERROR_BREAKPOINT_FRAMES environment variable, to work around JAX bug google/jax#16732 when using EQX_ON_ERROR=breakpoint. This new variable sets the number of stack frames you can access via the u debugger command, when the on-error debugger is triggered. Set this to a small enough number, e.g. EQX_ON_ERROR_BREAKPOINT_FRAMES=1, and it should fix unusual trace-time errors when using EQX_ON_ERROR=breakpoint.

New Contributors

Full Changelog: v0.11.0...v0.11.1

Equinox v0.11.0

29 Sep 22:36
Compare
Choose a tag to compare

Better errors

Equinox now includes several additional checks to guard against various bugs. If you have a new error, then this is probably an indication that your code always had a silent bug, and should be updated.

  • eqx.nn.LayerNorm now correctly validates that the shape of its input. This was a common cause of silent bugs. (Thanks @dlwh for pointing this one out!)
  • Equinox now prints out a warning if you supply both __init__ and __post_init__ -- the former actually overwrites the latter. (This is normal Python dataclass behaviour, but probably unexpected.)
  • Equinox now prevents you from assigning Module attributes with a bound method of your current instance, e.g.
    class Model(eqx.Module):
        foo: Callable
    
        def __init__(self):
            self.foo = self.bar
    
        def bar(self):
            ...
    Otherwise, you end up with two different copies of your model! One at self, the other at self.foo.__self__. (The latter being in the bound method.)
  • eqx.tree_at now gives a better error message if you use it try to and update something that isn't a PyTree leaf. (Thanks @LouisDesdoigts!)

API changes

These should all be very minor.

  • Breaking change: eqx.nn.StateIndex now takes the initial value, rather than a function that returns the initial value.
  • Breaking change: If using eqx.field(converter=...), then conversion now happens before __post_init__, rather than after it.
  • Prefer eqx.nn.make_with_state over eqx.nn.State. The latter will continue to work, but the former is more memory-efficient. (It deletes the original copy of the initial state.)
  • Prefer eqx.nn.inference_mode over eqx.tree_inference. The latter will continue to exist for backward compatibility. These are the same function, this is really just a matter of moving it into the eqx.nn namespace where it always belonged.

Sharing layers

Equinox now supports sharing a layer between multiple parts of your model! This has probably been our longest-requested feature -- in large part because of how intractable it seemed. Equinox models are PyTrees, not PyDAGs, so how exactly are we supposed to have two different parts of our model point at the same layer?

The answer turned out to be the following -- in this example, we're reusing the embedding weight matrix between the initial embedding layer, and the final readout layer, of a language model.

class LanguageModel(eqx.Module):
    shared: eqx.nn.Shared

    def __init__(self):
        embedding = eqx.nn.Embedding(...)
        linear = eqx.nn.Linear(...)
        # These two weights will now be tied together.
        where = lambda embed_and_lin: embed_and_lin[1].weight
        get = lambda embed_and_lin: embed_and_lin[0].weight
        self.shared = eqx.nn.Shared((embedding, linear), where, get)

    def __call__(self, tokens):
        # Expand back out so we can evaluate these layers.
        embedding, linear = self.shared()
        assert embedding.weight is linear.weight  # same parameter!
        # Now go ahead and evaluate your language model.
        ...

here, eqx.nn.Shared(...) simply removes all of the nodes at where, so that we don't have two separate copies. Then when it is called at self.shared(), it puts them back again. Note that this isn't a copy and doesn't incur any additional memory overhead; this all happens at the Python level, not the XLA level.

(The curious may like to take a look at the implementation in equinox/nn/_shared.py, which turned out to be very simple.)

On a meta level, I'd like to comment that I'm quite proud of having gotten this one in! It means that Equinox now supports both stateful layers and shared layers, which have always been the two pieces that seemed out of reach when using something as simple as PyTrees to represent models. But it turns out that PyTrees really are all you need. :D

Other changes

Documentation

  • Many documentation fixes courtesy of @colehaus and @Artur-Galstyan!
  • Added two new examples to the documentation. Thank you to @ahmed-alllam for both of them!
    • Deep convolutional GAN
    • Vision Transformer
  • Added an FAQ entry on comparisons between Equinox and PyTorch/Keras/Julia/Flax. It's a common enough question that should probably have had an answer before now.
  • Added an FAQ entry on debugging recompilation.

Features

  • Added eqx.filter_checkpoint, which as you might expect is a filtered version of jax.checkpoint. (Thanks @dlwh!)
  • Added eqx.Module.__check_init__. This is run in a similar fashion to __post_init__; see the documentation. This can be used to check that invariants of your module hold after initialisation.
  • Added support for vmap'ing stateful layers, by adding eqx.nn.State.{substate, update}. This offers a way to subset or update a State object, that so only the parts of it that need to be vmap'd are passed in. See the stateful documentation for an example of how to do this.
  • Runtime error should now produce much more readable results, without any of the terrifying INTERNAL: Generated function failed: CpuCallback error stuff! This clean-up of the runtime error message is done by eqx.filter_jit, so that will need to be your top-level way of JIT'ing your computation.
  • Added eqx.nn.StatefulLayer -- this is (only!) with eqx.nn.Sequential, to indicate that the layer should be called with x, state, and not just x. If you would like a custom stateful layer to be compatible with Sequential then go ahead and subclass this, and potentially implement the is_stateful method. (Thanks @paganpasta!)
  • The forward pass of each eqx.nn.* layer is now wrapped in a jax.named_scope, for better debugging experience. (Thanks @ahmed-alllam!)
  • eqx.module_update_wrapper no longer requires a second argument; it will look at the __wrapped__ attribute of its first argument.
  • Added eqx.internal.closure_to_pytree, for... you guessed it, turning function closures into PyTrees. The closed-over variables are treated as the subnodes in the PyTree. This will operate recursively so that closed-over closures will themselves become PyTrees, etc. Note that closed-over global variables are not included.

Bugfixes

  • eqx.tree_{serialise,deserialise}_leaves now correctly handle unusual NumPy scalars, like bfloat16. (Thanks @colehaus!)
  • eqx.field(metadata=...) arguments no longer results in the static/converter arguments being ignored. (Thanks @mjo22!)
  • eqx.filter_custom_vjp now supports residuals that are not arrays. (The residuals are the pytree that is passed between the forward and backward pass.)
  • eqx.{AbstractVar,AbstractClassVar} should now support overriden generics in subclasses. That is, something like this:
    class Foo(eqx.Module):
        x: eqx.AbstractVar[list[str]]
    
    class Bar(Foo):
        x: list[str]
    should no longer raise spurious errors under certain conditions.
  • eqx.internal.while_loop now supports using custom (non-Equinox) pytrees in the state.
  • eqx.tree_check no longer raises some false positives.
  • Equinox modules now support __init_subclass__ with additional class creation kwargs. (Thanks @ASEM000, @Roger-luo!)

New Contributors

Full Changelog: v0.10.11...v0.11.0

Equinox v0.10.11

26 Jul 00:11
Compare
Choose a tag to compare

New features

  • Equinox now offers true runtime errors! This is available as equinox.error_if. This is something new under the JAX sun: these are raised eagerly during the execution, they work on TPU, and if you set the environment variable EQX_ON_ERROR=breakpoint, then they'll even drop you into a debugger as soon as you hit an error. (These are basically a strict improvement over jax.experimental.checkify, which doesn't offer many of these advantages.)

  • Added a suite of debugging tools:

    • equinox.debug.announce_transform: prints to stdout when it is transformed via jvp/vmap etc; very useful for keeping track of how many times a particular operation is getting transformed or compiled, when trying to minimise your compilation times.
    • equinox.debug.backward_nan: for debugging NaNs that only arise on the backward pass.
    • equinox.debug.breakpoint_if: opens a breakpoint if a condition is satisfied.
    • equinox.debug.{store_dce, inspect_dce}: used for checking whether certain variables are removed via the dead-code-elimination pass of the XLA compiler.
  • equinox.filter_jvp now supports keyword arguments (which are treated as not differentiated).

Bugfixes

  • Nested filter_jvps will now no longer materialise symbolic zero tangents. (#422).

Documentation

  • The marvellous Levanter library is now linked to in the documentation!

Full Changelog: v0.10.10...v0.10.11

Equinox v0.10.10

11 Jul 15:27
Compare
Choose a tag to compare

Performance improvements

These are the real highlight of this release.

  • equinox.internal.{while_loop, scan} now use new symbolic zero functionality, which may result in runtime speedups (and slight increases in compile times) as they can now skip calculating gradients for some quantities.
  • equinox.internal.{while_loop, scan}(..., buffers=...) now do their best to work around an XLA bug (google/jax#10197). This can reduce computational cost from quadratic scaling to linear scaling.
  • equinox.internal.{while_loop, scan} now includes several optimisations for the common case is which every step is checkpointed. (#415)

Features

  • equinox.filter_custom_{jvp,vjp} now support symbolic zeros.

    Previously, None was passed to represent symbolic zero tangent/cotangents for anything that wasn't a floating-point array -- but all floating-point-arrays always had materialised tangent/cotangents.

    With this release, None may also sometimes be passed as the tangent of floating-point arrays. In this case it represents a zero tangent/cotangent, and moreover this zero is "symbolic" -- that is to say it is known to be zero at compile time, which may allow you to write more-efficient custom JVP/VJP rules. (The canonical example is the inverse function theorem -- this involves a linear solve, parts of which you can skip if you know parts of it are zero.)

    In addition, filter_custom_vjp now takes another argument, perturbed, indicating whether a value actually needs cotangents calculated for it. You can skip calculating cotangents for anything that is not perturbed.

    For more information see jax.custom_jvp.defjvp(..., symbolic_zeros=True) and jax.custom_vjp.defvjp(..., symbolic_zeros=True), which provide the underlying behaviour that is being forwarded.

    Note that this is provided through a new API: filter_custom_jvp.def_jvp instead of filter_custom_jvp.defjvp, and filter_custom_vjp.{def_fwd, def_bwd} instead of filter_custom_vjp.defvjp. The old API will continue to exhibit the previous behaviour, for backward compatibility.

Misc

  • Apply functools.wraps to Module methods to preserve docstrings (Thanks @bowlingmh! #409)
  • Enumerations now perform their checks at compile time if possible. This sometimes makes it possible to get more efficent code, by special-casing on these values or eliding branches. (#417)

New Contributors

Full Changelog: v0.10.6...v0.10.10

(Why no v0.10.{7,8,9}? We had a bit of a rocky release this time around, and these got yanked for having bugs. Thanks to everyone who reported issues so quickly! Things look like they're stable now...)

Equinox v0.10.6

14 Jun 17:26
05c5673
Compare
Choose a tag to compare

Features

  • Added eqx.field: this supports converter=... and static=.... The former is an extension to dataclasses that applies that conversion function when the field is assigned. The latter supersedes the old eqx.static_field. (#390)
  • Added eqx.Enumeration, which are JAX-compatible Enums. (Moved from `eqx.internal.Enumeration.) (#392)
  • Added eqx.clear_caches to clear internal caches and reduce memory usage. (#380)
  • Added eqx.nn.BatchNorm(..., dtype=...) (Thanks @Benjamin-Walker! #384)
  • Inside eqx.internal.while_loop: buffers now support buffer.at[index].add(...) etc. (Thanks @packquickly! #395)

Changes

  • Updated typing->collections.abc where appropriate; Tuple->tuple etc. (#385)

Bugfixes

  • eqx.module_update_wrapper no longer assigns __wrapped__. (#381)

Full Changelog: v0.10.5...v0.10.6

Equinox v0.10.5

01 Jun 18:45
Compare
Choose a tag to compare

Quite a small release.

Bugfixes

  • Fixed modules initialising twice (#369; this bug was introduced in the last couple of Equinox versions.)

Documentation

  • Fix docstring typos in MLP.__init__. (Thanks @schmrlng! #366)
  • Added example ofor serialisation of hyperparameters (Thanks @bytbox! #374)

Misc

  • Add equinox.internal.eval_full (like equinox.internal.eval_{zeros, empty}) (Thanks @RaderJason! #367)
  • Added JAX-compatible enums: equinox.internal.Enumeration (#375)
  • The minimum supported Python version has been bumped to 3.9 (#379)

New Contributors

Full Changelog: v0.10.4...v0.10.5

Equinox v0.10.4

22 May 05:31
Compare
Choose a tag to compare

Features

  • eqx.nn.{LayerNorm, GroupNorm} can now accept a call-time state argument that they thread through unchanged. This means that they have the same API as eqx.nn.BatchNorm, so that they may be used interchangeably.
  • eqx.Modules now work with the new jax.tree_util.tree_flatten_with_path API. (#363)
  • eqx.nn.MLP now supports use_bias and use_final_bias arguments. (Thanks @jlperla! #358)
  • Added eqx.tree_check to assert that a pytree does not contain duplicate elements, and does not contain any reference cycles. This may be useful to call on your models prior to training, to check that they are well-formed. (#355)
  • Added eqx.tree_flatten_one_level to flatten a pytree by one level only. (#355)

Internal (semi-undocumented / unstable) Features

  • eqx.internal.{error_if, branched_error_if, debug_backward_nans} now have TPU support! This means that they now support all backends, and are (to my knowledge) the single best option for adding runtime checks to JAX programs. In addition they now eagerly will raise errors at trace-time if the predicate is a raw Python True. (#351)
  • eqx.internal.scan now supports buffers and checkpoints arguments for finer-grained control over its autodiff. (#349)
  • Added eqx.internal.scan_trick, which can be used to minimise compilation time by wrapping nearby function invocations into a single scan. See this PR against Diffrax for an example.

Bugfixes

  • Remove implicit rank promotion in eqx.nn.ConvTranspose (Thanks @khdlr! #335)
  • eqx.static_field()s were sometimes being put in leaves; this is now fixed. (This issue existed in v0.10.3 only.) (#338)
  • eqx.filter_custom_jvp will no longer raise the occasional spurious leaked tracer error. (When using traced non-floating arrays.) (#349)
  • Fixed crash when using zero-sized arrays inside eqxi.while_loop(... kind='checkpointed') (#331)

Other

  • Now using pyproject.toml to handle everything (no more setup.py, .flake8 etc!)
  • Added example docs for autoparallel APIs (link)
  • eqx.internal.while_loop should now have a slightly faster compile time. (#353)

New Contributors

Full Changelog: v0.10.3...v0.10.4