-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make EosTokenCriteria compatible with mps #30376
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing! Maybe we can leave the old method only and get rid of extra if
conditions. Both methods are compile compatible anyway :)
Yes, good point, I thought about that too. I opted to keep both because the new method is much more readable and intuitive, so it doesn't really add too much noise in my opinion. It also documents why we are using the long approach, and provides a hint (via the link to the github issue) to remove that case when we have full support of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the fix 🙏
I'm pro the if/else: it serves as a reminder to revisit the code in the future, when |
I'm not sure we should be doing this - the lack of support is really a torch issue, and not a transformers one. There's many operations which are (unfortunately) unsupported by the mps backend, and I wouldn't want to see a tonne of if/else statements across the library to support this. At the moment, |
I don't really mind this as it should be mostly for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed internally in slack - https://huggingface.slack.com/archives/C06SW4886G6/p1713567939087889 - happy to merge as it's currently impacting llama 3, but should have an issue created tracking this to make sure it's removed in the future
Thanks so much for the quick fix. Could this get a release? |
Yes, will do this asap, just waiting to make sure we don't have other failures. |
What does this PR do?
EOS termination was moved to stopped criteria in #29459. It uses
torch.isin()
, which is not compatible with themps
device: pytorch/pytorch#77764 (comment). This PR uses the old method to detect whether any of the EOS tokens were found in the generated sequences.Fixes #29459 (comment)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.