-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix use_cache for xla fsdp #30353
Fix use_cache for xla fsdp #30353
Conversation
I don't know how to fix the issue with the use_cache parameter in the modeling code. Is that widely used? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @alanwaketan, thanks for handling this!
I think this is OK - it's consistent with logic in transformers modeling code, @muellerzr to confirm it's fine for trainer.
I don't know how to fix the issue with the use_cache parameter in the modeling code. Is that widely used?
@alanwaketan Could you specify which modeling code you're referring to?
Re the failing tests - there was a fix pushed on main. Rebasing should resolve
Thanks, @amyeroberts. In most of the modeling code, I saw this use_cache is passed as a parameter: And then there is this check happened in the modeling code: Since we currently still cannot directly use the upstream grad ckpt, we cannot relay on the gradient_checkpointing flag and then reuse this logic. And I also don't want to add a new flag and modify all modeling codes. lol |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed should be fine to me. Thanks!
Thanks, @muellerzr. Can someone help me merge this? |
I can merge. Thanks again for fixing this! |
Thanks, @amyeroberts |
What does this PR do?
use_cache cannot be used with gradient checkpointing. In PyTorch/XLA, we have to rely on our own gradient checkpointing function instead of the upstream one. Somehow, transformers regress and couldn't recognize our gradient checkpointing anymore. This PR fixes it.
Fixes #30155
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@muellerzr @amyeroberts