@@ -1309,6 +1309,31 @@ def _parse_and_validate_image_input(self, **kwargs: object):
13091309
13101310 raise AssertionError ("This line should be unreachable." )
13111311
1312+ def _get_and_validate_encoder_lens (
1313+ self ,
1314+ encoder_seq_lens : List [int ],
1315+ num_tiles : List [List [int ]],
1316+ num_tokens_per_tile : int ,
1317+ ) -> List [int ]:
1318+ # Get the actual number of encoder tokens for each sample.
1319+ # Because attn_metadata.encoder_seq_lens only counts the last
1320+ # group of images for each sample, which is used to cheat the
1321+ # block manager to allocate blocks for those images only.
1322+ # See input_processor_for_mllama() for more details.
1323+ actual_encoder_seq_lens = [
1324+ sum (num_tile ) * num_tokens_per_tile for num_tile in num_tiles
1325+ ]
1326+
1327+ # remove 0 encoder len entries for text-only requests for these
1328+ # assertions
1329+ attn_metadata_lens = [len for len in encoder_seq_lens if len > 0 ]
1330+ assert len (actual_encoder_seq_lens ) == len (attn_metadata_lens )
1331+ for actual_len , last_group_len in zip (actual_encoder_seq_lens ,
1332+ attn_metadata_lens ):
1333+ assert actual_len >= last_group_len
1334+
1335+ return actual_encoder_seq_lens
1336+
13121337 def flat_encoder_result (self , cross_attention_states : torch .Tensor ,
13131338 attn_metadata : AttentionMetadata ,
13141339 actual_encoder_seq_lens : List [int ]):
@@ -1436,26 +1461,14 @@ def forward(
14361461 else :
14371462 skip_cross_attention = False
14381463
1439- # Get the actual number of encoder tokens for each sample.
1440- # Because attn_metadata.encoder_seq_lens only counts the last
1441- # group of images for each sample, which is used to cheat the
1442- # block manager to allocate blocks for those images only.
1443- # See MllamaMultiModalProcessor for more details.
1444- num_tiles_tensor = kwargs .pop ("num_tiles" )
1445- num_tiles = [t .tolist () for t in num_tiles_tensor ]
1464+ num_tiles = [t .tolist () for t in kwargs .pop ("num_tiles" )]
14461465 num_tokens_per_tile = calc_token_per_chunk (self .image_size )
1447- actual_encoder_seq_lens = [
1448- sum (num_tile ) * num_tokens_per_tile for num_tile in num_tiles
1449- ]
14501466
1451- # remove 0 entries for text-only requests for these assertions
1452- attn_metadata_lens = [
1453- len for len in attn_metadata .encoder_seq_lens if len > 0
1454- ]
1455- assert len (actual_encoder_seq_lens ) == len (attn_metadata_lens )
1456- for actual_len , last_group_len in zip (actual_encoder_seq_lens ,
1457- attn_metadata_lens ):
1458- assert actual_len >= last_group_len
1467+ actual_encoder_seq_lens = self ._get_and_validate_encoder_lens (
1468+ attn_metadata .encoder_seq_lens ,
1469+ num_tiles ,
1470+ num_tokens_per_tile ,
1471+ )
14591472
14601473 cross_attention_states = self .get_cross_attention_states (
14611474 image_inputs , attn_metadata , actual_encoder_seq_lens )
0 commit comments