Skip to content

Commit c832bcb

Browse files
authored
Fix owlv2 code snippet (#27698)
* Fix code snippet * Improve code snippet
1 parent 334a6d1 commit c832bcb

File tree

1 file changed

+56
-20
lines changed

1 file changed

+56
-20
lines changed

src/transformers/models/owlv2/modeling_owlv2.py

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,19 +1544,38 @@ def image_guided_detection(
15441544
>>> import requests
15451545
>>> from PIL import Image
15461546
>>> import torch
1547+
>>> import numpy as np
15471548
>>> from transformers import AutoProcessor, Owlv2ForObjectDetection
1549+
>>> from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
15481550
15491551
>>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
15501552
>>> model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
1553+
15511554
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
15521555
>>> image = Image.open(requests.get(url, stream=True).raw)
15531556
>>> query_url = "http://images.cocodataset.org/val2017/000000001675.jpg"
15541557
>>> query_image = Image.open(requests.get(query_url, stream=True).raw)
15551558
>>> inputs = processor(images=image, query_images=query_image, return_tensors="pt")
1559+
1560+
>>> # forward pass
15561561
>>> with torch.no_grad():
15571562
... outputs = model.image_guided_detection(**inputs)
1558-
>>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
1559-
>>> target_sizes = torch.Tensor([image.size[::-1]])
1563+
1564+
>>> # Note: boxes need to be visualized on the padded, unnormalized image
1565+
>>> # hence we'll set the target image sizes (height, width) based on that
1566+
1567+
>>> def get_preprocessed_image(pixel_values):
1568+
... pixel_values = pixel_values.squeeze().numpy()
1569+
... unnormalized_image = (pixel_values * np.array(OPENAI_CLIP_STD)[:, None, None]) + np.array(OPENAI_CLIP_MEAN)[:, None, None]
1570+
... unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
1571+
... unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
1572+
... unnormalized_image = Image.fromarray(unnormalized_image)
1573+
... return unnormalized_image
1574+
1575+
>>> unnormalized_image = get_preprocessed_image(inputs.pixel_values)
1576+
1577+
>>> target_sizes = torch.Tensor([unnormalized_image.size[::-1]])
1578+
15601579
>>> # Convert outputs (bounding boxes and class logits) to COCO API
15611580
>>> results = processor.post_process_image_guided_detection(
15621581
... outputs=outputs, threshold=0.9, nms_threshold=0.3, target_sizes=target_sizes
@@ -1566,19 +1585,19 @@ def image_guided_detection(
15661585
>>> for box, score in zip(boxes, scores):
15671586
... box = [round(i, 2) for i in box.tolist()]
15681587
... print(f"Detected similar object with confidence {round(score.item(), 3)} at location {box}")
1569-
Detected similar object with confidence 0.938 at location [327.31, 54.94, 547.39, 268.06]
1570-
Detected similar object with confidence 0.959 at location [5.78, 360.65, 619.12, 366.39]
1571-
Detected similar object with confidence 0.902 at location [2.85, 360.01, 627.63, 380.79]
1572-
Detected similar object with confidence 0.985 at location [176.97, -29.45, 672.69, 182.83]
1573-
Detected similar object with confidence 1.0 at location [6.53, 14.35, 624.87, 470.82]
1574-
Detected similar object with confidence 0.998 at location [579.98, 29.14, 615.49, 489.05]
1575-
Detected similar object with confidence 0.985 at location [206.15, 10.53, 247.74, 466.01]
1576-
Detected similar object with confidence 0.947 at location [18.62, 429.72, 646.5, 457.72]
1577-
Detected similar object with confidence 0.996 at location [523.88, 20.69, 586.84, 483.18]
1578-
Detected similar object with confidence 0.998 at location [3.39, 360.59, 617.29, 499.21]
1579-
Detected similar object with confidence 0.969 at location [4.47, 449.05, 614.5, 474.76]
1580-
Detected similar object with confidence 0.966 at location [31.44, 463.65, 654.66, 471.07]
1581-
Detected similar object with confidence 0.924 at location [30.93, 468.07, 635.35, 475.39]
1588+
Detected similar object with confidence 0.938 at location [490.96, 109.89, 821.09, 536.11]
1589+
Detected similar object with confidence 0.959 at location [8.67, 721.29, 928.68, 732.78]
1590+
Detected similar object with confidence 0.902 at location [4.27, 720.02, 941.45, 761.59]
1591+
Detected similar object with confidence 0.985 at location [265.46, -58.9, 1009.04, 365.66]
1592+
Detected similar object with confidence 1.0 at location [9.79, 28.69, 937.31, 941.64]
1593+
Detected similar object with confidence 0.998 at location [869.97, 58.28, 923.23, 978.1]
1594+
Detected similar object with confidence 0.985 at location [309.23, 21.07, 371.61, 932.02]
1595+
Detected similar object with confidence 0.947 at location [27.93, 859.45, 969.75, 915.44]
1596+
Detected similar object with confidence 0.996 at location [785.82, 41.38, 880.26, 966.37]
1597+
Detected similar object with confidence 0.998 at location [5.08, 721.17, 925.93, 998.41]
1598+
Detected similar object with confidence 0.969 at location [6.7, 898.1, 921.75, 949.51]
1599+
Detected similar object with confidence 0.966 at location [47.16, 927.29, 981.99, 942.14]
1600+
Detected similar object with confidence 0.924 at location [46.4, 936.13, 953.02, 950.78]
15821601
```"""
15831602
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
15841603
output_hidden_states = (
@@ -1650,8 +1669,10 @@ def forward(
16501669
```python
16511670
>>> import requests
16521671
>>> from PIL import Image
1672+
>>> import numpy as np
16531673
>>> import torch
16541674
>>> from transformers import AutoProcessor, Owlv2ForObjectDetection
1675+
>>> from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
16551676
16561677
>>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
16571678
>>> model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
@@ -1660,10 +1681,25 @@ def forward(
16601681
>>> image = Image.open(requests.get(url, stream=True).raw)
16611682
>>> texts = [["a photo of a cat", "a photo of a dog"]]
16621683
>>> inputs = processor(text=texts, images=image, return_tensors="pt")
1663-
>>> outputs = model(**inputs)
16641684
1665-
>>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
1666-
>>> target_sizes = torch.Tensor([image.size[::-1]])
1685+
>>> # forward pass
1686+
>>> with torch.no_grad():
1687+
... outputs = model(**inputs)
1688+
1689+
>>> # Note: boxes need to be visualized on the padded, unnormalized image
1690+
>>> # hence we'll set the target image sizes (height, width) based on that
1691+
1692+
>>> def get_preprocessed_image(pixel_values):
1693+
... pixel_values = pixel_values.squeeze().numpy()
1694+
... unnormalized_image = (pixel_values * np.array(OPENAI_CLIP_STD)[:, None, None]) + np.array(OPENAI_CLIP_MEAN)[:, None, None]
1695+
... unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
1696+
... unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
1697+
... unnormalized_image = Image.fromarray(unnormalized_image)
1698+
... return unnormalized_image
1699+
1700+
>>> unnormalized_image = get_preprocessed_image(inputs.pixel_values)
1701+
1702+
>>> target_sizes = torch.Tensor([unnormalized_image.size[::-1]])
16671703
>>> # Convert outputs (bounding boxes and class logits) to final bounding boxes and scores
16681704
>>> results = processor.post_process_object_detection(
16691705
... outputs=outputs, threshold=0.2, target_sizes=target_sizes
@@ -1676,8 +1712,8 @@ def forward(
16761712
>>> for box, score, label in zip(boxes, scores, labels):
16771713
... box = [round(i, 2) for i in box.tolist()]
16781714
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
1679-
Detected a photo of a cat with confidence 0.614 at location [341.67, 17.54, 642.32, 278.51]
1680-
Detected a photo of a cat with confidence 0.665 at location [6.75, 38.97, 326.62, 354.85]
1715+
Detected a photo of a cat with confidence 0.614 at location [512.5, 35.08, 963.48, 557.02]
1716+
Detected a photo of a cat with confidence 0.665 at location [10.13, 77.94, 489.93, 709.69]
16811717
```"""
16821718
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
16831719
output_hidden_states = (

0 commit comments

Comments
 (0)