Skip to content

Equinox v0.11.4

Latest
Compare
Choose a tag to compare
@github-actions github-actions released this 14 Apr 13:04
· 17 commits to main since this release

Features

  • Added eqx.filter_shard. This lowers to jax.lax.with_sharding_constraint as a single way to transfer data, or reshard data, both inside and outside of JIT! (No more jax.device_put.) In addition, the parallelism example has been updated to use this simpler new functionality. (Thanks @homerjed and @dlwh! #688, #691)

  • Added eqx.filter_{jacfwd,jacrev,hessian}. These do what you expect! (Thanks @lockwo! #677)

  • Added eqx.nn.RotaryPostionalEmbedding. This is designed to be used in conjunction with the existing eqx.nn.MultiheadAttention. (Thanks @Artur-Galstyan! #568)

  • Added support for padding='VALID', padding='SAME', padding='SAME_LOWER' to the convolutional layers: eqx.nn.{Conv, ...}. (Thanks @ChenAo-Phys! #658)

  • Added support for padding_mode='ZEROS', padding_mode='REFLECT', padding_mode='REPLICATE', padding_mode='CIRCULAR' to the convolutional layers: eqx.nn.{Conv, ...}. (Thanks @ChenAo-Phys! #658)

  • Added a dtype argument to eqx.nn.{MultiheadAttention, Linear, Conv, ...} for specifying the dtype of their parameters. In addition eqx.nn.BatchNorm will now also uses its dtype argument to determine the dtype of its weights and bias, not just the dtype of its moving statistics. (Thanks @Artur-Galstyan and @AakashKumarNain! #680, #689)

Compatibility

  • eqx.error_if is now compatible with JAX 0.4.26, which changed JAX's own reporting of error messages slightly. (Thanks @hawkinsp! #670)

  • Added a warning that checks for doing something like:

    class MyModule(eqx.Module):
    	fn: Callable
    
        def __init__(self, ...):
    	    self.fn = jax.vmap(some_fn)

    As this is an easy source of bugs. (The vmap'd function is not a PyTree so will not propagate anything in the PyTree stucture of some_fn.)

Technical internal stuff

  • eqx.internal.while_loop(..., kind="checkpointed") will now only propagate forward JVP tracers for those outputs which are perturbed due to the input to the loop being perturbed. (Rather than all of them.) This change just means that later calls to a nondifferentiable operation, like jax.pure_callback or eqx.internal.nondifferentiable, will no longer crash at trace time. (See patrick-kidger/diffrax#396.)

  • eqx.internal.while_loop(..., kind="bounded") will now handle certain vmap+grad combinations without crashing. (It seems like JAX is adding some spurious batch tracers.) (See patrick-kidger/optimistix#48 (comment))

  • the transpose rule for eqx.internal.create_vprim now understands symbolic zeros, fixing a crash for grad-of-vmap-of-<lineax.linear_solve that we only use some outputs from>. (See patrick-kidger/optimistix#48.)

  • The type annotation for the input of any converter function used in eqx.field(converter=...) will now be used as the type annotation in any dataclass-autogenerated __init__ functions. In particular this should mean such functions are now compatible with runtime type checkers like beartype. (jaxtyping users, you were already covered: this checks the assigned annotations instead.)

New Contributors

Full Changelog: v0.11.3...v0.11.4