Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Framework-agnostic split_state_dict_into_shards helper #1938

Merged
merged 19 commits into from
Feb 22, 2024

Conversation

Wauplin
Copy link
Contributor

@Wauplin Wauplin commented Dec 22, 2023

(feedback welcome 🙏)

PR based on an idea from @LysandreJik. Goal is to have a framework-agnostic helper to split a state dict into shards that can be reused in transformers, peft, accelerate, etc. The scope of this method is yet to be defined. At the moment, it takes as input a state_dict, a "max_size_per_shard" and a filename and returns an index + a "filename => tensor" mapping.

>>> import json
>>> import os
>>> from safetensors.torch import save_file as safe_save_file
>>> from huggingface_hub import split_torch_state_dict_into_shards

>>> def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str):
...     state_dict_split = split_torch_state_dict_into_shards(state_dict)
...     for filename, tensors in state_dict_split.filename_to_tensors.values():
...         shard = {tensor: state_dict[tensor] for tensor in tensors}
...         safe_save_file(
...             shard,
...             os.path.join(save_directory, filename),
...             metadata={"format": "pt"},
...         )
...     if state_dict_split.is_sharded:
...         index = {
...             "metadata": state_dict_split.metadata,
...             "weight_map": state_dict_split.tensor_to_filename,
...         }
...         with open(os.path.join(save_directory, "model.safetensors.index.json"), "w") as f:
...             f.write(json.dumps(index, indent=2))

Currently in PR:

  • take state_dict + threshold as input
  • take filename as input as a pattern (e.g. "model{suffix}.safetensors", "tf_model{suffix}.h5", "pytorch_model{suffix}.bin")
  • group tensors in shards
  • respect storage id if tensors have to be saved together (still to be done)
  • build index with metadata (total size) + weights_map
  • return shards (a list of state_dict) + index (a jsonable dict)
  • support for torch, tensorflow, numpy

Currently not in PR:

  • add framework to index (e.g. "pt")
  • provide filename for the index =>how?
  • save tensors to files =>do we want to provide a helper for that? Especially to save index in correct file + weights in a consistent way)
  • deserialize/load sharded model (will most probably never be done in huggingface_hub)

The current implementation is inspired by the torch implementation (see here). It support torch, tensorflow and numpy. This PR is still in draft so nothing is set in stone. In particular, depending on the scope we want, inputs and outputs can be adapted to be as user-friendly as possible (and still flexible).

Ping @amyeroberts @ArthurZucker @muellerzr on this (and please ping others if relevant).

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be great to have inside huggingface_hub to make available for everyone rather than seperated. I think for the most part the implementations look similar to what we have in accelerate here.

cc @SunMarc if you can notice any differences or things we may struggle with?

On our end once this gets merge and we're sure it's 1->1 what we're doing we'll deprecate the util in Accelerate and rely on the Hub's version (esp relevant since it's a required dep of Accelerate now)

@Wauplin
Copy link
Contributor Author

Wauplin commented Jan 4, 2024

Linking internal slack thread discussing it (cc @pcuenca).

@LysandreJik
Copy link
Member

Pretty excited by this PR! The current skeleton looks good to me. I'm wondering if it wouldn't make sense to make the addition of TensorT custom classes simpler by defining it as a class with the necessary overrideable methods, namely:

  • getting the storage IDs
  • getting the tensor sizes
  • anything else that might be worthwhile

and then we'd start with the definition of torch/tf/numpy methods but adding extras will therefore be super simple as long as those three methods are implemented.

@Wauplin
Copy link
Contributor Author

Wauplin commented Jan 24, 2024

Thanks for the review and the idea @LysandreJik! Will have a look on how I could make the framework-specific stuff simpler 👍

@Wauplin Wauplin changed the title [RfC] Draft for a framework-agnostic split_state_dict_into_shards helper Framework-agnostic split_state_dict_into_shards helper Feb 14, 2024
@Wauplin
Copy link
Contributor Author

Wauplin commented Feb 14, 2024

@LysandreJik I switched to a more functional programming design (which should be easier to test and maintain as you suggested). I realized that we don't need to test if the tensor if from tensorflow, numpy or torch each time we handle a new one. Instead I'm defining 1 method per framework and it's at the user discretion to use the correct one. I don't see a situation where a user doesn't know which type of tensor they are using. WDYT of the current design?

If that's fine, I'll clean this up, add some tests and document it a bit.

Copy link

codecov bot commented Feb 14, 2024

Codecov Report

Attention: 48 lines in your changes are missing coverage. Please review.

Comparison is base (d01206d) 82.22% compared to head (586b8d8) 80.29%.
Report is 2 commits behind head on main.

❗ Current head 586b8d8 differs from pull request most recent head f9c5057. Consider uploading reports for the commit f9c5057 to get more accurate results

Files Patch % Lines
src/huggingface_hub/serialization/_torch.py 30.76% 36 Missing ⚠️
src/huggingface_hub/serialization/_tensorflow.py 80.00% 4 Missing ⚠️
src/huggingface_hub/utils/_runtime.py 73.33% 4 Missing ⚠️
src/huggingface_hub/serialization/_base.py 96.96% 2 Missing ⚠️
src/huggingface_hub/serialization/_numpy.py 77.77% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1938      +/-   ##
==========================================
- Coverage   82.22%   80.29%   -1.93%     
==========================================
  Files          66       71       +5     
  Lines        8309     8461     +152     
==========================================
- Hits         6832     6794      -38     
- Misses       1477     1667     +190     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@LysandreJik
Copy link
Member

Yes, the current design looks great to me!

@Wauplin Wauplin marked this pull request as ready for review February 16, 2024 16:08
@Wauplin
Copy link
Contributor Author

Wauplin commented Feb 16, 2024

Thanks @LysandreJik, PR is now ready to be reviewed. I have added an example on how to use it with torch. We could add a "save" method for each framework that saves the state dict to files but that will be done in a follow-up PR.

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very clean to me :)

src/huggingface_hub/serialization/_base.py Outdated Show resolved Hide resolved
src/huggingface_hub/serialization/_base.py Show resolved Hide resolved
src/huggingface_hub/serialization/_base.py Show resolved Hide resolved
src/huggingface_hub/serialization/_numpy.py Outdated Show resolved Hide resolved
src/huggingface_hub/serialization/_tensorflow.py Outdated Show resolved Hide resolved
@Wauplin
Copy link
Contributor Author

Wauplin commented Feb 19, 2024

Thanks for the thorough review @pcuenca! I have addressed or replied to all of your comments :)

@LysandreJik
Copy link
Member

cc @mfuntowicz as well as we discussed it a while ago

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! It would be awesome to have docs for this somewhere.

Left a few suggestions and comments.

src/huggingface_hub/serialization/_base.py Outdated Show resolved Hide resolved
src/huggingface_hub/serialization/_torch.py Show resolved Hide resolved
src/huggingface_hub/serialization/_torch.py Show resolved Hide resolved
src/huggingface_hub/serialization/_base.py Outdated Show resolved Hide resolved
tests/test_serialization.py Outdated Show resolved Hide resolved
@Wauplin
Copy link
Contributor Author

Wauplin commented Feb 22, 2024

Thanks for the thorough review! Made the suggested changes and now waiting for the CI to complete before merging this stuff :)

Added them to the reference package under a "serialization" page that is meant to grow when adding the "save tensors" part. Let's start with that and reassess :) https://moon-ci-docs.huggingface.co/docs/huggingface_hub/pr_1938/en/package_reference/serialization

@Wauplin Wauplin merged commit ae3c4a0 into main Feb 22, 2024
16 checks passed
@Wauplin Wauplin deleted the add-helper-to-shard-model branch February 22, 2024 15:24
Copy link
Member

@julien-c julien-c left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm late to the party, but the filename convention is the same as the one we've been using in the existing sharding? (I don't remember if the existing sharding is implemented in transformers or in safetensors?)

@Wauplin
Copy link
Contributor Author

Wauplin commented Feb 23, 2024

Ends up with "model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors", etc... which is the one defined in transformers (not defined in safetensors itself).
See https://huggingface.co/HuggingFaceH4/zephyr-7b-beta/tree/main for example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants