Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Token classification for Udop #30672

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/en/model_doc/udop.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,10 @@ to fine-tune UDOP on a custom dataset as well as inference. 🌎
## UdopEncoderModel

[[autodoc]] UdopEncoderModel
- forward


## UdopForTokenClassification

[[autodoc]] UdopForTokenClassification
- forward
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3574,6 +3574,7 @@
"UDOP_PRETRAINED_MODEL_ARCHIVE_LIST",
"UdopEncoderModel",
"UdopForConditionalGeneration",
"UdopForTokenClassification",
"UdopModel",
"UdopPreTrainedModel",
],
Expand Down Expand Up @@ -8147,6 +8148,7 @@
UDOP_PRETRAINED_MODEL_ARCHIVE_LIST,
UdopEncoderModel,
UdopForConditionalGeneration,
UdopForTokenClassification,
UdopModel,
UdopPreTrainedModel,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,7 @@
("roformer", "RoFormerForTokenClassification"),
("squeezebert", "SqueezeBertForTokenClassification"),
("t5", "T5ForTokenClassification"),
("udop", "UdopForTokenClassification"),
("umt5", "UMT5ForTokenClassification"),
("xlm", "XLMForTokenClassification"),
("xlm-roberta", "XLMRobertaForTokenClassification"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/udop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"UdopPreTrainedModel",
"UdopModel",
"UdopEncoderModel",
"UdopForTokenClassification",
]

if TYPE_CHECKING:
Expand Down Expand Up @@ -88,6 +89,7 @@
UDOP_PRETRAINED_MODEL_ARCHIVE_LIST,
UdopEncoderModel,
UdopForConditionalGeneration,
UdopForTokenClassification,
UdopModel,
UdopPreTrainedModel,
)
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/models/udop/configuration_udop.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class UdopConfig(PretrainedConfig):
Size of the intermediate feed forward layer in each `UdopBlock`.
num_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder and decoder.
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
num_decoder_layers (`int`, *optional*):
Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
num_heads (`int`, *optional*, defaults to 16):
Expand Down Expand Up @@ -84,6 +86,8 @@ class UdopConfig(PretrainedConfig):
The patch size used by the vision encoder.
num_channels (`int`, *optional*, defaults to 3):
The number of channels in the input images.
classifier_dropout (`float`, *optional*):
The dropout ratio for the classification head.
"""

model_type = "udop"
Expand All @@ -97,6 +101,7 @@ def __init__(
d_kv=64,
d_ff=4096,
num_layers=24,
hidden_dropout_prob=0.0,
num_decoder_layers=None,
num_heads=16,
relative_attention_num_buckets=32,
Expand All @@ -114,13 +119,15 @@ def __init__(
image_size=224,
patch_size=16,
num_channels=3,
classifier_dropout=None,
**kwargs,
):
self.vocab_size = vocab_size
self.d_model = d_model
self.d_kv = d_kv
self.d_ff = d_ff
self.num_layers = num_layers
self.hidden_dropout_prob = hidden_dropout_prob
self.num_decoder_layers = (
num_decoder_layers if num_decoder_layers is not None else self.num_layers
) # default = symmetry
Expand All @@ -132,6 +139,7 @@ def __init__(
self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj
self.use_cache = use_cache
self.classifier_dropout = classifier_dropout

# UDOP attributes
self.max_2d_position_embeddings = max_2d_position_embeddings
Expand Down
136 changes: 136 additions & 0 deletions src/transformers/models/udop/modeling_udop.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from transformers.modeling_outputs import (
Seq2SeqLMOutput,
Seq2SeqModelOutput,
TokenClassifierOutput,
)

from ...activations import ACT2FN
Expand Down Expand Up @@ -1933,6 +1934,34 @@ def _reorder_cache(self, past_key_values, beam_idx):
return reordered_decoder_past


# Copied from transformers.models.layoutlmv3.modeling_layoutlmv3.LayoutLMv3ClassificationHead with LayoutLMv3-> Udop
class UdopClassificationHead(nn.Module):
"""
Head for sentence-level classification tasks. Reference: RobertaClassificationHead
"""

def __init__(self, config, pool_feature=False):
super().__init__()
self.pool_feature = pool_feature
if pool_feature:
self.dense = nn.Linear(config.hidden_size * 3, config.hidden_size)
else:
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

def forward(self, x):
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x


@add_start_docstrings(
"The bare UDOP Model transformer outputting encoder's raw hidden-states without any specific head on top.",
UDOP_START_DOCSTRING,
Expand Down Expand Up @@ -2042,3 +2071,110 @@ def forward(
)

return encoder_outputs


@add_start_docstrings(
"""
Udop Model with a token classification head on top (a linear layer on top of the final hidden states) e.g.
for sequence labeling (information extraction) tasks such as [FUNSD](https://guillaumejaume.github.io/FUNSD/),
[SROIE](https://rrc.cvc.uab.es/?ch=13), [CORD](https://github.com/clovaai/cord) and
[Kleister-NDA](https://github.com/applicaai/kleister-nda).
""",
UDOP_START_DOCSTRING,
)
class UdopForTokenClassification(UdopPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels

self.udop = UdopEncoderModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

self.classifier = UdopClassificationHead(config, pool_feature=False)

self.init_weights()

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
bbox: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[Tensor] = None,
decoder_attention_mask: Optional[Tensor] = None,
use_cache=True,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.LongTensor] = None,
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.

Returns:

Examples:

```python
>>> from transformers import AutoProcessor, AutoModelForTokenClassification
>>> from datasets import load_dataset

>>> processor = AutoProcessor.from_pretrained("microsoft/udop-large", apply_ocr=False)
>>> model = AutoModelForTokenClassification.from_pretrained("microsoft/udop-large", num_labels=7)

>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
>>> example = dataset[0]
>>> image = example["image"]
>>> words = example["tokens"]
>>> boxes = example["bboxes"]
>>> word_labels = example["ner_tags"]

>>> encoding = processor(image, words, boxes=boxes, word_labels=word_labels, return_tensors="pt")

>>> outputs = model(**encoding)
>>> loss = outputs.loss
>>> logits = outputs.logits
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.udop(
input_ids,
bbox=bbox,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
pixel_values=pixel_values,
)
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]

seq_length = input_shape[1]
# only take the text part of the output representations
sequence_output = outputs[0][:, :seq_length]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)

loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()

loss = loss_fct(logits.view(-1), labels.view(-1))

if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -8721,6 +8721,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class UdopForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class UdopModel(metaclass=DummyObject):
_backends = ["torch"]

Expand Down