|
1 | 1 | from typing import List, Optional, Tuple, Type, overload |
2 | 2 |
|
3 | 3 | import pytest |
| 4 | +import torch |
4 | 5 | from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, |
5 | 6 | BatchEncoding) |
6 | 7 |
|
| 8 | +from vllm.attention.backends.flash_attn import FlashAttentionMetadata |
7 | 9 | from vllm.attention.selector import (_Backend, _cached_get_attn_backend, |
8 | 10 | global_force_attn_backend_context_manager) |
| 11 | +from vllm.model_executor.models.mllama import (MLLAMA_IMAGE_TOKEN_ID, |
| 12 | + MllamaForConditionalGeneration) |
9 | 13 | from vllm.multimodal.image import rescale_image_size |
10 | 14 | from vllm.sequence import SampleLogprobs |
11 | 15 |
|
|
33 | 37 | "meta-llama/Llama-3.2-11B-Vision-Instruct", |
34 | 38 | ] |
35 | 39 |
|
| 40 | +# Indices for inputs |
| 41 | +TEXT_ONLY = '0' |
| 42 | +IMAGE_AT_BEG = '1' |
| 43 | +IMAGE_AT_MIDDLE = '2' |
| 44 | +TWO_IMAGES = '3' |
| 45 | + |
| 46 | +# Input tokenized |
| 47 | +prompt_data = { |
| 48 | + # Tell me a story |
| 49 | + TEXT_ONLY: [41551, 757, 264, 3446], |
| 50 | + # <|image|> What's the content of this image |
| 51 | + IMAGE_AT_BEG: |
| 52 | + [MLLAMA_IMAGE_TOKEN_ID, 3639, 596, 279, 2262, 315, 420, 2217, 220], |
| 53 | + # Hello <|image|>What' the content of this image |
| 54 | + IMAGE_AT_MIDDLE: |
| 55 | + [9906, 220, MLLAMA_IMAGE_TOKEN_ID, 3923, 6, 279, 2262, 315, 420, 2217], |
| 56 | + #<|image|>Is there a duck in this image?<|image|>What's the animal in this image? # noqa: E501 |
| 57 | + TWO_IMAGES: [ |
| 58 | + MLLAMA_IMAGE_TOKEN_ID, 3957, 1070, 264, 37085, 304, 420, 2217, 30, |
| 59 | + MLLAMA_IMAGE_TOKEN_ID, 3923, 596, 279, 10065, 304, 420, 2217, 30 |
| 60 | + ] |
| 61 | +} |
| 62 | + |
36 | 63 |
|
37 | 64 | def vllm_to_hf_output(vllm_output: Tuple[List[int], str, |
38 | 65 | Optional[SampleLogprobs]], |
@@ -365,3 +392,184 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, |
365 | 392 | num_logprobs=num_logprobs, |
366 | 393 | tensor_parallel_size=1, |
367 | 394 | ) |
| 395 | + |
| 396 | + |
| 397 | +@large_gpu_test(min_gb=48) |
| 398 | +@pytest.mark.core_model |
| 399 | +@pytest.mark.parametrize("model", models) |
| 400 | +@pytest.mark.parametrize("dtype", ["bfloat16"]) |
| 401 | +@pytest.mark.parametrize("max_tokens", [128]) |
| 402 | +@pytest.mark.parametrize("num_logprobs", [5]) |
| 403 | +@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) |
| 404 | +def test_regression(vllm_runner, image_assets, model, dtype, max_tokens, |
| 405 | + num_logprobs, attn_backend: _Backend) -> None: |
| 406 | + |
| 407 | + stop_sign = image_assets[0].pil_image |
| 408 | + |
| 409 | + with global_force_attn_backend_context_manager(attn_backend), vllm_runner( |
| 410 | + model, |
| 411 | + dtype=dtype, |
| 412 | + max_model_len=4096, |
| 413 | + max_num_seqs=2, |
| 414 | + tensor_parallel_size=1, |
| 415 | + enforce_eager=True, |
| 416 | + limit_mm_per_prompt={"image": |
| 417 | + _LIMIT_IMAGE_PER_PROMPT}) as vllm_model: |
| 418 | + |
| 419 | + # Regression tests for https://github.com/vllm-project/vllm/issues/10648 |
| 420 | + |
| 421 | + # Number of image tags is greater than the number of images provided |
| 422 | + prompt = "<|begin_of_text|><|image|><|image|> Compare the two images" # noqa: E501 |
| 423 | + image = stop_sign |
| 424 | + with pytest.raises(ValueError): |
| 425 | + vllm_model.generate_greedy_logprobs([prompt], |
| 426 | + max_tokens, |
| 427 | + num_logprobs, |
| 428 | + images=[image]) |
| 429 | + |
| 430 | + # Batch of a text-only and image request that requires cross-attention |
| 431 | + prompts = [ |
| 432 | + "What is the capital of spain?", |
| 433 | + "Text before the image...<|image|>What is in the image?", # noqa: E501 |
| 434 | + ] |
| 435 | + images = [ |
| 436 | + None, |
| 437 | + [stop_sign], |
| 438 | + ] |
| 439 | + vllm_model.generate_greedy_logprobs(prompts, |
| 440 | + max_tokens, |
| 441 | + num_logprobs, |
| 442 | + images=images) |
| 443 | + |
| 444 | + # Test the reverse order too for good measure |
| 445 | + prompts = [ |
| 446 | + "<|begin_of_text|>Text before the image...<|image|>What is in the image?", # noqa: E501 |
| 447 | + "<|begin_of_text|>Hello!", |
| 448 | + ] |
| 449 | + images = [ |
| 450 | + [stop_sign], |
| 451 | + None, |
| 452 | + ] |
| 453 | + vllm_model.generate_greedy_logprobs(prompts, |
| 454 | + max_tokens, |
| 455 | + num_logprobs, |
| 456 | + images=images) |
| 457 | + |
| 458 | + |
| 459 | +@pytest.mark.core_model |
| 460 | +@pytest.mark.parametrize( |
| 461 | + "input_indices_and_output", |
| 462 | + # inputs, (cross_attention_mask, kv_range_for_decode) |
| 463 | + [([TEXT_ONLY], (None, None)), ([IMAGE_AT_BEG], (None, None)), |
| 464 | + ([TEXT_ONLY, IMAGE_AT_BEG], (None, None)), |
| 465 | + ([IMAGE_AT_MIDDLE], ((10, 12), [[0, 6]])), |
| 466 | + ([TEXT_ONLY, IMAGE_AT_MIDDLE], ((14, 12), [[0, 6]])), |
| 467 | + ([TEXT_ONLY, IMAGE_AT_BEG, IMAGE_AT_MIDDLE], |
| 468 | + ((23, 24), [[0, 6], [6, 12]])), |
| 469 | + ([IMAGE_AT_MIDDLE, TEXT_ONLY], ((14, 12), [[0, 6]])), |
| 470 | + ([TWO_IMAGES], ((18, 12), [[6, 12]])), |
| 471 | + ([TEXT_ONLY, TWO_IMAGES], ((22, 12), [[6, 12]]))]) |
| 472 | +def test_get_cross_attention_mask(input_indices_and_output) -> None: |
| 473 | + |
| 474 | + input_indices, expected_output = input_indices_and_output |
| 475 | + |
| 476 | + sequences = [torch.tensor(prompt_data[i]) for i in input_indices] |
| 477 | + num_tiles = [[2, 2] if i != TEXT_ONLY else [] for i in input_indices |
| 478 | + if i != TEXT_ONLY] |
| 479 | + input = torch.cat(sequences) |
| 480 | + |
| 481 | + seq_lens = [len(s) for s in sequences] |
| 482 | + |
| 483 | + attn_data = FlashAttentionMetadata( |
| 484 | + seq_lens=seq_lens, |
| 485 | + # Dummy values |
| 486 | + enable_kv_scales_calculation=False, |
| 487 | + num_prefills=0, |
| 488 | + num_prefill_tokens=0, |
| 489 | + num_decode_tokens=0, |
| 490 | + slot_mapping=0, |
| 491 | + multi_modal_placeholder_index_maps=None, |
| 492 | + seq_lens_tensor=0, |
| 493 | + max_prefill_seq_len=0, |
| 494 | + max_decode_seq_len=0, |
| 495 | + context_lens_tensor=None, |
| 496 | + block_tables=None, |
| 497 | + use_cuda_graph=False, |
| 498 | + ) |
| 499 | + |
| 500 | + dummy: dict[str, str] = {} |
| 501 | + |
| 502 | + cross_attention_mask, kv_range_for_decode = MllamaForConditionalGeneration\ |
| 503 | + .get_cross_attention_mask(dummy, |
| 504 | + input, |
| 505 | + attn_data, |
| 506 | + num_tiles=num_tiles, |
| 507 | + num_tokens_per_tile=3, |
| 508 | + dtype=torch.bfloat16) |
| 509 | + |
| 510 | + expected_cross_attention_mask, expected_kv_range_for_decode = \ |
| 511 | + expected_output |
| 512 | + |
| 513 | + assert kv_range_for_decode == expected_kv_range_for_decode |
| 514 | + if expected_cross_attention_mask is not None: |
| 515 | + assert cross_attention_mask is not None |
| 516 | + assert cross_attention_mask.shape == expected_cross_attention_mask |
| 517 | + else: |
| 518 | + assert cross_attention_mask is None |
| 519 | + |
| 520 | + |
| 521 | +@pytest.mark.core_model |
| 522 | +@pytest.mark.parametrize( |
| 523 | + "input_indices", |
| 524 | + [[TEXT_ONLY], [IMAGE_AT_BEG], [TEXT_ONLY, IMAGE_AT_BEG], [IMAGE_AT_MIDDLE], |
| 525 | + [TEXT_ONLY, IMAGE_AT_MIDDLE], [TEXT_ONLY, IMAGE_AT_BEG, IMAGE_AT_MIDDLE], |
| 526 | + [IMAGE_AT_MIDDLE, TEXT_ONLY], [TWO_IMAGES], [TEXT_ONLY, TWO_IMAGES]]) |
| 527 | +def test_get_full_text_row_masked_out_mask(input_indices) -> None: |
| 528 | + |
| 529 | + sequences = [torch.tensor(prompt_data[i]) for i in input_indices] |
| 530 | + |
| 531 | + seq_lens = [len(s) for s in sequences] |
| 532 | + |
| 533 | + num_prefill_tokens = sum(seq_lens) |
| 534 | + |
| 535 | + # TEXT_ONLY is zero, so it will be masked out, |
| 536 | + # other instances should not be. |
| 537 | + encoder_seq_lens = [int(i) for i in input_indices] |
| 538 | + |
| 539 | + attn_data = FlashAttentionMetadata( |
| 540 | + seq_lens=seq_lens, |
| 541 | + encoder_seq_lens=encoder_seq_lens, |
| 542 | + num_prefill_tokens=num_prefill_tokens, |
| 543 | + # Dummy values |
| 544 | + enable_kv_scales_calculation=False, |
| 545 | + num_prefills=0, |
| 546 | + num_decode_tokens=0, |
| 547 | + slot_mapping=0, |
| 548 | + multi_modal_placeholder_index_maps=None, |
| 549 | + seq_lens_tensor=0, |
| 550 | + max_prefill_seq_len=0, |
| 551 | + max_decode_seq_len=0, |
| 552 | + context_lens_tensor=None, |
| 553 | + block_tables=None, |
| 554 | + use_cuda_graph=False, |
| 555 | + ) |
| 556 | + |
| 557 | + dummy: dict[str, str] = {} |
| 558 | + |
| 559 | + full_text_row_masked_out_mask = MllamaForConditionalGeneration\ |
| 560 | + .get_full_text_row_masked_out_mask(dummy, |
| 561 | + attn_data, |
| 562 | + torch.get_default_device()) |
| 563 | + |
| 564 | + full_text_row_masked_out_mask = full_text_row_masked_out_mask.squeeze() |
| 565 | + full_text_row_masked_out_mask = full_text_row_masked_out_mask.tolist() |
| 566 | + |
| 567 | + idx = 0 |
| 568 | + assert len(full_text_row_masked_out_mask) == num_prefill_tokens |
| 569 | + for i, seq_len in enumerate(seq_lens): |
| 570 | + must_be_masked = input_indices[i] != TEXT_ONLY |
| 571 | + for _ in range(seq_len): |
| 572 | + assert full_text_row_masked_out_mask[idx] == must_be_masked, \ |
| 573 | + f"full_text_row_masked_out_mask[{idx}] must be " \ |
| 574 | + f"'{must_be_masked}' " |
| 575 | + idx += 1 |
0 commit comments