Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LayoutLMv2 init issue and doctest #30278

Merged
merged 3 commits into from
Apr 23, 2024
Merged

Fix LayoutLMv2 init issue and doctest #30278

merged 3 commits into from
Apr 23, 2024

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Apr 16, 2024

What does this PR do?

(I think I have done all I can do for doctest - there are 2 other failures where I ping other people)

LayoutLMv2Model.visual_segment_embedding's initialization (when it is not in a checkpoint) won't be reproducible even if seed is set. This is because it is defined

self.visual_segment_embedding = nn.Parameter(nn.Embedding(1, config.hidden_size).weight[0])

and we have _fast_init used in loading +_init_weights is not designed to deal with nn.Parameter.

This PR fixes the reproducibility issue by initializing visual_segment_embedding explicitly.

Remark: such places requiring this fix is very few - so we don't destroy the purpose of having _fast_init.

# For `nn.Parameter`, we need to give it an initialized weight in order to keep reproducibility.
# Otherwise, since `_fast_init` is used in `modeling_utils.py`, the layer `nn.Embedding` is not initialized,
# and `_init_weights` is not designed to deal with `nn.Parameter`.
self.visual_segment_embedding.data.normal_(mean=0.0, std=self.config.initializer_range)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the main fix

@ydshieh ydshieh requested a review from amyeroberts April 16, 2024 18:48
@ydshieh
Copy link
Collaborator Author

ydshieh commented Apr 16, 2024

The failure in tests_pr_documentation_tests is just missing detectron2. I can install it for that job.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -697,6 +697,10 @@ def __init__(self, config):
self.visual_proj = nn.Linear(config.image_feature_pool_shape[-1], config.hidden_size)
if self.has_visual_segment_embedding:
self.visual_segment_embedding = nn.Parameter(nn.Embedding(1, config.hidden_size).weight[0])
# For `nn.Parameter`, we need to give it an initialized weight in order to keep reproducibility.
# Otherwise, since `_fast_init` is used in `modeling_utils.py`, the layer `nn.Embedding` is not initialized,
# and `_init_weights` is not designed to deal with `nn.Parameter`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why we can't have this logic in _init_weights ?

Copy link
Collaborator Author

@ydshieh ydshieh Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The definition of _init_weights looks like def _init_weights(self, module):, where we check things like

if isinstance(module, nn.Embedding)

elif isinstance(module, (nn.Linear, nn.Conv1d)

etc. (using layer type to determine the init method)

We don't check if it is nn.Parameter as it is too generic. For example, we also have

self.q_bias = nn.Parameter(torch.zeros(1, 1, self.all_head_size))
self.v_bias = nn.Parameter(torch.zeros(1, 1, self.all_head_size))

(where they are initialized with zeros , and not treated in _init_weights.

However, here it kind special:

self.visual_segment_embedding = nn.Parameter(nn.Embedding(1, config.hidden_size).weight[0])

it was nn.Embedding but wrappered into nn.Parameter. I guess it is because we don't want to have bias (?)
But since it is nn.Parameter, it is not treated in nn.Parameter.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some models (say, VivitPreTrainedModel), we have isinstance(module, nn.Parameter) in _init_weights

         elif isinstance(module, nn.Parameter):
            module.data.normal_(mean=0.0, std=self.config.initializer_range)

But for LayoutLM, it has multiple nn.Parameter and they don't mean to be initialized in the same way.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gentilly ping @amyeroberts 😃

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts Let me know if you have further question.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sorry for being late on this. I see. My understanding of the init weights behaviour is that if a weight has been initialized then it's marked as such and skipped it it matches other patterns later on (correct me if I'm wrong here). Would it be possible then to instead something similar to what we see in e.g. wav2vec2's _init_weights where we check if the module is LayoutLMv2Model and then init this param?

I realise it's not a huge difference having this in this init than _init_weights, but there's an increasing amount of logic e.g. with quantization which assumes this behaviour is controlled within the _init_weights module

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the insightful suggestion! I will give it a shot and let you know.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts works well! Once you approve the changes, I will update the >>> set_seed(x) to make the doctest passing.

@ydshieh ydshieh requested review from amyeroberts and removed request for amyeroberts April 17, 2024 09:47
@ydshieh ydshieh marked this pull request as draft April 22, 2024 15:20
@@ -822,7 +829,7 @@ def forward(
>>> import torch
>>> from datasets import load_dataset

>>> set_seed(88)
>>> set_seed(0)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all these and below have to be updated

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

turns out they work on the CI runners. No need to update

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing and iterating!

@ydshieh ydshieh marked this pull request as ready for review April 23, 2024 13:24
@ydshieh ydshieh merged commit 416fdba into main Apr 23, 2024
17 of 19 checks passed
@ydshieh ydshieh deleted the fix_layoutlmv2_init branch April 23, 2024 13:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants