1- from typing import (Iterable , List , Mapping , Optional , Set , Tuple , TypedDict ,
2- Union )
1+ from typing import (Callable , Iterable , List , Mapping , Optional , Set , Tuple ,
2+ TypedDict , Union )
33
44import torch
55import torch .nn as nn
99from vllm .attention import AttentionMetadata
1010from vllm .config import CacheConfig , QuantizationConfig , VllmConfig
1111from vllm .distributed import get_tensor_model_parallel_rank
12- from vllm .inputs import InputContext
1312from vllm .model_executor .layers .activation import get_act_fn
1413from vllm .model_executor .layers .fused_moe import FusedMoE
1514from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
@@ -87,8 +86,8 @@ def __init__(
8786 def forward (
8887 self ,
8988 pixel_values : torch .Tensor ,
90- pixel_mask : Optional [torch .BoolTensor ] = None ,
91- ) -> Tuple [torch .Tensor , Optional [torch .BoolTensor ]]:
89+ pixel_mask : Optional [torch .Tensor ] = None ,
90+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
9291 patch_attention_mask = self ._create_patch_attention_mask (pixel_mask )
9392
9493 vit_oup = self .vision_model (
@@ -100,7 +99,8 @@ def forward(
10099
101100 return vit_oup , image_atts
102101
103- def _create_patch_attention_mask (self , pixel_mask ):
102+ def _create_patch_attention_mask (
103+ self , pixel_mask : Optional [torch .Tensor ]) -> torch .Tensor :
104104 if pixel_mask is None :
105105 return None
106106
@@ -115,7 +115,8 @@ def _create_patch_attention_mask(self, pixel_mask):
115115 )
116116 return (patches_subgrid .sum (dim = (- 1 , - 2 )) > 0 ).bool ()
117117
118- def _create_image_attention_mask (self , patch_attention_mask ):
118+ def _create_image_attention_mask (
119+ self , patch_attention_mask : torch .Tensor ) -> torch .Tensor :
119120 if patch_attention_mask is None :
120121 return None
121122
@@ -125,13 +126,13 @@ def _create_image_attention_mask(self, patch_attention_mask):
125126
126127class FFN (nn .Module ):
127128
128- def __init__ (self , embed_dim , ff_dim , output_dim ) :
129+ def __init__ (self , embed_dim : int , ff_dim : int , output_dim : int ) -> None :
129130 super ().__init__ ()
130131 self .linear_in = ColumnParallelLinear (embed_dim , ff_dim , bias = False )
131132 self .linear_out = RowParallelLinear (ff_dim , output_dim , bias = False )
132133 self .act = get_act_fn ("gelu_new" )
133134
134- def forward (self , hidden_states ) :
135+ def forward (self , hidden_states : torch . Tensor ) -> torch . Tensor :
135136 hidden_states , _ = self .linear_in (hidden_states )
136137 hidden_states = self .act (hidden_states )
137138 hidden_states , _ = self .linear_out (hidden_states )
@@ -140,7 +141,7 @@ def forward(self, hidden_states):
140141
141142class CrossAttention (nn .Module ):
142143
143- def __init__ (self , kv_dim , embed_dim , num_heads , drop_out_rate = 0 ) :
144+ def __init__ (self , kv_dim : int , embed_dim : int , num_heads : int ) -> None :
144145 super ().__init__ ()
145146 self .num_heads = num_heads
146147 self .q_proj = nn .Linear (embed_dim , embed_dim , bias = False )
@@ -149,12 +150,16 @@ def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0):
149150
150151 self .multihead_attn = nn .MultiheadAttention (embed_dim , num_heads )
151152 self .linear = nn .Linear (embed_dim , embed_dim )
152- self .dropout = nn .Dropout (drop_out_rate )
153153
154154 self .layer_norm = nn .LayerNorm (embed_dim )
155155 self .ln_kv = nn .LayerNorm (kv_dim )
156156
157- def forward (self , x , hidden_states , attn_mask = None , add_residual = False ):
157+ def forward (
158+ self ,
159+ x : torch .Tensor ,
160+ hidden_states : torch .Tensor ,
161+ attn_mask : Optional [torch .Tensor ] = None ,
162+ ) -> torch .Tensor :
158163 normed_hidden_states = self .layer_norm (hidden_states )
159164 query = self .q_proj (normed_hidden_states ).permute (1 , 0 , 2 )
160165
@@ -169,11 +174,7 @@ def forward(self, x, hidden_states, attn_mask=None, add_residual=False):
169174
170175 attn_output = attn_output .permute (1 , 0 , 2 )
171176
172- if add_residual :
173- attn_output = hidden_states + self .dropout (
174- self .linear (attn_output ))
175- else :
176- attn_output = self .dropout (self .linear (attn_output ))
177+ attn_output = self .linear (attn_output )
177178
178179 return attn_output
179180
@@ -201,14 +202,14 @@ class AriaProjector(nn.Module):
201202
202203 def __init__ (
203204 self ,
204- patch_to_query_dict ,
205- embed_dim ,
206- num_heads ,
207- kv_dim ,
208- ff_dim ,
209- output_dim ,
210- norm_layer = nn .LayerNorm ,
211- ):
205+ patch_to_query_dict : dict [ int , int ] ,
206+ embed_dim : int ,
207+ num_heads : int ,
208+ kv_dim : int ,
209+ ff_dim : int ,
210+ output_dim : int ,
211+ norm_layer : Callable [[ int ], nn . Module ] = nn .LayerNorm ,
212+ ) -> None :
212213 super ().__init__ ()
213214 self .patch_to_query_dict = patch_to_query_dict
214215 self .embed_dim = embed_dim
@@ -224,7 +225,11 @@ def __init__(
224225 self .ln_ffn = norm_layer (embed_dim )
225226 self .ffn = FFN (embed_dim , ff_dim , output_dim )
226227
227- def forward (self , x , attn_mask = None ):
228+ def forward (
229+ self ,
230+ x : torch .Tensor ,
231+ attn_mask : Optional [torch .Tensor ] = None ,
232+ ) -> torch .Tensor :
228233 bs = x .shape [0 ]
229234 queries = self .query .unsqueeze (0 ).repeat (bs , 1 , 1 )
230235
@@ -442,12 +447,17 @@ def build_mm_projector(config: PretrainedConfig):
442447 )
443448
444449
445- def get_max_aria_image_tokens (ctx : InputContext ):
446- hf_config = ctx .get_hf_config ()
447- return max (hf_config .projector_patch_to_query_dict .values ())
450+ class AriaMultiModalProcessor (BaseMultiModalProcessor ):
451+
452+ def get_supported_mm_limits (self ) -> Mapping [str , Optional [int ]]:
453+ return {"image" : None }
448454
455+ def _get_num_image_tokens (self ) -> int :
456+ hf_config = self .ctx .get_hf_config ()
457+ return max (hf_config .projector_patch_to_query_dict .values ())
449458
450- class AriaMultiModalProcessor (BaseMultiModalProcessor ):
459+ def get_mm_max_tokens_per_item (self ) -> Mapping [str , int ]:
460+ return {"image" : self ._get_num_image_tokens ()}
451461
452462 def _get_mm_fields_config (
453463 self ,
@@ -468,13 +478,13 @@ def _get_prompt_replacements(
468478 hf_config = self .ctx .get_hf_config ()
469479 image_token_id = hf_config .image_token_index
470480
471- max_image_tokens = get_max_aria_image_tokens ( self .ctx )
481+ num_image_tokens = self ._get_num_image_tokens ( )
472482
473483 return [
474484 PromptReplacement (
475485 modality = "image" ,
476486 target = [image_token_id ],
477- replacement = [image_token_id ] * max_image_tokens ,
487+ replacement = [image_token_id ] * num_image_tokens ,
478488 )
479489 ]
480490
@@ -504,7 +514,6 @@ def _get_dummy_mm_inputs(
504514 )
505515
506516
507- @MULTIMODAL_REGISTRY .register_max_image_tokens (get_max_aria_image_tokens )
508517@MULTIMODAL_REGISTRY .register_processor (AriaMultiModalProcessor )
509518class AriaForConditionalGeneration (nn .Module , SupportsMultiModal ):
510519 """
0 commit comments