Skip to content

Commit 62965de

Browse files
farzadabliPatrick
andauthored
[Model] Ultravox: Support Llama 4 and Gemma 3 backends (#17818)
Signed-off-by: Farzad Abdolhosseini <[email protected]> Signed-off-by: Patrick Li <[email protected]> Co-authored-by: Patrick Li <[email protected]>
1 parent 7ae75fa commit 62965de

File tree

4 files changed

+39
-24
lines changed

4 files changed

+39
-24
lines changed

tests/models/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ def check_available_online(
221221
"fp8": "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"}), # noqa: E501
222222
"LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf",
223223
is_available_online=False),
224+
"Llama4ForCausalLM": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
225+
is_available_online=False),
224226
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
225227
"Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1"),
226228
"FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
9090
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
9191
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
92+
"Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # noqa: E501
9293
# For decapoda-research/llama-*
9394
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
9495
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),

vllm/model_executor/models/ultravox.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,7 @@
3939
merge_multimodal_embeddings,
4040
merge_multimodal_embeddings_from_map)
4141

42-
_AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>"
43-
_AUDIO_PLACEHOLDER_TOKEN = 128002
44-
_AUDIO_TOKENS_PER_SECOND = 6.25
42+
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
4543
_MAX_ENCODER_BATCH_SIZE = 16
4644

4745

@@ -80,14 +78,15 @@ def get_hf_processor(
8078
sampling_rate: Optional[int] = None,
8179
**kwargs: object,
8280
) -> ProcessorMixin:
81+
config = self.ctx.model_config.hf_config
8382
hf_processor = self.ctx.get_hf_processor(**kwargs)
8483

8584
# NOTE: Ultravox processing definition uses '<|eot_id|>' as the
8685
# placeholder that will cause confusion with the actual end of turn
87-
# token, thus we override placeholder with a reserved special
88-
# token.
86+
# token, thus we override placeholder with a reserved token.
8987
hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
90-
hf_processor.audio_replacement_token_id = _AUDIO_PLACEHOLDER_TOKEN
88+
hf_processor.audio_replacement_token_id = config.audio_token_index
89+
9190
return hf_processor
9291

9392
def get_feature_extractor(
@@ -274,7 +273,7 @@ def __init__(self, config: UltravoxConfig):
274273
else:
275274
self.act = get_act_fn(config.projector_act)
276275

277-
dim_out = config.text_config.hidden_size
276+
dim_out = config.text_hidden_size
278277
self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)
279278

280279
# Ultravox v0.4.1 and below use layer_norm after the second linear layer
@@ -572,9 +571,14 @@ def get_input_embeddings(
572571
input_ids: torch.Tensor,
573572
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
574573
) -> torch.Tensor:
575-
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
576-
if multimodal_embeddings is not None \
577-
and len(multimodal_embeddings) != 0:
574+
# The audio token index is not included in the embedding table
575+
# We need to remove it before embedding lookup
576+
safe_input_ids = input_ids.clone()
577+
safe_input_ids[safe_input_ids == self.config.audio_token_index] = 0
578+
inputs_embeds = self.language_model.get_input_embeddings(
579+
safe_input_ids)
580+
if multimodal_embeddings is not None and len(
581+
multimodal_embeddings) > 0:
578582

579583
# TODO(ywang96): remove this block after v0 is deprecated.
580584
if not envs.VLLM_USE_V1:
@@ -585,7 +589,7 @@ def get_input_embeddings(
585589
else:
586590
inputs_embeds = merge_multimodal_embeddings(
587591
input_ids, inputs_embeds, multimodal_embeddings,
588-
_AUDIO_PLACEHOLDER_TOKEN)
592+
self.config.audio_token_index)
589593
return inputs_embeds
590594

591595
def forward(self,
@@ -623,10 +627,14 @@ def forward(self,
623627
multimodal_embeddings)
624628
input_ids = None
625629

626-
hidden_states = self.language_model.model(input_ids,
627-
positions,
628-
intermediate_tensors,
629-
inputs_embeds=inputs_embeds)
630+
language_model = self.language_model
631+
if hasattr(language_model, "language_model"):
632+
language_model = language_model.language_model
633+
634+
hidden_states = language_model.model(input_ids,
635+
positions,
636+
intermediate_tensors,
637+
inputs_embeds=inputs_embeds)
630638
return hidden_states
631639

632640
def compute_logits(self, hidden_states: torch.Tensor,

vllm/transformers_utils/configs/ultravox.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
4545
"""
4646

4747
model_type = "ultravox"
48+
audio_token = "<|audio|>"
4849
is_composition = False
4950

5051
def __init__(
@@ -80,29 +81,32 @@ def __init__(
8081
# Avoid circular import
8182
from vllm.transformers_utils.config import get_config
8283

83-
self.text_config = get_config(text_model_id,
84-
trust_remote_code=False)
84+
text_config_obj = get_config(text_model_id,
85+
trust_remote_code=False)
8586
else:
8687
text_config = text_config or {}
87-
self.text_config = transformers.CONFIG_MAPPING[text_config.get(
88+
text_config_obj = transformers.CONFIG_MAPPING[text_config.get(
8889
"model_type", "llama")](**text_config)
8990

91+
inner_text_config = text_config_obj.get_text_config()
92+
9093
if audio_model_id is not None:
9194
# Avoid circular import
9295
from vllm.transformers_utils.config import get_config
9396

94-
self.audio_config = get_config(audio_model_id,
95-
trust_remote_code=False)
97+
audio_config = get_config(audio_model_id, trust_remote_code=False)
9698
else:
9799
audio_config = audio_config or {}
98-
self.audio_config = transformers.CONFIG_MAPPING[audio_config.get(
100+
audio_config = transformers.CONFIG_MAPPING[audio_config.get(
99101
"model_type", "whisper")](**audio_config)
100102

103+
self.text_config = text_config_obj
104+
self.audio_config = audio_config
101105
self.text_model_lora_config = text_model_lora_config or {}
102106
self.audio_model_lora_config = audio_model_lora_config or {}
103107

104-
self.vocab_size = self.text_config.vocab_size
105-
106-
self.initializer_range = self.text_config.initializer_range
108+
self.vocab_size = inner_text_config.vocab_size
109+
self.initializer_range = inner_text_config.initializer_range
110+
self.text_hidden_size = inner_text_config.hidden_size
107111

108112
super().__init__(**kwargs)

0 commit comments

Comments
 (0)