Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix use_cache for xla fsdp #30353

Merged
merged 2 commits into from
Apr 23, 2024
Merged

Conversation

alanwaketan
Copy link
Contributor

@alanwaketan alanwaketan commented Apr 19, 2024

What does this PR do?

use_cache cannot be used with gradient checkpointing. In PyTorch/XLA, we have to rely on our own gradient checkpointing function instead of the upstream one. Somehow, transformers regress and couldn't recognize our gradient checkpointing anymore. This PR fixes it.

Fixes #30155

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@muellerzr @amyeroberts

Sorry, something went wrong.

@alanwaketan
Copy link
Contributor Author

I don't know how to fix the issue with the use_cache parameter in the modeling code. Is that widely used?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @alanwaketan, thanks for handling this!

I think this is OK - it's consistent with logic in transformers modeling code, @muellerzr to confirm it's fine for trainer.

I don't know how to fix the issue with the use_cache parameter in the modeling code. Is that widely used?

@alanwaketan Could you specify which modeling code you're referring to?

Re the failing tests - there was a fix pushed on main. Rebasing should resolve

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
@alanwaketan
Copy link
Contributor Author

alanwaketan commented Apr 22, 2024

Thanks, @amyeroberts. In most of the modeling code, I saw this use_cache is passed as a parameter:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L949

And then there is this check happened in the modeling code:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L967-L971

Since we currently still cannot directly use the upstream grad ckpt, we cannot relay on the gradient_checkpointing flag and then reuse this logic. And I also don't want to add a new flag and modify all modeling codes. lol

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed should be fine to me. Thanks!

@alanwaketan
Copy link
Contributor Author

Thanks, @muellerzr. Can someone help me merge this?

@amyeroberts
Copy link
Collaborator

I can merge. Thanks again for fixing this!

@amyeroberts amyeroberts merged commit 12c39e5 into huggingface:main Apr 23, 2024
21 checks passed
@alanwaketan
Copy link
Contributor Author

Thanks, @amyeroberts

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Transformers 4.39.x gemma 2b/7b lora tuning on TPU example error
4 participants