Skip to content

Conversation

houseroad
Copy link
Collaborator

@houseroad houseroad commented Apr 6, 2025

As a follow up of #16104, we upstream the Llama4 support to the main branch.

The goal of this PR:

  • Support the llama4
  • Clean up some hacks

More enhancement will be tracked in #16114.

Fixes from v0.8.3:

  • Fix missing sampler
  • Fix failing CI on transformers==4.51.0

FIX #16177

@github-actions
Copy link

github-actions bot commented Apr 6, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added documentation Improvements or additions to documentation ci/build frontend multi-modality Related to multi-modality (#4194) v1 labels Apr 6, 2025
@DarkLight1337
Copy link
Member

Can you apply #16104 (comment)? Thanks

Co-authored-by: Aston Zhang <[email protected]>
Co-authored-by: Chris Thi <[email protected]>
Co-authored-by: drisspg <[email protected]>
Co-authored-by: Jon Swenson <[email protected]>
Co-authored-by: Keyun Tong <[email protected]>
Co-authored-by: Lu Fang <[email protected]>
Co-authored-by: Lu Fang <[email protected]>
Co-authored-by: Xiaodong Wang <[email protected]>
Co-authored-by: Yang Chen <[email protected]>
Co-authored-by: Ye (Charlotte) Qi <[email protected]>
Co-authored-by: Yong Hoon Shin <[email protected]>
Co-authored-by: Zijing Liu <[email protected]>

Signed-off-by: Aston Zhang <[email protected]>
Signed-off-by: Chris Thi <[email protected]>
Signed-off-by: drisspg <[email protected]>
Signed-off-by: Jon Swenson <[email protected]>
Signed-off-by: Keyun Tong <[email protected]>
Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Xiaodong Wang <[email protected]>
Signed-off-by: Yang Chen <[email protected]>
Signed-off-by: Ye (Charlotte) Qi <[email protected]>
Signed-off-by: Yong Hoon Shin <[email protected]>
Signed-off-by: Zijing Liu <[email protected]>
Signed-off-by: Lu Fang <[email protected]>
@houseroad houseroad force-pushed the merge_init_pr_main branch from e8e3bbb to 83cdc27 Compare April 6, 2025 04:37
@ywang96
Copy link
Member

ywang96 commented Apr 6, 2025

Since we're not in a huge rush to merge this PR into main, a few action items (and notes to myself):

  • Fix model init test with llama4 since irope is only supported on V1 but the test is forced to be on V0
  • Remove Llama4ForCausalLM from registry and update mllama4.py accordingly
  • Remove the usage of BaseModelOutput in mllama4.py
  • Remove the hacks we added in test_registry.py for Llama4ForCausalLM
  • Add PP to the model?

@DarkLight1337
Copy link
Member

Fix model init test with llama4 since irope is only supported on V1 but the test is forced to be on V0

This one has been addressed by #16112 already

Signed-off-by: Lu Fang <[email protected]>
@luccafong luccafong force-pushed the merge_init_pr_main branch from 802c2b6 to 4276ac0 Compare April 6, 2025 05:42
Signed-off-by: Lucia Fang <[email protected]>
@luccafong
Copy link
Collaborator

Since we're not in a huge rush to merge this PR into main, a few action items (and notes to myself):

  • Fix model init test with llama4 since irope is only supported on V1 but the test is forced to be on V0
  • Remove Llama4ForCausalLM from registry and update mllama4.py accordingly
  • Remove the usage of BaseModelOutput in mllama4.py
  • Remove the hacks we added in test_models.py for Llama4ForCausalLM
  • Add PP to the model?

These 2 are addressed in latest commits

  • Remove Llama4ForCausalLM from registry and update mllama4.py accordingly
  • Remove the usage of BaseModelOutput in mllama4.py

@luccafong
Copy link
Collaborator

Since we're not in a huge rush to merge this PR into main, a few action items (and notes to myself):

  • Fix model init test with llama4 since irope is only supported on V1 but the test is forced to be on V0
  • Remove Llama4ForCausalLM from registry and update mllama4.py accordingly
  • Remove the usage of BaseModelOutput in mllama4.py
  • Remove the hacks we added in test_models.py for Llama4ForCausalLM
  • Add PP to the model?

PP inherits from llama.py for text part, so is already there, mm computation is relatively small.

@ywang96
Copy link
Member

ywang96 commented Apr 6, 2025

Since we're not in a huge rush to merge this PR into main, a few action items (and notes to myself):

  • Fix model init test with llama4 since irope is only supported on V1 but the test is forced to be on V0
  • Remove Llama4ForCausalLM from registry and update mllama4.py accordingly
  • Remove the usage of BaseModelOutput in mllama4.py
  • Remove the hacks we added in test_models.py for Llama4ForCausalLM
  • Add PP to the model?

These 2 are addressed in latest commits

  • Remove Llama4ForCausalLM from registry and update mllama4.py accordingly
  • Remove the usage of BaseModelOutput in mllama4.py

Thank you! I made a typo in the list and corrected it too - I think PP is the only one left to be added and it should be fairly straightforward (just an interface change - feel free to take a look at other model files)

@luccafong
Copy link
Collaborator

Since we're not in a huge rush to merge this PR into main, a few action items (and notes to myself):

  • Fix model init test with llama4 since irope is only supported on V1 but the test is forced to be on V0
  • Remove Llama4ForCausalLM from registry and update mllama4.py accordingly
  • Remove the usage of BaseModelOutput in mllama4.py
  • Remove the hacks we added in test_models.py for Llama4ForCausalLM
  • Add PP to the model?

These 2 are addressed in latest commits

  • Remove Llama4ForCausalLM from registry and update mllama4.py accordingly
  • Remove the usage of BaseModelOutput in mllama4.py

Thank you! I made a typo in the list and corrected it too - I think PP is the only one left to be added and it should be fairly straightforward (just an interface change - feel free to take a look at other model files)

@ywang96 do you mean add SupportsPP directly to Llama4ForConditionalGeneration?

@ywang96
Copy link
Member

ywang96 commented Apr 6, 2025

Since we're not in a huge rush to merge this PR into main, a few action items (and notes to myself):

  • Fix model init test with llama4 since irope is only supported on V1 but the test is forced to be on V0
  • Remove Llama4ForCausalLM from registry and update mllama4.py accordingly
  • Remove the usage of BaseModelOutput in mllama4.py
  • Remove the hacks we added in test_models.py for Llama4ForCausalLM
  • Add PP to the model?

These 2 are addressed in latest commits

  • Remove Llama4ForCausalLM from registry and update mllama4.py accordingly
  • Remove the usage of BaseModelOutput in mllama4.py

Thank you! I made a typo in the list and corrected it too - I think PP is the only one left to be added and it should be fairly straightforward (just an interface change - feel free to take a look at other model files)

@ywang96 do you mean add SupportsPP directly to Llama4ForConditionalGeneration?

Yep! and update the forward method too

Signed-off-by: Lucia Fang <[email protected]>
@luccafong
Copy link
Collaborator

@ywang96 please review, forward should be already be compatible with PP, forward is reusing LlamaModel which already support PP

@LagPixelLOL
Copy link

Upon further testing, it seems that it not only happens at 999500, it starts to happen right about the set max context length - 500, for example, when setting the max context length to 800000 with --max-model-len 800000 --max-seq-len-to-capture 800000, it starts to happen at about 799500 input length.

@ywang96
Copy link
Member

ywang96 commented Apr 7, 2025

Upon further testing, it seems that it not only happens at 999500, it starts to happen right about the set max context length - 500, for example, when setting the max context length to 800000 with --max-model-len 800000 --max-seq-len-to-capture 800000, it starts to happen at about 799500 input length.

That's interesting, I suspect there's something wrong with the kv cache manager. Could you see if you can repro this with vllm==0.8.3? and if you can, could you open an issue so we can track it instead of letting it mixed with the PR review comments here? Thanks!

@LagPixelLOL
Copy link

I managed to get a not messed up stack trace. It seems that the last call still inside vLLM was (VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383] masked_input, input_mask = get_masked_input_and_mask(, after this it goes into PyTorch compiled code.

(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383] WorkerProc hit an exception: %s
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383] Traceback (most recent call last):
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/vllm/v1/executor/multiproc_executor.py", line 376, in worker_busy_loop
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     output = func(*args, **kwargs)
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]              ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     return func(*args, **kwargs)
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/vllm/v1/worker/gpu_worker.py", line 242, in execute_model
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     output = self.model_runner.execute_model(scheduler_output)
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     return func(*args, **kwargs)
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/vllm/v1/worker/gpu_model_runner.py", line 1016, in execute_model
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     inputs_embeds = self.model.get_input_embeddings(input_ids)
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/models/mllama4.py", line 787, in get_input_embeddings
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     inputs_embeds = self.language_model.get_input_embeddings(input_ids)
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/models/llama.py", line 532, in get_input_embeddings
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     return self.model.get_input_embeddings(input_ids)
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/models/llama.py", line 339, in get_input_embeddings
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     return self.embed_tokens(input_ids)
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     return self._call_impl(*args, **kwargs)
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     return forward_call(*args, **kwargs)
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/layers/vocab_parallel_embedding.py", line 406, in forward
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     masked_input, input_mask = get_masked_input_and_mask(
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]                                ^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     return fn(*args, **kwargs)
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]            ^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     return self._torchdynamo_orig_callable(
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1164, in __call__
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     result = self._inner_convert(
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]              ^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     return _compile(
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]            ^^^^^^^^^
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1036, in _compile
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     raise InternalTorchDynamoError(
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     guarded_code = compile_inner(code, one_graph, hooks, transform)
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     return _compile_inner(code, one_graph, hooks, transform)
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     return function(*args, **kwargs)
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     out_code = transform_code_object(code, transform)
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     transformations(instructions, code_options)
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 246, in _fn
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     torch.cuda.set_rng_state(cuda_rng_state)
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/torch/cuda/random.py", line 75, in set_rng_state
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     _lazy_call(cb)
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/torch/cuda/__init__.py", line 249, in _lazy_call
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     callable()
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]   File "/usr/local/lib/python3.12/site-packages/torch/cuda/random.py", line 73, in cb
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383]     default_generator.set_state(new_state_copy)
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383] torch._dynamo.exc.InternalTorchDynamoError: RuntimeError: CUDA error: device-side assert triggered
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383] CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383] For debugging consider passing CUDA_LAUNCH_BLOCKING=1
(VllmWorker rank=1 pid=18180) ERROR 04-07 05:27:37 [multiproc_executor.py:383] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

@LagPixelLOL
Copy link

It does also happen with 0.8.3, I'm opening an issue.

@LagPixelLOL
Copy link

#16157

@DarkLight1337 DarkLight1337 requested a review from simon-mo as a code owner April 7, 2025 08:57
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
@DarkLight1337
Copy link
Member

CI failures aren't related to this PR.

@vllm-bot vllm-bot merged commit 55dcce9 into vllm-project:main Apr 7, 2025
66 of 69 checks passed
yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Apr 21, 2025
Signed-off-by: Aston Zhang <[email protected]>
Signed-off-by: Chris Thi <[email protected]>
Signed-off-by: drisspg <[email protected]>
Signed-off-by: Jon Swenson <[email protected]>
Signed-off-by: Keyun Tong <[email protected]>
Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Xiaodong Wang <[email protected]>
Signed-off-by: Yang Chen <[email protected]>
Signed-off-by: Ye (Charlotte) Qi <[email protected]>
Signed-off-by: Yong Hoon Shin <[email protected]>
Signed-off-by: Zijing Liu <[email protected]>
Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Lucia Fang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Co-authored-by: Lu Fang <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Co-authored-by: DarkLight1337 <[email protected]>
Signed-off-by: Yang Wang <[email protected]>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
Signed-off-by: Aston Zhang <[email protected]>
Signed-off-by: Chris Thi <[email protected]>
Signed-off-by: drisspg <[email protected]>
Signed-off-by: Jon Swenson <[email protected]>
Signed-off-by: Keyun Tong <[email protected]>
Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Xiaodong Wang <[email protected]>
Signed-off-by: Yang Chen <[email protected]>
Signed-off-by: Ye (Charlotte) Qi <[email protected]>
Signed-off-by: Yong Hoon Shin <[email protected]>
Signed-off-by: Zijing Liu <[email protected]>
Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Lucia Fang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Co-authored-by: Lu Fang <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Co-authored-by: DarkLight1337 <[email protected]>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: Aston Zhang <[email protected]>
Signed-off-by: Chris Thi <[email protected]>
Signed-off-by: drisspg <[email protected]>
Signed-off-by: Jon Swenson <[email protected]>
Signed-off-by: Keyun Tong <[email protected]>
Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Xiaodong Wang <[email protected]>
Signed-off-by: Yang Chen <[email protected]>
Signed-off-by: Ye (Charlotte) Qi <[email protected]>
Signed-off-by: Yong Hoon Shin <[email protected]>
Signed-off-by: Zijing Liu <[email protected]>
Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Lucia Fang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Co-authored-by: Lu Fang <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Co-authored-by: DarkLight1337 <[email protected]>
Signed-off-by: Mu Huai <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation frontend multi-modality Related to multi-modality (#4194) ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: AttributeError: 'Llama4ForConditionalGeneration' object has no attribute 'sampler' with prompt_logprobs

6 participants