Skip to content

Commit 62b7772

Browse files
committed
little refactor and add CI test
Signed-off-by: Travis Johnson <[email protected]>
1 parent f9582a1 commit 62b7772

File tree

2 files changed

+54
-18
lines changed

2 files changed

+54
-18
lines changed

tests/models/encoder_decoder/vision_language/test_mllama.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,3 +691,26 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None:
691691
f"full_text_row_masked_out_mask[{idx}] must be " \
692692
f"'{must_be_masked}' "
693693
idx += 1
694+
695+
696+
@pytest.mark.core_model
697+
@pytest.mark.parametrize("encoder_seq_lens, num_tiles, expected", [
698+
([6404], [[4]], [6404]),
699+
([0, 6404], [[4]], [6404]),
700+
([0, 1601, 8005], [[1], [4, 1]], [1601, 8005]),
701+
([0, 19212, 0, 3202], [[4, 4, 4], [2]], [19212, 3202]),
702+
])
703+
def test_parse_and_validate_encoder_lens(encoder_seq_lens, num_tiles,
704+
expected) -> None:
705+
706+
dummy = DummyModel()
707+
num_tokens_per_tile = 1601
708+
actual_encoder_seq_lens = MllamaForConditionalGeneration \
709+
._get_and_validate_encoder_lens(
710+
dummy,
711+
encoder_seq_lens,
712+
num_tiles,
713+
num_tokens_per_tile,
714+
)
715+
assert actual_encoder_seq_lens == expected, \
716+
f"Expected {expected} but got {actual_encoder_seq_lens}"

vllm/model_executor/models/mllama.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,6 +1309,31 @@ def _parse_and_validate_image_input(self, **kwargs: object):
13091309

13101310
raise AssertionError("This line should be unreachable.")
13111311

1312+
def _get_and_validate_encoder_lens(
1313+
self,
1314+
encoder_seq_lens: List[int],
1315+
num_tiles: List[List[int]],
1316+
num_tokens_per_tile: int,
1317+
) -> List[int]:
1318+
# Get the actual number of encoder tokens for each sample.
1319+
# Because attn_metadata.encoder_seq_lens only counts the last
1320+
# group of images for each sample, which is used to cheat the
1321+
# block manager to allocate blocks for those images only.
1322+
# See input_processor_for_mllama() for more details.
1323+
actual_encoder_seq_lens = [
1324+
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
1325+
]
1326+
1327+
# remove 0 encoder len entries for text-only requests for these
1328+
# assertions
1329+
attn_metadata_lens = [len for len in encoder_seq_lens if len > 0]
1330+
assert len(actual_encoder_seq_lens) == len(attn_metadata_lens)
1331+
for actual_len, last_group_len in zip(actual_encoder_seq_lens,
1332+
attn_metadata_lens):
1333+
assert actual_len >= last_group_len
1334+
1335+
return actual_encoder_seq_lens
1336+
13121337
def flat_encoder_result(self, cross_attention_states: torch.Tensor,
13131338
attn_metadata: AttentionMetadata,
13141339
actual_encoder_seq_lens: List[int]):
@@ -1436,26 +1461,14 @@ def forward(
14361461
else:
14371462
skip_cross_attention = False
14381463

1439-
# Get the actual number of encoder tokens for each sample.
1440-
# Because attn_metadata.encoder_seq_lens only counts the last
1441-
# group of images for each sample, which is used to cheat the
1442-
# block manager to allocate blocks for those images only.
1443-
# See MllamaMultiModalProcessor for more details.
1444-
num_tiles_tensor = kwargs.pop("num_tiles")
1445-
num_tiles = [t.tolist() for t in num_tiles_tensor]
1464+
num_tiles = [t.tolist() for t in kwargs.pop("num_tiles")]
14461465
num_tokens_per_tile = calc_token_per_chunk(self.image_size)
1447-
actual_encoder_seq_lens = [
1448-
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
1449-
]
14501466

1451-
# remove 0 entries for text-only requests for these assertions
1452-
attn_metadata_lens = [
1453-
len for len in attn_metadata.encoder_seq_lens if len > 0
1454-
]
1455-
assert len(actual_encoder_seq_lens) == len(attn_metadata_lens)
1456-
for actual_len, last_group_len in zip(actual_encoder_seq_lens,
1457-
attn_metadata_lens):
1458-
assert actual_len >= last_group_len
1467+
actual_encoder_seq_lens = self._get_and_validate_encoder_lens(
1468+
attn_metadata.encoder_seq_lens,
1469+
num_tiles,
1470+
num_tokens_per_tile,
1471+
)
14591472

14601473
cross_attention_states = self.get_cross_attention_states(
14611474
image_inputs, attn_metadata, actual_encoder_seq_lens)

0 commit comments

Comments
 (0)