Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make FlatState a Mapping instead of a dict #3928

Merged
merged 1 commit into from
Jun 4, 2024

Conversation

NeilGirdhar
Copy link
Contributor

@NeilGirdhar NeilGirdhar commented May 19, 2024

  • Add nnx.traversals.{flatten_mapping, unflatten_dict}

    • These are modified from traverse_util.{flatten_dict, unflatten_dict}
    • They are annotated so that any future changes to the function (e.g., changing it back to work with dicts only) triggers a type error.
  • Minor tweaks to imports:

    • import from collections.abc instead of typing since the latter imports are deprecated.
    • Import from flax.typing instead of flax.core.scope since the latter has been moved.
    • Add from __future__ import annotations so that the annotations work on Python 3.9.
  • Minor tweaks to code:

    • Annotate some private functions to make them easier to understand.
    • When printing a type, print its qualname since that's a bit easier to read (str instead of <class 'str'>).

Fixes #3879

@NeilGirdhar
Copy link
Contributor Author

@cgarciae (Done, as you requested.)

@NeilGirdhar NeilGirdhar force-pushed the use_mapping branch 3 times, most recently from fa0436e to 6776c47 Compare May 20, 2024 06:33
@codecov-commenter
Copy link

codecov-commenter commented May 20, 2024

Codecov Report

Attention: Patch coverage is 0% with 102 lines in your changes are missing coverage. Please review.

Project coverage is 0.00%. Comparing base (2c7d7cd) to head (ecb7cf3).
Report is 62 commits behind head on main.

Files Patch % Lines
flax/nnx/nnx/traversals.py 0.00% 58 Missing ⚠️
flax/nnx/tests/test_traversals.py 0.00% 38 Missing ⚠️
flax/nnx/nnx/state.py 0.00% 5 Missing ⚠️
flax/nnx/__init__.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main   #3928       +/-   ##
==========================================
- Coverage   60.43%   0.00%   -60.44%     
==========================================
  Files         105     105               
  Lines       13263   13328       +65     
==========================================
- Hits         8015       0     -8015     
- Misses       5248   13328     +8080     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

flax/experimental/nnx/__init__.py Outdated Show resolved Hide resolved
@cgarciae
Copy link
Collaborator

@NeilGirdhar looks great! Thanks for doing this.
Approved but left a comment.

@cgarciae
Copy link
Collaborator

Seems there are some git conflicts.

@NeilGirdhar
Copy link
Contributor Author

Rebased to main.

* Add nnx.traversals.{flatten_mapping, unflatten_dict}
    * These are modified from traverse_util.{flatten_dict,
      unflatten_dict}
    * They are annotated so that any future changes to the function
      (e.g., changing it back to work with dicts only) triggers a type
      error.

* Minor tweaks to imports:
    * import from collections.abc instead of typing since the latter
      imports are deprecated.
    * Import from flax.typing instead of flax.core.scope since the
      latter has been moved.
    * Add from __future__ import annotations so that the annotations
      work on Python 3.9.

* Minor tweaks to code:
    * Annotate some private functions to make them easier to understand.
    * When printing a type, print its __qualname__ since that's a bit
      easier to read (str instead of <class 'str'>).

Fixes google#3879
@copybara-service copybara-service bot merged commit bff079b into google:main Jun 4, 2024
19 checks passed
@NeilGirdhar NeilGirdhar deleted the use_mapping branch June 4, 2024 22:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

No way to call nnx.State.from_flat_path
3 participants