Skip to content

Commit 55dcce9

Browse files
houseroadluccafongywang96DarkLight1337
authored
Upstream Llama4 Support to Main (#16113)
Signed-off-by: Aston Zhang <[email protected]> Signed-off-by: Chris Thi <[email protected]> Signed-off-by: drisspg <[email protected]> Signed-off-by: Jon Swenson <[email protected]> Signed-off-by: Keyun Tong <[email protected]> Signed-off-by: Lu Fang <[email protected]> Signed-off-by: Xiaodong Wang <[email protected]> Signed-off-by: Yang Chen <[email protected]> Signed-off-by: Ye (Charlotte) Qi <[email protected]> Signed-off-by: Yong Hoon Shin <[email protected]> Signed-off-by: Zijing Liu <[email protected]> Signed-off-by: Lu Fang <[email protected]> Signed-off-by: Lu Fang <[email protected]> Signed-off-by: Lucia Fang <[email protected]> Signed-off-by: Roger Wang <[email protected]> Signed-off-by: DarkLight1337 <[email protected]> Co-authored-by: Lu Fang <[email protected]> Co-authored-by: Roger Wang <[email protected]> Co-authored-by: DarkLight1337 <[email protected]>
1 parent 8017c8d commit 55dcce9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2436
-155
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,8 @@ steps:
389389
- pytest -v -s models/test_transformers.py
390390
- pytest -v -s models/test_registry.py
391391
# V1 Test: https://github.com/vllm-project/vllm/issues/14531
392-
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py
392+
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4'
393+
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4'
393394

394395
- label: Language Models Test (Standard) # 32min
395396
#mirror_hardwares: [amd]

benchmarks/kernels/benchmark_moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,9 @@ def main(args: argparse.Namespace):
553553
intermediate_size = config.moe_intermediate_size
554554
shard_intermediate_size = 2 * intermediate_size // args.tp_size
555555
else:
556+
if not hasattr(config, "hidden_size"):
557+
# Support for llama4
558+
config = config.text_config
556559
# Default: Mixtral.
557560
E = config.num_local_experts
558561
topk = config.num_experts_per_tok

docs/source/models/supported_models.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ vLLM also supports model implementations that are available in Transformers. Thi
2424

2525
To check if the modeling backend is Transformers, you can simply do this:
2626

27-
```python
27+
```python
2828
from vllm import LLM
2929
llm = LLM(model=..., task="generate") # Name or path of your model
3030
llm.apply_model(lambda model: print(type(model)))
@@ -55,7 +55,7 @@ If your model is neither supported natively by vLLM or Transformers, you can sti
5555
Simply set `trust_remote_code=True` and vLLM will run any model on the Model Hub that is compatible with Transformers.
5656
Provided that the model writer implements their model in a compatible way, this means that you can run new models before they are officially supported in Transformers or vLLM!
5757

58-
```python
58+
```python
5959
from vllm import LLM
6060
llm = LLM(model=..., task="generate", trust_remote_code=True) # Name or path of your model
6161
llm.apply_model(lambda model: print(model.__class__))
@@ -850,6 +850,13 @@ See [this page](#generative-models) for more information on how to use generativ
850850
*
851851
* ✅︎
852852
* ✅︎
853+
- * `Llama4ForConditionalGeneration`
854+
* Llama-4-17B-Omni-Instruct
855+
* T + I<sup>+</sup>
856+
* `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc.
857+
*
858+
* ✅︎
859+
* ✅︎
853860
- * `LlavaForConditionalGeneration`
854861
* LLaVA-1.5
855862
* T + I<sup>E+</sup>

examples/offline_inference/audio_language.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
4747
model=model_name,
4848
trust_remote_code=True,
4949
max_model_len=4096,
50-
max_num_seqs=5,
50+
max_num_seqs=2,
5151
limit_mm_per_prompt={"audio": audio_count},
5252
)
5353

examples/offline_inference/vision_language.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,42 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
582582
)
583583

584584

585+
def run_llama4(questions: list[str], modality: str):
586+
assert modality == "image"
587+
588+
model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
589+
590+
engine_args = EngineArgs(
591+
model=model_name,
592+
max_model_len=8192,
593+
max_num_seqs=4,
594+
tensor_parallel_size=8,
595+
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
596+
gpu_memory_utilization=0.4,
597+
)
598+
599+
tokenizer = AutoTokenizer.from_pretrained(model_name)
600+
messages = [[{
601+
"role":
602+
"user",
603+
"content": [{
604+
"type": "image"
605+
}, {
606+
"type": "text",
607+
"text": f"{question}"
608+
}]
609+
}] for question in questions]
610+
prompts = tokenizer.apply_chat_template(messages,
611+
add_generation_prompt=True,
612+
tokenize=False)
613+
stop_token_ids = None
614+
return ModelRequestData(
615+
engine_args=engine_args,
616+
prompts=prompts,
617+
stop_token_ids=stop_token_ids,
618+
)
619+
620+
585621
# Molmo
586622
def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
587623
assert modality == "image"
@@ -907,6 +943,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
907943
"minicpmv": run_minicpmv,
908944
"mistral3": run_mistral3,
909945
"mllama": run_mllama,
946+
"llama4": run_llama4,
910947
"molmo": run_molmo,
911948
"NVLM_D": run_nvlm_d,
912949
"paligemma": run_paligemma,

examples/offline_inference/vision_language_multi_image.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,43 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
253253
)
254254

255255

256+
def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
257+
model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
258+
259+
engine_args = EngineArgs(
260+
model=model_name,
261+
max_model_len=8192,
262+
max_num_seqs=4,
263+
tensor_parallel_size=8,
264+
limit_mm_per_prompt={"image": len(image_urls)},
265+
)
266+
267+
placeholders = [{"type": "image", "image": url} for url in image_urls]
268+
messages = [{
269+
"role":
270+
"user",
271+
"content": [
272+
*placeholders,
273+
{
274+
"type": "text",
275+
"text": question
276+
},
277+
],
278+
}]
279+
280+
processor = AutoProcessor.from_pretrained(model_name)
281+
282+
prompt = processor.apply_chat_template(messages,
283+
tokenize=False,
284+
add_generation_prompt=True)
285+
286+
return ModelRequestData(
287+
engine_args=engine_args,
288+
prompt=prompt,
289+
image_data=[fetch_image(url) for url in image_urls],
290+
)
291+
292+
256293
def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData:
257294
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
258295

@@ -567,6 +604,7 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
567604
"h2ovl_chat": load_h2ovl,
568605
"idefics3": load_idefics3,
569606
"internvl_chat": load_internvl,
607+
"llama4": load_llama4,
570608
"mistral3": load_mistral3,
571609
"mllama": load_mllama,
572610
"NVLM_D": load_nvlm_d,

requirements/common.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ requests >= 2.26.0
66
tqdm
77
blake3
88
py-cpuinfo
9-
transformers >= 4.50.3
9+
transformers >= 4.51.0
1010
huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads.
1111
tokenizers >= 0.19.1 # Required for Llama 3.
1212
protobuf # Required by LlamaTokenizer.

requirements/test.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ mistral_common[opencv] >= 1.5.4 # required for pixtral test
3030
opencv-python-headless >= 4.11.0 # required for video test
3131
datamodel_code_generator # required for minicpm3 test
3232
lm-eval[api]==0.4.8 # required for model evaluation test
33-
transformers==4.50.3
33+
transformers==4.51.0
3434
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
3535
# quantization
3636
bitsandbytes>=0.45.3

requirements/test.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ tqdm==4.66.6
645645
# transformers
646646
tqdm-multiprocess==0.0.11
647647
# via lm-eval
648-
transformers==4.50.3
648+
transformers==4.51.0
649649
# via
650650
# -r requirements/test.in
651651
# genai-perf

tests/models/decoder_only/audio_language/test_ultravox.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from ....conftest import HfRunner, VllmRunner
1414
from ....utils import RemoteOpenAIServer
15+
from ...registry import HF_EXAMPLE_MODELS
1516
from ...utils import check_logprobs_close
1617

1718
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
@@ -55,7 +56,10 @@ def server(request, audio_assets):
5556
for key, value in request.param.items()
5657
]
5758

58-
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
59+
with RemoteOpenAIServer(MODEL_NAME,
60+
args,
61+
env_dict={"VLLM_AUDIO_FETCH_TIMEOUT":
62+
"30"}) as remote_server:
5963
yield remote_server
6064

6165

@@ -106,6 +110,10 @@ def run_test(
106110
**kwargs,
107111
):
108112
"""Inference result should be the same between hf and vllm."""
113+
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
114+
model_info.check_available_online(on_fail="skip")
115+
model_info.check_transformers_version(on_fail="skip")
116+
109117
# NOTE: take care of the order. run vLLM first, and then run HF.
110118
# vLLM needs a fresh new process without cuda initialization.
111119
# if we run HF first, the cuda initialization will be done and it
@@ -156,6 +164,10 @@ def run_multi_audio_test(
156164
num_logprobs: int,
157165
**kwargs,
158166
):
167+
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
168+
model_info.check_available_online(on_fail="skip")
169+
model_info.check_transformers_version(on_fail="skip")
170+
159171
with vllm_runner(model,
160172
dtype=dtype,
161173
enforce_eager=True,

0 commit comments

Comments
 (0)