3939 merge_multimodal_embeddings ,
4040 merge_multimodal_embeddings_from_map )
4141
42- _AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>"
43- _AUDIO_PLACEHOLDER_TOKEN = 128002
44- _AUDIO_TOKENS_PER_SECOND = 6.25
42+ _AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
4543_MAX_ENCODER_BATCH_SIZE = 16
4644
4745
@@ -80,14 +78,15 @@ def get_hf_processor(
8078 sampling_rate : Optional [int ] = None ,
8179 ** kwargs : object ,
8280 ) -> ProcessorMixin :
81+ config = self .ctx .model_config .hf_config
8382 hf_processor = self .ctx .get_hf_processor (** kwargs )
8483
8584 # NOTE: Ultravox processing definition uses '<|eot_id|>' as the
8685 # placeholder that will cause confusion with the actual end of turn
87- # token, thus we override placeholder with a reserved special
88- # token.
86+ # token, thus we override placeholder with a reserved token.
8987 hf_processor .audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
90- hf_processor .audio_replacement_token_id = _AUDIO_PLACEHOLDER_TOKEN
88+ hf_processor .audio_replacement_token_id = config .audio_token_index
89+
9190 return hf_processor
9291
9392 def get_feature_extractor (
@@ -274,7 +273,7 @@ def __init__(self, config: UltravoxConfig):
274273 else :
275274 self .act = get_act_fn (config .projector_act )
276275
277- dim_out = config .text_config . hidden_size
276+ dim_out = config .text_hidden_size
278277 self .linear_2 = nn .Linear (dim_mid , dim_out , bias = False )
279278
280279 # Ultravox v0.4.1 and below use layer_norm after the second linear layer
@@ -572,9 +571,14 @@ def get_input_embeddings(
572571 input_ids : torch .Tensor ,
573572 multimodal_embeddings : Optional [MultiModalEmbeddings ] = None ,
574573 ) -> torch .Tensor :
575- inputs_embeds = self .language_model .get_input_embeddings (input_ids )
576- if multimodal_embeddings is not None \
577- and len (multimodal_embeddings ) != 0 :
574+ # The audio token index is not included in the embedding table
575+ # We need to remove it before embedding lookup
576+ safe_input_ids = input_ids .clone ()
577+ safe_input_ids [safe_input_ids == self .config .audio_token_index ] = 0
578+ inputs_embeds = self .language_model .get_input_embeddings (
579+ safe_input_ids )
580+ if multimodal_embeddings is not None and len (
581+ multimodal_embeddings ) > 0 :
578582
579583 # TODO(ywang96): remove this block after v0 is deprecated.
580584 if not envs .VLLM_USE_V1 :
@@ -585,7 +589,7 @@ def get_input_embeddings(
585589 else :
586590 inputs_embeds = merge_multimodal_embeddings (
587591 input_ids , inputs_embeds , multimodal_embeddings ,
588- _AUDIO_PLACEHOLDER_TOKEN )
592+ self . config . audio_token_index )
589593 return inputs_embeds
590594
591595 def forward (self ,
@@ -623,10 +627,14 @@ def forward(self,
623627 multimodal_embeddings )
624628 input_ids = None
625629
626- hidden_states = self .language_model .model (input_ids ,
627- positions ,
628- intermediate_tensors ,
629- inputs_embeds = inputs_embeds )
630+ language_model = self .language_model
631+ if hasattr (language_model , "language_model" ):
632+ language_model = language_model .language_model
633+
634+ hidden_states = language_model .model (input_ids ,
635+ positions ,
636+ intermediate_tensors ,
637+ inputs_embeds = inputs_embeds )
630638 return hidden_states
631639
632640 def compute_logits (self , hidden_states : torch .Tensor ,
0 commit comments