Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] transforms refactor #3927

Merged
merged 1 commit into from
May 22, 2024
Merged

[nnx] transforms refactor #3927

merged 1 commit into from
May 22, 2024

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented May 16, 2024

What does this PR do?

  • Adds cond
  • Fixes issue with scan not caching properly
  • Fixes issue with scan having different keys for the every step for non-split_rngs keys
  • Refactored code for jit and scan (TODO: port the other transforms to the new simplified style).
  • Refactors the implementation of fork to make it easier to use.
  • RngCount now has a tag attribute (same as RngKey), this enable filtering counts as well (needed for scan).
  • Fix tracer leakage issues on transforms.

Copy link
Member

@superbobry superbobry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did my best to read through, but this PR seems to have too many moving parts to really follow what's going on.

I recommend

  • adding more details to the PR description -- what were the bugs, why were they important, what is the high level fix idea etc etc;
  • doing smaller more scoped PRs going forward.

If you can, ask one other person to review as well.


for ref in self.refmap:
if isinstance(ref, Variable):
ref.raw_value = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to add a clear() method to Variable doing this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed this code, found a better way.

flax/experimental/nnx/nnx/graph.py Outdated Show resolved Hide resolved
@@ -173,7 +160,7 @@ def fork(
state: State,
split_filter: filterlib.Filter,
split_pattern: SplitPattern,
) -> tuple[State, State]:
) -> tuple[State, State, State, State]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend making this a typing.NamedTuple or a dataclass so that the caller doesn't have to remember what each State component corresponds to.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

flax/experimental/nnx/nnx/transforms.py Outdated Show resolved Hide resolved

argnums = options.argnums[0] if len(options.argnums) == 1 else options.argnums
# rebuild diff_state from substates in args
diff_state = State({})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be the same as

diff_state = State({i: _args[i] for i in diff_args})

?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe like this:

diff_state = State({i: _args[i].raw_mapping for i in diff_args})

reduce_axes=reduce_axes,
)(*_args, f, ctx, graphdef, non_diff_state, has_aux, diff_args)

updates: State
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

used to force the type annotation on the update definitions below

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any updates definitions below. Which lines are you referring to?

@@ -104,7 +105,7 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.din = din
self.dout = dout

@nnx.jit
# @nnx.jit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leftover debugging code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, thanks! removed

@copybara-service copybara-service bot merged commit 7542b28 into main May 22, 2024
21 checks passed
@copybara-service copybara-service bot deleted the nnx-cond branch May 22, 2024 09:26
chiamp pushed a commit to chiamp/flax that referenced this pull request May 22, 2024
@cgarciae cgarciae mentioned this pull request May 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants