-
Notifications
You must be signed in to change notification settings - Fork 601
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] add compat #3921
[nnx] add compat #3921
Conversation
17e7fc2
to
3dc8838
Compare
c137583
to
24002bc
Compare
e4af67e
to
0da1116
Compare
return x @ self.w + self.b[None] | ||
|
||
@dataclasses.dataclass | ||
class Foo(nnx.compat.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class Foo(nnx.compat.Module): | |
class Foo(compat.Module): |
Change all nnx.compat
references to compat
for consistency?
>>> import jax.numpy as jnp | ||
... | ||
>>> class Linear(nnc.Module): | ||
... def __init__(self, dout, rngs: nnx.Rngs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For clarification:
- if we define an
__init__
method, we will be able to instantiate module parameters using the.init
method? - if we define a
setup
method or wrap__call__
withcompact
, we will be able to instantiate the module parameters by calling the module on a sample input and invoking shape inference? - the module parameters that are instantiated are bound to the module so they can be dot-accessed, which is different from Linen where they are returned separately as a variable dict?
- Instead of defining an
__init__
method, can we define asetup
method or wrap__call__
withcompact
to use the.init
method as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the confusion here, this init
method is the current init
we have for nnx.Module
but we are just moving it out to compat.Module
, however its still need to create refactor the method so it follows the Linen API as closely as possible in a subsequent PR.
What does this PR do?
Adds
nnx.compat
module with the goal of making it possible to port Linen codebases to an NNX system with as few changes as possible. It would contain the following functionality:Module
: inherits fromnnx.Module
and adds methods fromlinen.Module
.compact
: allows defining submodules inlinewrappers
: some types that simplify NNX <-> Linen interop.