Skip to content

Commit 4c5d8e8

Browse files
authored
[Bugfix] Fix phi3v batch inference when images have different aspect ratio (#7392)
1 parent baa2402 commit 4c5d8e8

File tree

4 files changed

+25
-19
lines changed

4 files changed

+25
-19
lines changed

tests/models/test_phi3v.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,10 @@ def run_test(
8181

8282
inputs_per_image = [(
8383
[prompt for _ in size_factors],
84-
[rescale_image_size(image, factor) for factor in size_factors],
84+
[
85+
rescale_image_size(image, factor, transpose=idx)
86+
for idx, factor in enumerate(size_factors)
87+
],
8588
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
8689

8790
# NOTE: take care of the order. run vLLM first, and then run HF.

tests/tracing/test_tracing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,5 +114,5 @@ def test_traces(trace_service):
114114
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN) == ttft
115115
e2e_time = metrics.finished_time - metrics.arrival_time
116116
assert attributes.get(SpanAttributes.LLM_LATENCY_E2E) == e2e_time
117-
assert attributes.get(SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER
118-
) == metrics.scheduler_time
117+
assert attributes.get(
118+
SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER) == metrics.scheduler_time

vllm/model_executor/models/phi3v.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def hd_feature_transform(self, image_features, image_sizes):
189189
global_image_features_hd_newline = self.add_image_newline(
190190
global_image_features_hd)
191191

192-
all_image_embeddings = []
192+
batch_image_features_proj = []
193193
# need a for loop to process each image because of different image sizes
194194
# (patch arrangement is different for each image)
195195
for i, img_size in enumerate(image_sizes):
@@ -207,19 +207,17 @@ def hd_feature_transform(self, image_features, image_sizes):
207207
sub_image_features_hd)
208208

209209
# [sub features, separator, global features]
210-
all_image_embeddings.append(
211-
torch.cat([
212-
sub_image_features_hd_newline.squeeze(
213-
0), # (h_crop*12*(w_crop*12+1), 4096)
214-
self.glb_GN.squeeze(0),
215-
global_image_features_hd_newline[i],
216-
]))
217-
218-
image_features_proj = self.img_projection(
219-
torch.stack(all_image_embeddings).to(target_device, target_dtype)
220-
) # (num_images, (h_crop*12*(w_crop*12+1)+1), hidden_size)
221-
222-
return image_features_proj
210+
image_embeddings = torch.cat([
211+
sub_image_features_hd_newline.squeeze(
212+
0), # (h_crop*12*(w_crop*12+1), 4096)
213+
self.glb_GN.squeeze(0),
214+
global_image_features_hd_newline[i],
215+
])
216+
img_proj = self.img_projection(
217+
image_embeddings.to(target_device, target_dtype))
218+
batch_image_features_proj.append(img_proj)
219+
220+
return batch_image_features_proj
223221

224222
def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
225223
"""

vllm/multimodal/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,13 @@ def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
9090
return _load_image_from_bytes(base64.b64decode(image))
9191

9292

93-
def rescale_image_size(image: Image.Image, size_factor: float) -> Image.Image:
93+
def rescale_image_size(image: Image.Image,
94+
size_factor: float,
95+
transpose: int = -1) -> Image.Image:
9496
"""Rescale the dimensions of an image by a constant factor."""
9597
new_width = int(image.width * size_factor)
9698
new_height = int(image.height * size_factor)
97-
return image.resize((new_width, new_height))
99+
image = image.resize((new_width, new_height))
100+
if transpose >= 0:
101+
image = image.transpose(Image.Transpose(transpose))
102+
return image

0 commit comments

Comments
 (0)