Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama: fix custom 4D masks, v2 #30348

Merged
merged 16 commits into from
May 13, 2024
Merged

Conversation

poedator
Copy link
Contributor

@poedator poedator commented Apr 19, 2024

this is an attempt to rebase #29930 started initially by @gante

Fixes the issue raised by @poedator in #29753 (comment).

Causal mask is now of shape [..., seq_len, full_len], as opposed to [..., full_len, full_len]. This means custom 4D attention masks are now the whole causal mask, so we don't need a sliced copy -- we can copy the whole thing :)

This PR also expands the support of custom 4D attention mask: we can pass both the full mask ([..., full_len, full_len]) or the partial mask ([..., seq_len, full_len]).


as of 18.04.24 it is not passing the 4D mask tests because of the _ignore_causal_mask_sdpa() method, most recently edited in #30317 (merged today).

tests/models/llama/test_modeling_llama.py::Mask4DTestHard::test_partial_stacked_causal_mask - ValueError: Incorrect 4D attention_mask shape: (1, 1, 12, 12); expected: (1, 1, 9, 12).
apparently _ignore_causal_mask_sdpa() expects that attention_mask.shape[-2] == query_length. This may only be true if the input_ids are contiguous, which is not always the case in some intended 4D mask applications.

tests/models/llama/test_modeling_llama.py::Mask4DTestHard::test_stacked_causal_mask_static_cache - ValueError: Incorrect 4D attention_mask shape: (1, 1, 12, 16); expected: (1, 1, 12, 12).
in this test _ignore_causal_mask_sdpa() expects that attention_mask.shape[-1] == key_value_length which is set by past seen tokens. However, in the test I make this dimension equal to the static cache size, so that the mask always has same shape and the whole graph may be compiled and reused.

I hesitate to make edits to _ignore_causal_mask_sdpa() because there may be some greater context.

Summoning @younesbelkada @ArthurZucker @gante to help

@poedator poedator changed the title 4d fix 2 Llama: fix custom 4D masks, v2 Apr 19, 2024
@poedator poedator force-pushed the 4d_fix_2 branch 2 times, most recently from 38fcb65 to 69ce14b Compare April 20, 2024 23:02
@poedator
Copy link
Contributor Author

poedator commented Apr 20, 2024

As s a solution, I added additional expected_shapes to _ignore_causal_mask_sdpa() and improved StaticCache detection code.
Note: it is inconvenient to have StaticCache as layer.self_attn objects and other Caches as model-level object. Perhaps there may be a model-level plug to avoid referencing the layer levels.

Please review soon - I need this for my paper code. It's been broken for quite long now.

The LONG tests look OK.

@poedator poedator marked this pull request as ready for review April 20, 2024 23:45
Comment on lines 1017 to 1023
if attention_mask is not None and attention_mask.dim() == 4:
# we can pass both the full 4D mask (i.e. [..., full_len, full_len]) and a 4D mask with the same shape
# as the causal mask (i.e. [..., seq_len, full_len])
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
offset = cache_position[0]
if attention_mask.shape[-2] == offset + sequence_length:
mask_slice = mask_slice[..., offset:, :]
causal_mask = mask_slice
else:
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is very messy.
IMO if a 4D mask is passed, then we should not even touch it.
Previously we used to invert it, now that we don't do that in the process, we should not add extra logic and just return the 4D mask. Fine with me to add a check to make sure the min is -inf, max is 0 and shape is expected, but let's not increase the surface and complicate things.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simplified the code:

  • removed inversion for custom 4d masks,
  • removed mask slicing for custom 4d masks
  • removed test case where we pass a 4D attention mask with the full sequence length (i.e. [..., full_len, full_len]))
  • updated relevant tests

I am concerned that

  • part of code related to checking masks shape sits in _ignore_causal_mask_sdpa()
  • the advanced 4d mask tests sit in tests/models/llama/test_modeling_llama.py - something may break them again outside of llama code, like in src/transformers/modeling_attn_mask_utils.py - could you please ensure that Mask4DTestHard class is always tested if something changes in modeling_attn_mask_utils.py even if Llama code is not affected otherwise?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM let's just not add extra checks and var that are only here for the static cache in the forward

Comment on lines 977 to 979
is_static_cache = isinstance(past_key_values, StaticCache) or isinstance(
getattr(self.layers[0].self_attn, "past_key_value", None), StaticCache
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cc @gante in this release I think we can already not use any of the logic that does rely on the DynamicCache in favor of relying on the cache positions WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker what do you mean by "logic that does rely on the DynamicCache"? The past_seen_tokens = past_key_values.get_seq_length() line?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and the max_length as well! I just don't want to add a new variable that is not general an is only used in _update_causal_mask.

@@ -989,7 +992,10 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
static_cache_max_length = self.layers[0].self_attn.past_key_value.get_max_length() if is_static_cache else None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not super sure the static_cache_max_length should be extracted here, owuld rather hide everything in the _update_causal_mask. That's the only nit

Copy link
Contributor Author

@poedator poedator Apr 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed it in [69c84cc ] - please check

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot better thanks

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for working on it 💪

I'm assuming all slow tests are passing for the models changed in this PR :) (if you haven't run them, please do)

Comment on lines 977 to 979
is_static_cache = isinstance(past_key_values, StaticCache) or isinstance(
getattr(self.layers[0].self_attn, "past_key_value", None), StaticCache
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker what do you mean by "logic that does rely on the DynamicCache"? The past_seen_tokens = past_key_values.get_seq_length() line?

@poedator
Copy link
Contributor Author

poedator commented Apr 23, 2024

all CI tests are green, SLOW tests were OK on my side yesterday

@poedator
Copy link
Contributor Author

I noticed that mistral model support for 4D masks stayed broken after these fixes. So I added similar lines to src/transformers/modeling_attn_mask_utils.py::_prepare_4d_causal_attention_mask_for_sdpa()

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@poedator
Copy link
Contributor Author

I added Mask4DTestHard tests (without static cache part) to tests/models/mistral/test_modeling_mistral.py to ensure that the 4d masks keep working in the models that use _prepare_4d_causal_attention_mask_for_sdpa(). These new tests would fail without the fixes from commit d488f35 just above.
Tested the SLOW tests for ./tests/models/mistral/ branch - all fine

Is there anything left to do before the merge is possible?
@gante @ArthurZucker

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove unrelated changes!

Comment on lines 899 to 910
is_static_cache = isinstance(past_key_values, StaticCache) or isinstance(
getattr(self.layers[0].self_attn, "past_key_value", None), StaticCache
)
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
if not is_static_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()

if cache_position is None:
if isinstance(past_key_values, StaticCache):
if is_static_cache:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like these are unreleated and can be adressed in another PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know you are trying to check the correct shapes, but it's a lot of code that's being added for that 😢

Copy link
Contributor Author

@poedator poedator Apr 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when StaticCache is used, the past_key_values argument can easily be None, so isinstance(past_key_values, StaticCache) condition is just not working.
As I understand it, once the StaticCache is initialized, there is no need to pass it in past_key_values argument. That's why getattr(self.layers[0].self_attn, "past_key_value", None) is necessary.

I'd love to use fewer code lines for that but can't confidently do so since I don't quite understand the story behind ._ignore_causal_mask_sdpa() addition.

For instance, I could propose if (not is_static_cache) and AttentionMaskConverter._ignore_causal_mask_sdpa(...), would that work?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fixed in another PR! 😉

Comment on lines 1002 to 988
static_cache_max_length = (
self.layers[0].self_attn.past_key_value.get_max_length()
if isinstance(getattr(self.layers[0].self_attn, "past_key_value", None), StaticCache)
else None
)

if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
static_cache_max_length=static_cache_max_length,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_ignore_causal_mask_sdpa() returns True unless it recognize the mask shape is valid. But to do so, it needs to know the static_cache_max_length .

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmmm again we should just check the attention_mask shape, if 4 just return it

@poedator
Copy link
Contributor Author

Let's remove unrelated changes!

sorry, but without these changes, the fixes and tests will not work. I looked for related PRs, all I found was #30476 but it is not fixing the relevant parts of the code.

@poedator poedator force-pushed the 4d_fix_2 branch 2 times, most recently from 54a6ec8 to c64188f Compare April 29, 2024 18:13
@poedator
Copy link
Contributor Author

I tried to follow Arthur's advice to streamline the path for the 4D masks and it seems to work. The relevant tests do pass.
@ArthurZucker @gante , please review

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the elif no? Looks good to me with the removed on the modeling changes!

Comment on lines 1105 to 1120
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
logger.warning_once(
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
"transformers v4.42.0"
)
offset = cache_position[0]
else:
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
logger.warning_once(
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
"transformers v4.42.0"
)
offset = cache_position[0]
else:
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed and removed

@poedator
Copy link
Contributor Author

poedator commented May 9, 2024

I combined the 2 tests from common, which were very similar. Added tolerance - now Mixtral passes it OK.
@ArthurZucker, @gante - please see if it is good to merge now

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!
A few small nits, but looks good to me!
Happy that 4d support is kept 🔥


@slow
@require_torch_gpu
class Mask4DTestHard(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test is the same as llama , let's use copied from

Copy link
Contributor Author

@poedator poedator May 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comparing the test code for Llama and Mistral I see that:

  • specific model used in setup()
  • test examples are text-based (may not lead to same token count with different tokenizers)
  • Llama version has extra methods to test StaticCache

Potential solutions:
0) do nothing

  1. get away from text-based examples and use token_ids for transferability across models
    1a) copy tests from Llama to Mistral (need a hint on auto-copying syntax)
    1b) move the Mask4DTestHard class methods to testing_common

Also are you OK with choice of Mistral for the test? I added it because it uses slightly different code path compared to Llama.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes no worries

self.skipTest("Model architecture has no generative classes, and thus not necessarily supporting 4D masks")

for model_class in self.all_generative_model_classes:
if not model_class._supports_cache_class:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

brittle as one does not guarantee the other, but that's alright.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these lines came from @gante - maybe he has comments?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no worries it's alright to keep it!

@ArthurZucker
Copy link
Collaborator

Just waiting for the commits to be push to check the done parts and we can merge

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants