Skip to content

Commit

Permalink
Enable dynamic resolution input for Swin Transformer and variants (#3…
Browse files Browse the repository at this point in the history
…0656)

* add interpolation of positional encoding support to swin

* add style changes

* use default image processor and make size a dictionary

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* remove logits testing

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Refactor image size validation logic when interpolation is disabled

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* remove asserts in modeling

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add dynamic resolution input support to swinv2

* change size to ensure interpolation encoding path is triggered

* set interpolate_pos_encoding default value to False

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set interpolate_pos_encoding default value to False

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set interpolate_pos_encoding default value to False

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set interpolate_pos_encoding default value to False

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set interpolate_pos_encoding default value to False

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set interpolate_pos_encoding default value to False

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set interpolate_pos_encoding default value to False

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set interpolate_pos_encoding default value to False

* add dynamic resolution input to donut swin

* add dynamic resolution input to maskformer swin

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
the-neural-networker and amyeroberts committed May 17, 2024
1 parent b6eb708 commit 481a957
Show file tree
Hide file tree
Showing 6 changed files with 291 additions and 20 deletions.
63 changes: 58 additions & 5 deletions src/transformers/models/donut/modeling_donut_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,48 @@ def __init__(self, config, use_mask_token=False):
self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""

num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def forward(
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
self,
pixel_values: Optional[torch.FloatTensor],
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> Tuple[torch.Tensor]:
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
_, num_channels, height, width = pixel_values.shape
embeddings, output_dimensions = self.patch_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size()

Expand All @@ -180,7 +218,10 @@ def forward(
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask

if self.position_embeddings is not None:
embeddings = embeddings + self.position_embeddings
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings

embeddings = self.dropout(embeddings)

Expand Down Expand Up @@ -219,14 +260,21 @@ def maybe_pad(self, pixel_values, height, width):
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values

def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
def forward(
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width)
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape
output_dimensions = (height, width)
Expand Down Expand Up @@ -849,6 +897,8 @@ def _init_weights(self, module):
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
Expand Down Expand Up @@ -899,6 +949,7 @@ def forward(
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, DonutSwinModelOutput]:
r"""
Expand All @@ -921,7 +972,9 @@ def forward(
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, len(self.config.depths))

embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
embedding_output, input_dimensions = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)

encoder_outputs = self.encoder(
embedding_output,
Expand Down
58 changes: 53 additions & 5 deletions src/transformers/models/maskformer/modeling_maskformer_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,50 @@ def __init__(self, config):
self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, pixel_values):
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""

num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def forward(self, pixel_values, interpolate_pos_encoding):
_, num_channels, height, width = pixel_values.shape
embeddings, output_dimensions = self.patch_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
embeddings = self.norm(embeddings)

if self.position_embeddings is not None:
embeddings = embeddings + self.position_embeddings
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings

embeddings = self.dropout(embeddings)

Expand Down Expand Up @@ -207,14 +245,21 @@ def maybe_pad(self, pixel_values, height, width):
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values

def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
def forward(
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width)
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape
output_dimensions = (height, width)
Expand Down Expand Up @@ -780,6 +825,7 @@ def forward(
head_mask=None,
output_attentions=None,
output_hidden_states=None,
interpolate_pos_encoding=False,
return_dict=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand All @@ -798,7 +844,9 @@ def forward(
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, len(self.config.depths))

embedding_output, input_dimensions = self.embeddings(pixel_values)
embedding_output, input_dimensions = self.embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)

encoder_outputs = self.encoder(
embedding_output,
Expand Down
75 changes: 70 additions & 5 deletions src/transformers/models/swin/modeling_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,48 @@ def __init__(self, config, use_mask_token=False):
self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""

num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def forward(
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
self,
pixel_values: Optional[torch.FloatTensor],
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> Tuple[torch.Tensor]:
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
_, num_channels, height, width = pixel_values.shape
embeddings, output_dimensions = self.patch_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size()

Expand All @@ -266,7 +304,10 @@ def forward(
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask

if self.position_embeddings is not None:
embeddings = embeddings + self.position_embeddings
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings

embeddings = self.dropout(embeddings)

Expand Down Expand Up @@ -304,14 +345,21 @@ def maybe_pad(self, pixel_values, height, width):
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values

def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
def forward(
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width)
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape
output_dimensions = (height, width)
Expand Down Expand Up @@ -924,6 +972,8 @@ def _init_weights(self, module):
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
Expand Down Expand Up @@ -981,6 +1031,7 @@ def forward(
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SwinModelOutput]:
r"""
Expand All @@ -1003,7 +1054,9 @@ def forward(
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, len(self.config.depths))

embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
embedding_output, input_dimensions = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)

encoder_outputs = self.encoder(
embedding_output,
Expand Down Expand Up @@ -1074,6 +1127,7 @@ def forward(
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SwinMaskedImageModelingOutput]:
r"""
Expand Down Expand Up @@ -1113,6 +1167,7 @@ def forward(
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

Expand Down Expand Up @@ -1156,6 +1211,14 @@ def forward(
"""
Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
the [CLS] token) e.g. for ImageNet.
<Tip>
Note that it's possible to fine-tune Swin on higher resolution images than the ones it has been trained on, by
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
position embeddings to the higher resolution.
</Tip>
""",
SWIN_START_DOCSTRING,
)
Expand Down Expand Up @@ -1188,6 +1251,7 @@ def forward(
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SwinImageClassifierOutput]:
r"""
Expand All @@ -1203,6 +1267,7 @@ def forward(
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

Expand Down

0 comments on commit 481a957

Please sign in to comment.