Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix image post-processing for OWLv2 #30686

Merged
merged 9 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/transformers/models/owlv2/image_processing_owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,6 @@ def preprocess(
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)

# Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process_object_detection
def post_process_object_detection(
self, outputs, threshold: float = 0.1, target_sizes: Union[TensorType, List[Tuple]] = None
):
Expand Down Expand Up @@ -525,6 +524,18 @@ def post_process_object_detection(
else:
img_h, img_w = target_sizes.unbind(1)

# rescale coordinates
width_ratio = 1
height_ratio = 1

if img_w < img_h:
width_ratio = img_w / img_h
elif img_h < img_w:
height_ratio = img_h / img_w

img_w = img_w / width_ratio
img_h = img_h / height_ratio

scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
boxes = boxes * scale_fct[:, None, :]

Expand Down
64 changes: 17 additions & 47 deletions src/transformers/models/owlv2/modeling_owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1542,9 +1542,7 @@ def image_guided_detection(
>>> import requests
>>> from PIL import Image
>>> import torch
>>> import numpy as np
>>> from transformers import AutoProcessor, Owlv2ForObjectDetection
>>> from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD

>>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
>>> model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
Expand All @@ -1559,20 +1557,7 @@ def image_guided_detection(
>>> with torch.no_grad():
... outputs = model.image_guided_detection(**inputs)

>>> # Note: boxes need to be visualized on the padded, unnormalized image
>>> # hence we'll set the target image sizes (height, width) based on that

>>> def get_preprocessed_image(pixel_values):
... pixel_values = pixel_values.squeeze().numpy()
... unnormalized_image = (pixel_values * np.array(OPENAI_CLIP_STD)[:, None, None]) + np.array(OPENAI_CLIP_MEAN)[:, None, None]
... unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
... unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
... unnormalized_image = Image.fromarray(unnormalized_image)
... return unnormalized_image

>>> unnormalized_image = get_preprocessed_image(inputs.pixel_values)

>>> target_sizes = torch.Tensor([unnormalized_image.size[::-1]])
>>> target_sizes = torch.Tensor([image.size[::-1]])

>>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
>>> results = processor.post_process_image_guided_detection(
Expand All @@ -1583,19 +1568,19 @@ def image_guided_detection(
>>> for box, score in zip(boxes, scores):
... box = [round(i, 2) for i in box.tolist()]
... print(f"Detected similar object with confidence {round(score.item(), 3)} at location {box}")
Detected similar object with confidence 0.938 at location [490.96, 109.89, 821.09, 536.11]
Detected similar object with confidence 0.959 at location [8.67, 721.29, 928.68, 732.78]
Detected similar object with confidence 0.902 at location [4.27, 720.02, 941.45, 761.59]
Detected similar object with confidence 0.985 at location [265.46, -58.9, 1009.04, 365.66]
Detected similar object with confidence 1.0 at location [9.79, 28.69, 937.31, 941.64]
Detected similar object with confidence 0.998 at location [869.97, 58.28, 923.23, 978.1]
Detected similar object with confidence 0.985 at location [309.23, 21.07, 371.61, 932.02]
Detected similar object with confidence 0.947 at location [27.93, 859.45, 969.75, 915.44]
Detected similar object with confidence 0.996 at location [785.82, 41.38, 880.26, 966.37]
Detected similar object with confidence 0.998 at location [5.08, 721.17, 925.93, 998.41]
Detected similar object with confidence 0.969 at location [6.7, 898.1, 921.75, 949.51]
Detected similar object with confidence 0.966 at location [47.16, 927.29, 981.99, 942.14]
Detected similar object with confidence 0.924 at location [46.4, 936.13, 953.02, 950.78]
Detected similar object with confidence 0.938 at location [327.31, 54.94, 547.39, 268.06]
Detected similar object with confidence 0.959 at location [5.78, 360.65, 619.12, 366.39]
Detected similar object with confidence 0.902 at location [2.85, 360.01, 627.63, 380.8]
Detected similar object with confidence 0.985 at location [176.98, -29.45, 672.69, 182.83]
Detected similar object with confidence 1.0 at location [6.53, 14.35, 624.87, 470.82]
Detected similar object with confidence 0.998 at location [579.98, 29.14, 615.49, 489.05]
Detected similar object with confidence 0.985 at location [206.15, 10.53, 247.74, 466.01]
Detected similar object with confidence 0.947 at location [18.62, 429.72, 646.5, 457.72]
Detected similar object with confidence 0.996 at location [523.88, 20.69, 586.84, 483.18]
Detected similar object with confidence 0.998 at location [3.39, 360.59, 617.29, 499.21]
Detected similar object with confidence 0.969 at location [4.47, 449.05, 614.5, 474.76]
Detected similar object with confidence 0.966 at location [31.44, 463.65, 654.66, 471.07]
Detected similar object with confidence 0.924 at location [30.93, 468.07, 635.35, 475.39]
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -1667,10 +1652,8 @@ def forward(
```python
>>> import requests
>>> from PIL import Image
>>> import numpy as np
>>> import torch
>>> from transformers import AutoProcessor, Owlv2ForObjectDetection
>>> from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD

>>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
>>> model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
Expand All @@ -1684,20 +1667,7 @@ def forward(
>>> with torch.no_grad():
... outputs = model(**inputs)

>>> # Note: boxes need to be visualized on the padded, unnormalized image
>>> # hence we'll set the target image sizes (height, width) based on that

>>> def get_preprocessed_image(pixel_values):
... pixel_values = pixel_values.squeeze().numpy()
... unnormalized_image = (pixel_values * np.array(OPENAI_CLIP_STD)[:, None, None]) + np.array(OPENAI_CLIP_MEAN)[:, None, None]
... unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
... unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
... unnormalized_image = Image.fromarray(unnormalized_image)
... return unnormalized_image

>>> unnormalized_image = get_preprocessed_image(inputs.pixel_values)

>>> target_sizes = torch.Tensor([unnormalized_image.size[::-1]])
>>> target_sizes = torch.Tensor([image.size[::-1]])
>>> # Convert outputs (bounding boxes and class logits) to final bounding boxes and scores
>>> results = processor.post_process_object_detection(
... outputs=outputs, threshold=0.2, target_sizes=target_sizes
Expand All @@ -1710,8 +1680,8 @@ def forward(
>>> for box, score, label in zip(boxes, scores, labels):
... box = [round(i, 2) for i in box.tolist()]
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
Detected a photo of a cat with confidence 0.614 at location [512.5, 35.08, 963.48, 557.02]
Detected a photo of a cat with confidence 0.665 at location [10.13, 77.94, 489.93, 709.69]
Detected a photo of a cat with confidence 0.614 at location [341.67, 23.39, 642.32, 371.35]
Detected a photo of a cat with confidence 0.665 at location [6.75, 51.96, 326.62, 473.13]
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down
25 changes: 23 additions & 2 deletions tests/models/owlv2/test_image_processor_owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@
import unittest

from transformers.testing_utils import require_torch, require_vision, slow
from transformers.utils import is_vision_available
from transformers.utils import is_torch_available, is_vision_available

from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs


if is_vision_available():
from PIL import Image

from transformers import Owlv2ImageProcessor
from transformers import AutoProcessor, Owlv2ForObjectDetection, Owlv2ImageProcessor

if is_torch_available():
import torch


class Owlv2ImageProcessingTester(unittest.TestCase):
Expand Down Expand Up @@ -120,6 +123,24 @@ def test_image_processor_integration_test(self):
mean_value = round(pixel_values.mean().item(), 4)
self.assertEqual(mean_value, 0.2353)

@slow
def test_image_processor_integration_test_resize(self):
checkpoint = "google/owlv2-base-patch16-ensemble"
processor = AutoProcessor.from_pretrained(checkpoint)
model = Owlv2ForObjectDetection.from_pretrained(checkpoint)

image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
inputs = processor(text=["cat"], images=image, return_tensors="pt")

with torch.no_grad():
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, threshold=0.2, target_sizes=target_sizes)[0]
jla524 marked this conversation as resolved.
Show resolved Hide resolved

boxes = results["boxes"].tolist()
self.assertEqual(boxes[0], [341.66656494140625, 23.38756561279297, 642.321044921875, 371.3482971191406])
self.assertEqual(boxes[1], [6.753320693969727, 51.96149826049805, 326.61810302734375, 473.12982177734375])

@unittest.skip("OWLv2 doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
def test_call_numpy_4_channels(self):
pass