-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama: fix custom 4D masks, v2 #30348
Conversation
38fcb65
to
69ce14b
Compare
As s a solution, I added additional Please review soon - I need this for my paper code. It's been broken for quite long now. The LONG tests look OK. |
if attention_mask is not None and attention_mask.dim() == 4: | ||
# we can pass both the full 4D mask (i.e. [..., full_len, full_len]) and a 4D mask with the same shape | ||
# as the causal mask (i.e. [..., seq_len, full_len]) | ||
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype | ||
offset = cache_position[0] | ||
if attention_mask.shape[-2] == offset + sequence_length: | ||
mask_slice = mask_slice[..., offset:, :] | ||
causal_mask = mask_slice | ||
else: | ||
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache | ||
target_length = self.config.max_position_embeddings | ||
else: # dynamic cache | ||
target_length = ( | ||
attention_mask.shape[-1] | ||
if isinstance(attention_mask, torch.Tensor) | ||
else past_seen_tokens + sequence_length + 1 | ||
) | ||
causal_mask = torch.full( | ||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device | ||
) | ||
if sequence_length != 1: | ||
causal_mask = torch.triu(causal_mask, diagonal=1) | ||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) | ||
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is very messy.
IMO if a 4D mask is passed, then we should not even touch it.
Previously we used to invert it, now that we don't do that in the process, we should not add extra logic and just return the 4D mask. Fine with me to add a check to make sure the min is -inf, max is 0 and shape is expected, but let's not increase the surface and complicate things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I simplified the code:
- removed inversion for custom 4d masks,
- removed mask slicing for custom 4d masks
- removed test case where we pass a 4D attention mask with the full sequence length (i.e. [..., full_len, full_len]))
- updated relevant tests
I am concerned that
- part of code related to checking masks shape sits in
_ignore_causal_mask_sdpa()
- the advanced 4d mask tests sit in
tests/models/llama/test_modeling_llama.py
- something may break them again outside of llama code, like insrc/transformers/modeling_attn_mask_utils.py
- could you please ensure thatMask4DTestHard
class is always tested if something changes inmodeling_attn_mask_utils.py
even if Llama code is not affected otherwise?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM let's just not add extra checks and var that are only here for the static cache in the forward
is_static_cache = isinstance(past_key_values, StaticCache) or isinstance( | ||
getattr(self.layers[0].self_attn, "past_key_value", None), StaticCache | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cc @gante in this release I think we can already not use any of the logic that does rely on the DynamicCache
in favor of relying on the cache positions WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker what do you mean by "logic that does rely on the DynamicCache"? The past_seen_tokens = past_key_values.get_seq_length()
line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, and the max_length as well! I just don't want to add a new variable that is not general an is only used in _update_causal_mask.
@@ -989,7 +992,10 @@ def forward( | |||
if position_ids is None: | |||
position_ids = cache_position.unsqueeze(0) | |||
|
|||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) | |||
static_cache_max_length = self.layers[0].self_attn.past_key_value.get_max_length() if is_static_cache else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not super sure the static_cache_max_length should be extracted here, owuld rather hide everything in the _update_causal_mask. That's the only nit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed it in [69c84cc ] - please check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A lot better thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for working on it 💪
I'm assuming all slow tests are passing for the models changed in this PR :) (if you haven't run them, please do)
is_static_cache = isinstance(past_key_values, StaticCache) or isinstance( | ||
getattr(self.layers[0].self_attn, "past_key_value", None), StaticCache | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker what do you mean by "logic that does rely on the DynamicCache"? The past_seen_tokens = past_key_values.get_seq_length()
line?
all CI tests are green, SLOW tests were OK on my side yesterday |
I noticed that mistral model support for 4D masks stayed broken after these fixes. So I added similar lines to |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
I added Is there anything left to do before the merge is possible? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove unrelated changes!
is_static_cache = isinstance(past_key_values, StaticCache) or isinstance( | ||
getattr(self.layers[0].self_attn, "past_key_value", None), StaticCache | ||
) | ||
if use_cache: # kept for BC (cache positions) | ||
if not isinstance(past_key_values, StaticCache): | ||
if not is_static_cache: | ||
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
past_seen_tokens = past_key_values.get_seq_length() | ||
|
||
if cache_position is None: | ||
if isinstance(past_key_values, StaticCache): | ||
if is_static_cache: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like these are unreleated and can be adressed in another PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know you are trying to check the correct shapes, but it's a lot of code that's being added for that 😢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when StaticCache is used, the past_key_values
argument can easily be None
, so isinstance(past_key_values, StaticCache)
condition is just not working.
As I understand it, once the StaticCache is initialized, there is no need to pass it in past_key_values
argument. That's why getattr(self.layers[0].self_attn, "past_key_value", None) is necessary.
I'd love to use fewer code lines for that but can't confidently do so since I don't quite understand the story behind ._ignore_causal_mask_sdpa()
addition.
For instance, I could propose if (not is_static_cache) and AttentionMaskConverter._ignore_causal_mask_sdpa(...)
, would that work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be fixed in another PR! 😉
static_cache_max_length = ( | ||
self.layers[0].self_attn.past_key_value.get_max_length() | ||
if isinstance(getattr(self.layers[0].self_attn, "past_key_value", None), StaticCache) | ||
else None | ||
) | ||
|
||
if AttentionMaskConverter._ignore_causal_mask_sdpa( | ||
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens | ||
attention_mask, | ||
inputs_embeds=input_tensor, | ||
past_key_values_length=past_seen_tokens, | ||
static_cache_max_length=static_cache_max_length, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_ignore_causal_mask_sdpa() returns True
unless it recognize the mask shape is valid. But to do so, it needs to know the static_cache_max_length
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mmmm again we should just check the attention_mask shape, if 4 just return it
sorry, but without these changes, the fixes and tests will not work. I looked for related PRs, all I found was #30476 but it is not fixing the relevant parts of the code. |
54a6ec8
to
c64188f
Compare
I tried to follow Arthur's advice to streamline the path for the 4D masks and it seems to work. The relevant tests do pass. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove the elif no? Looks good to me with the removed on the modeling changes!
elif attention_mask.dim() == 4: | ||
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with | ||
# cache. In that case, the 4D attention mask attends to the newest tokens only. | ||
if attention_mask.shape[-2] < cache_position[0] + sequence_length: | ||
logger.warning_once( | ||
"Passing a 4d mask shorter than the input length is deprecated and will be removed in " | ||
"transformers v4.42.0" | ||
) | ||
offset = cache_position[0] | ||
else: | ||
offset = 0 | ||
mask_shape = attention_mask.shape | ||
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype | ||
causal_mask[ | ||
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] | ||
] = mask_slice |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif attention_mask.dim() == 4: | |
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with | |
# cache. In that case, the 4D attention mask attends to the newest tokens only. | |
if attention_mask.shape[-2] < cache_position[0] + sequence_length: | |
logger.warning_once( | |
"Passing a 4d mask shorter than the input length is deprecated and will be removed in " | |
"transformers v4.42.0" | |
) | |
offset = cache_position[0] | |
else: | |
offset = 0 | |
mask_shape = attention_mask.shape | |
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype | |
causal_mask[ | |
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] | |
] = mask_slice |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed and removed
I combined the 2 tests from common, which were very similar. Added tolerance - now Mixtral passes it OK. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
A few small nits, but looks good to me!
Happy that 4d support is kept 🔥
src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
Outdated
Show resolved
Hide resolved
|
||
@slow | ||
@require_torch_gpu | ||
class Mask4DTestHard(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this test is the same as llama , let's use copied from
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comparing the test code for Llama and Mistral I see that:
- specific model used in setup()
- test examples are text-based (may not lead to same token count with different tokenizers)
- Llama version has extra methods to test StaticCache
Potential solutions:
0) do nothing
- get away from text-based examples and use token_ids for transferability across models
1a) copy tests from Llama to Mistral (need a hint on auto-copying syntax)
1b) move the Mask4DTestHard class methods to testing_common
Also are you OK with choice of Mistral for the test? I added it because it uses slightly different code path compared to Llama.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes no worries
self.skipTest("Model architecture has no generative classes, and thus not necessarily supporting 4D masks") | ||
|
||
for model_class in self.all_generative_model_classes: | ||
if not model_class._supports_cache_class: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
brittle as one does not guarantee the other, but that's alright.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these lines came from @gante - maybe he has comments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no worries it's alright to keep it!
Just waiting for the commits to be push to check the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
this is an attempt to rebase #29930 started initially by @gante
Fixes the issue raised by @poedator in #29753 (comment).
Causal mask is now of shape [..., seq_len, full_len], as opposed to [..., full_len, full_len]. This means custom 4D attention masks are now the whole causal mask, so we don't need a sliced copy -- we can copy the whole thing :)
This PR also expands the support of custom 4D attention mask: we can pass both the full mask ([..., full_len, full_len]) or the partial mask ([..., seq_len, full_len]).
as of 18.04.24 it is not passing the 4D mask tests because of the
_ignore_causal_mask_sdpa()
method, most recently edited in #30317 (merged today).tests/models/llama/test_modeling_llama.py::Mask4DTestHard::test_partial_stacked_causal_mask - ValueError: Incorrect 4D attention_mask shape: (1, 1, 12, 12); expected: (1, 1, 9, 12).
apparently
_ignore_causal_mask_sdpa()
expects thatattention_mask.shape[-2] == query_length
. This may only be true if the input_ids are contiguous, which is not always the case in some intended 4D mask applications.tests/models/llama/test_modeling_llama.py::Mask4DTestHard::test_stacked_causal_mask_static_cache - ValueError: Incorrect 4D attention_mask shape: (1, 1, 12, 16); expected: (1, 1, 12, 12).
in this test
_ignore_causal_mask_sdpa()
expects thatattention_mask.shape[-1] == key_value_length
which is set bypast seen tokens
. However, in the test I make this dimension equal to the static cache size, so that the mask always has same shape and the whole graph may be compiled and reused.I hesitate to make edits to
_ignore_causal_mask_sdpa()
because there may be some greater context.Summoning @younesbelkada @ArthurZucker @gante to help