Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flax.linen.module.init still fails under dynamic type checking for nested modules #3756

Open
evangelos-ch opened this issue Mar 13, 2024 · 0 comments
Assignees
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.

Comments

@evangelos-ch
Copy link

Related issue: #3224

While the snippet posted in that issue does work now, there still seems to be a failure mode when nested modules (all of which are runtime type checked) are used.

Colab Link

from jax import numpy as jnp
import flax.linen as nn
import jax
from beartype import beartype

from jaxtyping import jaxtyped

@jaxtyped(typechecker=beartype)
class MyModuleInternal(nn.Module):
    hidden_size: int = 2

    @nn.compact
    def __call__(self, x):
      return nn.Dense(self.hidden_size)(x)


@jaxtyped(typechecker=beartype)
class MyModule(nn.Module):
    hidden_dim: int

    def setup(self) -> None:
        self.internal_module = MyModuleInternal(self.hidden_dim)  # <-- failure here
  
    def __call__(self, x):
        x = self.internal_module(x)
        return x


model = MyModule(5)

params = model.init(
    rngs={"params": jax.random.PRNGKey(0)},
    x=jnp.ones((1, 1)),
)

This snippet fails with the following error:

---------------------------------------------------------------------------
BeartypeCallHintParamViolation            Traceback (most recent call last)
    [... skipping hidden 1 frame]

<@beartype(__main__.check_params) at 0x7d4c408e2830> in check_params(__beartype_get_violation, __beartype_conf, __beartype_object_137766497224064, __beartype_object_99821132912832, __beartype_object_99821132891488, __beartype_object_137766477140992, __beartype_func, *args, **kwargs)

BeartypeCallHintParamViolation: Method __main__.check_params() parameter parent="MyModule(
    # attributes
    hidden_dim = 5
)" violates type hint typing.Union[typing.Type[flax.linen.module.Module], flax.core.scope.Scope, typing.Type[flax.linen.module._Sentinel], NoneType]

Looking at nn.Module's _ParentType, indeed the type of the argument to parent is expected to be Type[nn.Module] so a class, rather than an instance of nn.Module which is what is actually being passed in. This seems to have been the problem for the previously reported instance of this issue in #3224 , since the PR that fixes it (#3371) changed the type annotation from Type[Scope] to simply Scope, to adjust the expectation from a class being provided to an instance.

@cgarciae cgarciae self-assigned this Mar 14, 2024
@chiamp chiamp added Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. and removed Priority: P3 - no schedule labels Mar 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.
Projects
None yet
Development

No branches or pull requests

3 participants