Skip to content

Commit 036ca94

Browse files
[Bugfix] handle alignment of arguments in convert_sparse_cross_attention_mask_to_dense (#12347)
Signed-off-by: Travis Johnson <[email protected]> Signed-off-by: Wallas Santos <[email protected]> Co-authored-by: Wallas Santos <[email protected]>
1 parent ef001d9 commit 036ca94

File tree

2 files changed

+222
-4
lines changed

2 files changed

+222
-4
lines changed

tests/models/encoder_decoder/vision_language/test_mllama.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from typing import List, Optional, Tuple, Type, overload
22

33
import pytest
4+
import torch
45
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
56
BatchEncoding)
67

8+
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
79
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
810
global_force_attn_backend_context_manager)
11+
from vllm.model_executor.models.mllama import (MLLAMA_IMAGE_TOKEN_ID,
12+
MllamaForConditionalGeneration)
913
from vllm.multimodal.image import rescale_image_size
1014
from vllm.sequence import SampleLogprobs
1115

@@ -33,6 +37,29 @@
3337
"meta-llama/Llama-3.2-11B-Vision-Instruct",
3438
]
3539

40+
# Indices for inputs
41+
TEXT_ONLY = '0'
42+
IMAGE_AT_BEG = '1'
43+
IMAGE_AT_MIDDLE = '2'
44+
TWO_IMAGES = '3'
45+
46+
# Input tokenized
47+
prompt_data = {
48+
# Tell me a story
49+
TEXT_ONLY: [41551, 757, 264, 3446],
50+
# <|image|> What's the content of this image
51+
IMAGE_AT_BEG:
52+
[MLLAMA_IMAGE_TOKEN_ID, 3639, 596, 279, 2262, 315, 420, 2217, 220],
53+
# Hello <|image|>What' the content of this image
54+
IMAGE_AT_MIDDLE:
55+
[9906, 220, MLLAMA_IMAGE_TOKEN_ID, 3923, 6, 279, 2262, 315, 420, 2217],
56+
#<|image|>Is there a duck in this image?<|image|>What's the animal in this image? # noqa: E501
57+
TWO_IMAGES: [
58+
MLLAMA_IMAGE_TOKEN_ID, 3957, 1070, 264, 37085, 304, 420, 2217, 30,
59+
MLLAMA_IMAGE_TOKEN_ID, 3923, 596, 279, 10065, 304, 420, 2217, 30
60+
]
61+
}
62+
3663

3764
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
3865
Optional[SampleLogprobs]],
@@ -365,3 +392,184 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
365392
num_logprobs=num_logprobs,
366393
tensor_parallel_size=1,
367394
)
395+
396+
397+
@large_gpu_test(min_gb=48)
398+
@pytest.mark.core_model
399+
@pytest.mark.parametrize("model", models)
400+
@pytest.mark.parametrize("dtype", ["bfloat16"])
401+
@pytest.mark.parametrize("max_tokens", [128])
402+
@pytest.mark.parametrize("num_logprobs", [5])
403+
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
404+
def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
405+
num_logprobs, attn_backend: _Backend) -> None:
406+
407+
stop_sign = image_assets[0].pil_image
408+
409+
with global_force_attn_backend_context_manager(attn_backend), vllm_runner(
410+
model,
411+
dtype=dtype,
412+
max_model_len=4096,
413+
max_num_seqs=2,
414+
tensor_parallel_size=1,
415+
enforce_eager=True,
416+
limit_mm_per_prompt={"image":
417+
_LIMIT_IMAGE_PER_PROMPT}) as vllm_model:
418+
419+
# Regression tests for https://github.com/vllm-project/vllm/issues/10648
420+
421+
# Number of image tags is greater than the number of images provided
422+
prompt = "<|begin_of_text|><|image|><|image|> Compare the two images" # noqa: E501
423+
image = stop_sign
424+
with pytest.raises(ValueError):
425+
vllm_model.generate_greedy_logprobs([prompt],
426+
max_tokens,
427+
num_logprobs,
428+
images=[image])
429+
430+
# Batch of a text-only and image request that requires cross-attention
431+
prompts = [
432+
"What is the capital of spain?",
433+
"Text before the image...<|image|>What is in the image?", # noqa: E501
434+
]
435+
images = [
436+
None,
437+
[stop_sign],
438+
]
439+
vllm_model.generate_greedy_logprobs(prompts,
440+
max_tokens,
441+
num_logprobs,
442+
images=images)
443+
444+
# Test the reverse order too for good measure
445+
prompts = [
446+
"<|begin_of_text|>Text before the image...<|image|>What is in the image?", # noqa: E501
447+
"<|begin_of_text|>Hello!",
448+
]
449+
images = [
450+
[stop_sign],
451+
None,
452+
]
453+
vllm_model.generate_greedy_logprobs(prompts,
454+
max_tokens,
455+
num_logprobs,
456+
images=images)
457+
458+
459+
@pytest.mark.core_model
460+
@pytest.mark.parametrize(
461+
"input_indices_and_output",
462+
# inputs, (cross_attention_mask, kv_range_for_decode)
463+
[([TEXT_ONLY], (None, None)), ([IMAGE_AT_BEG], (None, None)),
464+
([TEXT_ONLY, IMAGE_AT_BEG], (None, None)),
465+
([IMAGE_AT_MIDDLE], ((10, 12), [[0, 6]])),
466+
([TEXT_ONLY, IMAGE_AT_MIDDLE], ((14, 12), [[0, 6]])),
467+
([TEXT_ONLY, IMAGE_AT_BEG, IMAGE_AT_MIDDLE],
468+
((23, 24), [[0, 6], [6, 12]])),
469+
([IMAGE_AT_MIDDLE, TEXT_ONLY], ((14, 12), [[0, 6]])),
470+
([TWO_IMAGES], ((18, 12), [[6, 12]])),
471+
([TEXT_ONLY, TWO_IMAGES], ((22, 12), [[6, 12]]))])
472+
def test_get_cross_attention_mask(input_indices_and_output) -> None:
473+
474+
input_indices, expected_output = input_indices_and_output
475+
476+
sequences = [torch.tensor(prompt_data[i]) for i in input_indices]
477+
num_tiles = [[2, 2] if i != TEXT_ONLY else [] for i in input_indices
478+
if i != TEXT_ONLY]
479+
input = torch.cat(sequences)
480+
481+
seq_lens = [len(s) for s in sequences]
482+
483+
attn_data = FlashAttentionMetadata(
484+
seq_lens=seq_lens,
485+
# Dummy values
486+
enable_kv_scales_calculation=False,
487+
num_prefills=0,
488+
num_prefill_tokens=0,
489+
num_decode_tokens=0,
490+
slot_mapping=0,
491+
multi_modal_placeholder_index_maps=None,
492+
seq_lens_tensor=0,
493+
max_prefill_seq_len=0,
494+
max_decode_seq_len=0,
495+
context_lens_tensor=None,
496+
block_tables=None,
497+
use_cuda_graph=False,
498+
)
499+
500+
dummy: dict[str, str] = {}
501+
502+
cross_attention_mask, kv_range_for_decode = MllamaForConditionalGeneration\
503+
.get_cross_attention_mask(dummy,
504+
input,
505+
attn_data,
506+
num_tiles=num_tiles,
507+
num_tokens_per_tile=3,
508+
dtype=torch.bfloat16)
509+
510+
expected_cross_attention_mask, expected_kv_range_for_decode = \
511+
expected_output
512+
513+
assert kv_range_for_decode == expected_kv_range_for_decode
514+
if expected_cross_attention_mask is not None:
515+
assert cross_attention_mask is not None
516+
assert cross_attention_mask.shape == expected_cross_attention_mask
517+
else:
518+
assert cross_attention_mask is None
519+
520+
521+
@pytest.mark.core_model
522+
@pytest.mark.parametrize(
523+
"input_indices",
524+
[[TEXT_ONLY], [IMAGE_AT_BEG], [TEXT_ONLY, IMAGE_AT_BEG], [IMAGE_AT_MIDDLE],
525+
[TEXT_ONLY, IMAGE_AT_MIDDLE], [TEXT_ONLY, IMAGE_AT_BEG, IMAGE_AT_MIDDLE],
526+
[IMAGE_AT_MIDDLE, TEXT_ONLY], [TWO_IMAGES], [TEXT_ONLY, TWO_IMAGES]])
527+
def test_get_full_text_row_masked_out_mask(input_indices) -> None:
528+
529+
sequences = [torch.tensor(prompt_data[i]) for i in input_indices]
530+
531+
seq_lens = [len(s) for s in sequences]
532+
533+
num_prefill_tokens = sum(seq_lens)
534+
535+
# TEXT_ONLY is zero, so it will be masked out,
536+
# other instances should not be.
537+
encoder_seq_lens = [int(i) for i in input_indices]
538+
539+
attn_data = FlashAttentionMetadata(
540+
seq_lens=seq_lens,
541+
encoder_seq_lens=encoder_seq_lens,
542+
num_prefill_tokens=num_prefill_tokens,
543+
# Dummy values
544+
enable_kv_scales_calculation=False,
545+
num_prefills=0,
546+
num_decode_tokens=0,
547+
slot_mapping=0,
548+
multi_modal_placeholder_index_maps=None,
549+
seq_lens_tensor=0,
550+
max_prefill_seq_len=0,
551+
max_decode_seq_len=0,
552+
context_lens_tensor=None,
553+
block_tables=None,
554+
use_cuda_graph=False,
555+
)
556+
557+
dummy: dict[str, str] = {}
558+
559+
full_text_row_masked_out_mask = MllamaForConditionalGeneration\
560+
.get_full_text_row_masked_out_mask(dummy,
561+
attn_data,
562+
torch.get_default_device())
563+
564+
full_text_row_masked_out_mask = full_text_row_masked_out_mask.squeeze()
565+
full_text_row_masked_out_mask = full_text_row_masked_out_mask.tolist()
566+
567+
idx = 0
568+
assert len(full_text_row_masked_out_mask) == num_prefill_tokens
569+
for i, seq_len in enumerate(seq_lens):
570+
must_be_masked = input_indices[i] != TEXT_ONLY
571+
for _ in range(seq_len):
572+
assert full_text_row_masked_out_mask[idx] == must_be_masked, \
573+
f"full_text_row_masked_out_mask[{idx}] must be " \
574+
f"'{must_be_masked}' "
575+
idx += 1

vllm/model_executor/models/mllama.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1485,14 +1485,23 @@ def convert_sparse_cross_attention_mask_to_dense(
14851485
total_length = sum(lengths)
14861486
total_tiles = sum([sum(tiles) for tiles in num_tiles])
14871487
dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64)
1488-
# A list of ranges, range[i] = [start, end] means
1489-
# if the i-th sample has N tiles in total, the tiles[start, end]
1490-
# will be used for cross-attention decoding.
1488+
# A list of ranges, range[i] = [start, end] means that the i-th image will
1489+
# use tiles[start, end] for cross-attention decoding.
14911490
tile_range_for_decode = []
14921491

14931492
seq_start = 0
14941493
tile_start = 0
1495-
for masks, tiles, length in zip(sparse_mask, num_tiles, lengths):
1494+
1495+
# sparse_mask has an [] entry for each sequence that does not have images,
1496+
# but num_tiles does not have these entries...
1497+
num_tiles_idx = 0
1498+
for masks, length in zip(sparse_mask, lengths):
1499+
if len(masks) == 0:
1500+
# Text only
1501+
continue
1502+
1503+
tiles = num_tiles[num_tiles_idx]
1504+
num_tiles_idx += 1
14961505
ts, td = -1, 0
14971506
for mask, tile in zip(masks, tiles):
14981507
if len(mask) != 2:
@@ -1512,6 +1521,7 @@ def convert_sparse_cross_attention_mask_to_dense(
15121521
assert td != 0
15131522
tile_range_for_decode.append((ts, ts + td))
15141523
seq_start += length
1524+
assert num_tiles_idx == len(num_tiles)
15151525

15161526
return dense_mask, tile_range_for_decode
15171527

0 commit comments

Comments
 (0)