Skip to content

Commit

Permalink
Merge pull request #3919 from google:nnx-iternodes-docs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 634675098
  • Loading branch information
Flax Authors committed May 17, 2024
2 parents cf51da3 + ed49ab3 commit f2b2ea3
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 21 deletions.
2 changes: 1 addition & 1 deletion flax/experimental/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from .nnx.graph import pop as pop
from .nnx.graph import state as state
from .nnx.graph import graphdef as graphdef
from .nnx.graph import iter_nodes as iter_nodes
from .nnx.graph import iter_graph as iter_graph
from .nnx.nn import initializers as initializers
from .nnx.nn.activations import celu as celu
from .nnx.nn.activations import elu as elu
Expand Down
52 changes: 40 additions & 12 deletions flax/experimental/nnx/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,24 +1061,52 @@ def clone(node: Node) -> Node:
return merge(graphdef, state)


def iter_nodes(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
def iter_graph(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
"""Iterates over all nested nodes and leaves of a graph node, including the current node.
``iter_graph`` creates a generator that yields path and value pairs, where
the path is a tuple of strings or integers representing the path to the value from the
root. Repeated nodes are visited only once. Leaves include static values.
Example::
>>> from flax.experimental import nnx
>>> import jax.numpy as jnp
...
>>> class Linear(nnx.Module):
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
... self.din, self.dout = din, dout
... self.w = nnx.Param(jax.random.uniform(rngs.next(), (din, dout)))
... self.b = nnx.Param(jnp.zeros((dout,)))
...
>>> module = Linear(3, 4, rngs=nnx.Rngs(0))
>>> graph = [module, module]
...
>>> for path, module in nnx.iter_graph(graph):
... print(path, type(module).__name__)
...
(0, 'b') Param
(0, 'din') int
(0, 'dout') int
(0, 'w') Param
(0,) Linear
() list
"""
visited: set[int] = set()
path_parts: PathParts = ()
yield from _iter_nodes(node, visited, path_parts)
yield from _iter_graph(node, visited, path_parts)


def _iter_nodes(
def _iter_graph(
node: tp.Any, visited: set[int], path_parts: PathParts
) -> tp.Iterator[tuple[PathParts, tp.Any]]:
if not is_node(node):
return
if id(node) in visited:
return
visited.add(id(node))
node_impl = get_node_impl(node)
node_dict = node_impl.node_dict(node)
for key, value in node_dict.items():
yield from _iter_nodes(value, visited, (*path_parts, key))
if is_node(node):
if id(node) in visited:
return
visited.add(id(node))
node_dict = get_node_impl(node).node_dict(node)
for key, value in node_dict.items():
yield from _iter_graph(value, visited, (*path_parts, key))

yield path_parts, node


Expand Down
12 changes: 5 additions & 7 deletions flax/experimental/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@
tuple_reduce = lambda xs, x: xs + (x,)
tuple_init = lambda: ()


@tp.runtime_checkable
class _HasSetup(tp.Protocol):
def setup(self) -> None:
...
def setup(self) -> None: ...


class ModuleMeta(ObjectMeta):
Expand Down Expand Up @@ -138,15 +138,15 @@ def init(self: M) -> M:
"""

def _init_context(accessor: DelayedAccessor, *args, **kwargs):
for _, value in graph.iter_nodes(self):
for _, value in graph.iter_graph(self):
if isinstance(value, Object):
value._object__state._initializing = True

method = accessor(self)
try:
out = method(*args, **kwargs)
finally:
for _, value in graph.iter_nodes(self):
for _, value in graph.iter_graph(self):
if isinstance(value, Object):
value._object__state._initializing = False

Expand Down Expand Up @@ -190,7 +190,7 @@ def iter_modules(self) -> tp.Iterator[tuple[PathParts, Module]]:
('linear',) Linear
() Block
"""
for path, value in graph.iter_nodes(self):
for path, value in graph.iter_graph(self):
if isinstance(value, Module):
yield path, value

Expand Down Expand Up @@ -376,5 +376,3 @@ def first_from(*args: tp.Optional[A], error_msg: str) -> A:
if arg is not None:
return arg
raise ValueError(error_msg)


2 changes: 1 addition & 1 deletion flax/experimental/nnx/nnx/rnglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def split_key(key: tp.Any) -> jax.Array:

def backup_keys(node: tp.Any, /):
streams: list[RngStream] = []
for _, stream in graph.iter_nodes(node):
for _, stream in graph.iter_graph(node):
if isinstance(stream, RngStream):
stream.key_backups.append(RngKeyBackup(stream.key.value))
streams.append(stream)
Expand Down

0 comments on commit f2b2ea3

Please sign in to comment.