-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Gemma
work with torch.compile
#30775
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -104,15 +104,16 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): | |||
self.dim = dim | |||
self.max_position_embeddings = max_position_embeddings | |||
self.base = base | |||
self.register_buffer("inv_freq", None, persistent=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hope there is no specific (important enough) reason to do so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is 😅 precision issues. There is no flag to keep buffers in float32 precision, and I think it makes a difference in terms of the compute inv freq. That is what needs to be checked
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean we want to keep it always in fp32, but with buffer + if we give it values at init, it could be changed to other dtype at later stage of from_pretrained
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is the case, I would suggest let's first reach to an agreement that the issue of current implementation of self.inv_freq
doesn't work (actually) with torch.compile
It would be nice if you can try the provided code snippet. And if you want some evidences from the test test_torch_compile_fullgraph
, I can work on that too.
Then we discuss what would be a better fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems strange, as we do tests torch compile with gemma no? Is this a torch version?
AFAIK gemma supported compile from day 0 and I could use static cache prior to this PR.
@@ -104,15 +104,16 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): | |||
self.dim = dim | |||
self.max_position_embeddings = max_position_embeddings | |||
self.base = base | |||
self.register_buffer("inv_freq", None, persistent=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is 😅 precision issues. There is no flag to keep buffers in float32 precision, and I think it makes a difference in terms of the compute inv freq. That is what needs to be checked
I see
As mentioned (1 or 2 weeks ago) earlier on slack, this would likely pass without actually compiling the stuff. |
but then we run forward with it |
I would like to try your code and know your environment. In any case, I am sure there is something wrong on current main branch and reproducible code snippet are provided in the description. |
It doesn't capture the issue. My original message
So we really need to compile the This means we could still run (even if we |
also, the test is only on CPU. So we probably miss the cuda cases, as the error I provided are cuda-related
|
Taking the code snippet in the PR description On
Regarding the test Dummy model used in
|
BTW, transformers/src/transformers/models/llama/modeling_llama.py Lines 103 to 104 in 5ad960f
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a test similar to test_compile_static_cache
from Llama and now we should use copied from llama for the rotary embedding no?
4ff5b70
to
47c740c
Compare
Done
|
3924ace
to
a63c3ba
Compare
Adding test in this commit, see comments I left there for some explanations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
In order to make absolutely sure compile works, we need a test on the generations for gemma .
I will add |
|
All addressed, just the copies to be updated. |
self.register_buffer("inv_freq", None, persistent=False) | ||
|
||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) | ||
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as long as the dtype is float32 here this works!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok i will check again what you shared on slack DM and apply it. Thanks for the review!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
confirmed it's float32 even if i set torch_dtype=torch.float16
in from_pretrained
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well, after apply the fix-copies
Merge as the slow CI failures are on main too. |
What does this PR do?
Currently on
main
,Gemma
can't work withtorch.compile
(with static cache of course).This PR fixes it.
If the change is approved, I will apply it for a few more models to pass the copy check.
Short error log (on
A10
,torch 2.3 + cu112
)To reproduce and full error log
code snippet
Full error log
Full error log