Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨 Add training compatibility for Musicgen-like models #29802

Merged
merged 98 commits into from
Apr 25, 2024

Conversation

ylacombe
Copy link
Contributor

@ylacombe ylacombe commented Mar 22, 2024

This PR aims to add training compatibility for Musicgen and Musicgen Melody.

The main difference with classic cross-entropy is that there a num_codebooks labels to predict per timestamp instead of a single token per timestamp. This materializes in the loss which consists in the mean of cross-entropy per codebook.

A few additional insights:

  • The models don't have an EOS token id, so the models generate for max_length.
  • The model actually predict codebooks in a delayed pattern.
    - The first codebook channel is predicted without delay, but the further you go, the more delay there is (2nd codebook -> delayed by 1, 3rd codebook -> delayed by 2, etc.)
  • Training scripts will be shared as well

cc @sanchit-gandhi and @amyeroberts

@ylacombe ylacombe marked this pull request as ready for review March 22, 2024 12:50
@ylacombe ylacombe changed the title [WIP] Add training compatibility for Musicgen-like models Add training compatibility for Musicgen-like models Mar 22, 2024
@ylacombe ylacombe requested a review from sanchit-gandhi March 22, 2024 13:07
@arjunsinghrathore
Copy link

Hi! Is it possible to finetune the musicgen model currently? If yes then is there something I should keep in mind. Would be really helpful if you could share your opinions. Thanks!

@LiuZH-19
Copy link

Wonderful work! I'm currently attempting to fine-tune the Musicgen model using these codes, but I haven't succeeded yet. Is the model ready for fine-tuning, and are there specific aspects I should be aware of? Any training tips or guidance you could provide would be greatly appreciated!

Thank you so much!

@ylacombe
Copy link
Contributor Author

Hey @arjunsinghrathore and @LiuZH-19, I'll likely release some fine-tuning code next week or the week after!
May I ask what type of data do you have, out of curiosity ?
Thanks!

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fast tests look quite rigorous, but we can probably use the base one unless there's a specific embedding issue! Otherwise LGTM

Comment on lines 1390 to 1402
# per codebook cross-entropy
# -100 labels are ignored
labels = labels.masked_fill(labels == self.config.pad_token_id, -100)

mask = labels != -100

# per codebook cross-entropy
for codebook in range(self.config.num_codebooks):
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
codebook_mask = mask[..., codebook].contiguous().view(-1)
codebook_labels = labels[..., codebook].contiguous().view(-1)

loss += loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The padded labels should be set to -100 outside of the modelling code, i.e. in the data collator. Then, we don't need any of this masking logic, since the CE loss masks out -100 values by default (see argument ignore_index)

Suggested change
# per codebook cross-entropy
# -100 labels are ignored
labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
mask = labels != -100
# per codebook cross-entropy
for codebook in range(self.config.num_codebooks):
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
codebook_mask = mask[..., codebook].contiguous().view(-1)
codebook_labels = labels[..., codebook].contiguous().view(-1)
loss += loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask])
# per codebook cross-entropy
for codebook in range(self.config.num_codebooks):
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
codebook_labels = labels[..., codebook].contiguous().view(-1)
loss += loss_fct(codebook_logits, codebook_labels)


loss += loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask])

loss = loss / self.config.num_codebooks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reference from original AudioCraft code: https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/solvers/musicgen.py#L242-L243

Interesting that they average over codebooks, and not a true average over all labels. Given the sequence length for music generation is large (1500 tokens), the difference is going to be negligible.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add this link in a comment above for any unsuspecting future code reader to have context

@@ -1340,15 +1343,22 @@ def forward(
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c.f. this docstring that explains how the -100 padding token should be set

Comment on lines 228 to 229
# Contrarily to the initial method, we don't unfreeze freezed parameters.
# Otherwise, it'll mess with the freezed sinusoidal embeddings
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'll change the weights of the pre-trained embeddings for sure, but the code should still run right? Or is there an issue where the code won't run if we train the embeddings? Unless this is the case, I would just use the base check_training_gradient_checkpointing method for simplicity

Copy link
Contributor Author

@ylacombe ylacombe Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comments are not really clear, I'll make it clearer:

  1. the ConditionalGeneration models have the audio encoder which is never used outside of the .generate, if we don't freeze it, we'll have some issues because trainable weights won't have seen any gradients.
  2. the CausalModel's sinusoidal embeddings are frozen, and should stay frozen, (it shouldn't have been transcribed in Parameter)
    self.weights.requires_grad = False

@ylacombe ylacombe requested a review from amyeroberts April 16, 2024 17:45
@ylacombe
Copy link
Contributor Author

Hey @amyeroberts, gentle ping to ask for a review! Many thanks for your help!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this capability!

Mostly small comments. Only concern is backwards compatibility wrt the loss for musicgen

"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if (labels is not None) and (input_ids is None and inputs_embeds is None):
input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit funny to do this on this input_ids - normally it's just on the decoder_input_ids.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it's a bit confusing but input_ids in MusicgenForCausalLM actually corresponds to the audio input_ids (i.e the decoder input ids)!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, ok! Makes sense :)


loss += loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask])

loss = loss / self.config.num_codebooks
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add this link in a comment above for any unsuspecting future code reader to have context


return Seq2SeqLMOutput(
loss=loss,
loss=decoder_outputs.loss,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, the problem with this is it's not backwards compatible - users are now going to get different values of loss than before

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't seen any users using this or mentioning this tbh (besides it's totally wrong!).
How should we best handle this ? maybe adding a breaking flag in the PR name ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's completely wrong then I think it's OK to break. We should just add a 🚨 prefix to the PR title so it can be easily found when preparing the release notes

Comment on lines 2553 to 2561
def freeze_encoders(self, freeze_text_encoder=True):
if freeze_text_encoder:
for param in self.text_encoder.parameters():
param.requires_grad = False
self.text_encoder._requires_grad = False

for param in self.audio_encoder.parameters():
param.requires_grad = False
self.audio_encoder._requires_grad = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of this structure. In general, we don't add freeze methods to our models add leave that to the user to handle - although I see audio models appear to be the exception!

It's tidier to split this up to freeze_text_encoder and freeze_audio_encoder and then just call them separately or add an additional freeze_audio_encoder argument

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've split these up !

@@ -2428,6 +2442,16 @@ def _maybe_initialize_input_ids_for_generation(
break
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id

def freeze_encoders(self, freeze_text_encoder=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here

@@ -2533,6 +2550,16 @@ def resize_token_embeddings(self, *args, **kwargs):
" model.decoder.resize_token_embeddings(...))"
)

def freeze_encoders(self, freeze_text_encoder=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should have a docstring as its a public method

@@ -2189,6 +2288,26 @@ def test_eager_matches_sdpa_generate(self):

self.assertTrue(torch.allclose(res_eager, res_sdpa))

def test_requires_grad_with_frozen_encoders(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

@ylacombe
Copy link
Contributor Author

Many thanks for the review @amyeroberts, I've changed the code according to your comments!
The only left to address is the loss computation being breaking changes. let me know what you think of this.
Note that I don't believe a lot of users actually used the loss computation as it was.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work - thanks for adding this feature!

"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if (labels is not None) and (input_ids is None and inputs_embeds is None):
input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, ok! Makes sense :)


return Seq2SeqLMOutput(
loss=loss,
loss=decoder_outputs.loss,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's completely wrong then I think it's OK to break. We should just add a 🚨 prefix to the PR title so it can be easily found when preparing the release notes

@ylacombe ylacombe changed the title Add training compatibility for Musicgen-like models 🚨 Add training compatibility for Musicgen-like models Apr 25, 2024
@ylacombe ylacombe merged commit 90cb55b into huggingface:main Apr 25, 2024
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants