Skip to content

Conversation

@alexm-redhat
Copy link
Collaborator

@alexm-redhat alexm-redhat commented Feb 20, 2025

This PR adds tensor-parallel support to [V1] TPU via Ray executor without changing the SMPD and Ray compile flags that are used for the NVIDIA codepath. As a result, NVIDIA's Ray executor is mostly reused for the TPU codepath. Verified correctness via:

VLLM_USE_V1=1 pytest -s -v tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify
Copy link

mergify bot commented Feb 20, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @alexm-redhat.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 20, 2025
@alexm-redhat
Copy link
Collaborator Author

@bvrockwell this is the TP PR. I did not fully mimick the logic from V0, since I wanted to preserve the Ray DAG compilation which is actually used here. The PR is basically done, just needs to become green and remove debug cruft.

@alexm-redhat
Copy link
Collaborator Author

/ready

@alexm-redhat
Copy link
Collaborator Author

@mgoin what's nicolo's username?

@mergify mergify bot removed the needs-rebase label Feb 20, 2025
@alexm-redhat
Copy link
Collaborator Author

@mgoin The PR is ready for review.

@mgoin
Copy link
Member

mgoin commented Feb 20, 2025

cc @NickLucche

@alexm-redhat alexm-redhat self-assigned this Feb 20, 2025
def execute_model(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? I thought intermediate_tensors was just needed for PP

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not used, but it is part of the API, else it errors.

@robertgshaw2-redhat robertgshaw2-redhat added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 20, 2025
@brittrock
Copy link

@bvrockwell this is the TP PR. I did not fully mimick the logic from V0, since I wanted to preserve the Ray DAG compilation which is actually used here. The PR is basically done, just needs to become green and remove debug cruft.

Thanks @alexm-redhat ! Tagging @lsy323 who is working on this section of the code currently.

from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import msgspec
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to be a lazy import

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexm-redhat can you move the torch import back down?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved down

Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!
Any chance we can add a tiny unit test under tests/v1/tpu to make sure tp is working/using gpus as intended without running the whole correctness test?

if not self.compiled_dag_cuda_device_set:
torch.cuda.set_device(self.worker.device)
if current_platform.is_tpu():
# TODO: [AlexM] Verify if set_device is necessary here
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has this been verified?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah

@alexm-redhat
Copy link
Collaborator Author

@NickLucche thanks for the suggestion, added a new test in tests/v1/tpu/test_basic.py that performs a quick correctness sanity check without running the full evaluation suite.

Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the test!
I wasn't able to run it on my tpu pod (tp=4) despite tuning down all args that would cause oom. I end up with

INFO 03-07 09:53:04 [kv_cache_utils.py:537] GPU KV cache size: 1,106,000 tokens
INFO 03-07 09:53:04 [kv_cache_utils.py:540] Maximum concurrency for 128 tokens per request: 8640.62x
INFO 03-07 09:53:04 [kv_cache_utils.py:537] GPU KV cache size: 1,106,000 tokens
INFO 03-07 09:53:04 [kv_cache_utils.py:540] Maximum concurrency for 128 tokens per request: 8640.62x
INFO 03-07 09:53:04 [kv_cache_utils.py:537] GPU KV cache size: 1,106,000 tokens
INFO 03-07 09:53:04 [kv_cache_utils.py:540] Maximum concurrency for 128 tokens per request: 8640.62x
INFO 03-07 09:53:04 [kv_cache_utils.py:537] GPU KV cache size: 1,106,000 tokens
INFO 03-07 09:53:04 [kv_cache_utils.py:540] Maximum concurrency for 128 tokens per request: 8640.62x
INFO 03-07 09:53:04 [core.py:116] init engine (profile, create kv cache, warmup model) took 19.98 seconds
Processed prompts:   0%|                                                                                                                  | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]INFO 03-07 09:53:04 [ray_distributed_executor.py:534] VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = False
INFO 03-07 09:53:04 [ray_distributed_executor.py:536] VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = False
ERROR 03-07 09:53:14 [core.py:303] EngineCore hit an exception: Traceback (most recent call last):
ERROR 03-07 09:53:14 [core.py:303]   File "/home/nick/vllm/.venv/lib/python3.11/site-packages/ray/dag/compiled_dag_node.py", line 2344, in _execute_until
ERROR 03-07 09:53:14 [core.py:303]     result = self._dag_output_fetcher.read(timeout)
ERROR 03-07 09:53:14 [core.py:303]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 09:53:14 [core.py:303]   File "/home/nick/vllm/.venv/lib/python3.11/site-packages/ray/experimental/channel/common.py", line 318, in read
ERROR 03-07 09:53:14 [core.py:303]     outputs = self._read_list(timeout)
ERROR 03-07 09:53:14 [core.py:303]               ^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 09:53:14 [core.py:303]   File "/home/nick/vllm/.venv/lib/python3.11/site-packages/ray/experimental/channel/common.py", line 409, in _read_list
ERROR 03-07 09:53:14 [core.py:303]     raise e
ERROR 03-07 09:53:14 [core.py:303]   File "/home/nick/vllm/.venv/lib/python3.11/site-packages/ray/experimental/channel/common.py", line 391, in _read_list
ERROR 03-07 09:53:14 [core.py:303]     result = c.read(min(remaining_timeout, iteration_timeout))
ERROR 03-07 09:53:14 [core.py:303]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 09:53:14 [core.py:303]   File "/home/nick/vllm/.venv/lib/python3.11/site-packages/ray/experimental/channel/shared_memory_channel.py", line 776, in read
ERROR 03-07 09:53:14 [core.py:303]     return self._channel_dict[self._resolve_actor_id()].read(timeout)
ERROR 03-07 09:53:14 [core.py:303]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 09:53:14 [core.py:303]   File "/home/nick/vllm/.venv/lib/python3.11/site-packages/ray/experimental/channel/shared_memory_channel.py", line 612, in read
ERROR 03-07 09:53:14 [core.py:303]     output = self._buffers[self._next_read_index].read(timeout)
ERROR 03-07 09:53:14 [core.py:303]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 09:53:14 [core.py:303]   File "/home/nick/vllm/.venv/lib/python3.11/site-packages/ray/experimental/channel/shared_memory_channel.py", line 480, in read
ERROR 03-07 09:53:14 [core.py:303]     ret = self._worker.get_objects(
ERROR 03-07 09:53:14 [core.py:303]           ^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 09:53:14 [core.py:303]   File "/home/nick/vllm/.venv/lib/python3.11/site-packages/ray/_private/worker.py", line 893, in get_objects
ERROR 03-07 09:53:14 [core.py:303]     ] = self.core_worker.get_objects(
ERROR 03-07 09:53:14 [core.py:303]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 09:53:14 [core.py:303]   File "python/ray/_raylet.pyx", line 3189, in ray._raylet.CoreWorker.get_objects
ERROR 03-07 09:53:14 [core.py:303]   File "python/ray/includes/common.pxi", line 106, in ray._raylet.check_status
ERROR 03-07 09:53:14 [core.py:303] ray.exceptions.RayChannelTimeoutError: System error: Timed out waiting for object available to read. ObjectID: 00025bfe7d0aed89fa1684762ecae37e05ee43550100000002e1f505
ERROR 03-07 09:53:14 [core.py:303] 
ERROR 03-07 09:53:14 [core.py:303] The above exception was the direct cause of the following exception:
ERROR 03-07 09:53:14 [core.py:303] 
ERROR 03-07 09:53:14 [core.py:303] Traceback (most recent call last):
ERROR 03-07 09:53:14 [core.py:303]   File "/home/nick/vllm/vllm/v1/engine/core.py", line 296, in run_engine_core
ERROR 03-07 09:53:14 [core.py:303]     engine_core.run_busy_loop()
ERROR 03-07 09:53:14 [core.py:303]   File "/home/nick/vllm/vllm/v1/engine/core.py", line 339, in run_busy_loop
ERROR 03-07 09:53:14 [core.py:303]     outputs = step_fn()
ERROR 03-07 09:53:14 [core.py:303]               ^^^^^^^^^
ERROR 03-07 09:53:14 [core.py:303]   File "/home/nick/vllm/vllm/v1/engine/core.py", line 154, in step
ERROR 03-07 09:53:14 [core.py:303]     output = self.model_executor.execute_model(scheduler_output)
ERROR 03-07 09:53:14 [core.py:303]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 09:53:14 [core.py:303]   File "/home/nick/vllm/vllm/v1/executor/ray_distributed_executor.py", line 57, in execute_model
ERROR 03-07 09:53:14 [core.py:303]     return refs[0].get()
ERROR 03-07 09:53:14 [core.py:303]            ^^^^^^^^^^^^^
ERROR 03-07 09:53:14 [core.py:303]   File "/home/nick/vllm/.venv/lib/python3.11/site-packages/ray/experimental/compiled_dag_ref.py", line 124, in get
ERROR 03-07 09:53:14 [core.py:303]     self._dag._execute_until(
ERROR 03-07 09:53:14 [core.py:303]   File "/home/nick/vllm/.venv/lib/python3.11/site-packages/ray/dag/compiled_dag_node.py", line 2350, in _execute_until
ERROR 03-07 09:53:14 [core.py:303]     raise RayChannelTimeoutError(
ERROR 03-07 09:53:14 [core.py:303] ray.exceptions.RayChannelTimeoutError: System error: If the execution is expected to take a long time, increase RAY_CGRAPH_get_timeout which is currently 10 seconds. Otherwise, this may indicate that the execution is hanging.
ERROR 03-07 09:53:14 [core.py:303] 
INFO 03-07 09:53:14 [ray_distributed_executor.py:108] Shutting down Ray distributed executor. If you see error log from logging.cc regarding SIGTERM received, please ignore because this is the expected termination process in Ray.

Is this working fine for you guys with ray==2.43.0?
Might even be some issue with libtpu, but just posting it here to double check.

@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This is a basic test for TPU only")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should set this to bfloat16 to avoid
The TPU backend currently does not support torch.float16. Using bfloat16 instead.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I removed the dtype, since it is not really necessary

@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
def test_models(
monkeypatch,
hf_runner,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hf_runner is not used

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, removed

@alexm-redhat
Copy link
Collaborator Author

@NickLucche I have rebased over the new ragged paged attn v2, however, it does not work now for TP==4 due to new kernel's limitation: "ValueError: Not implemented: num_kv_heads=1 can not be XLA fully tiled.”. TP==1 works fine.

@alexm-redhat
Copy link
Collaborator Author

@yaochengji is aware of the TP==4 issue. Anyway, I think the PR can be merged so that Chengji and Jevin can reproduce the issue locally.

@alexm-redhat
Copy link
Collaborator Author

@NickLucche btw, I was using use_kernel=True to test this PR before because I was getting constant OOM.

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mgoin mgoin added the tpu Related to Google TPUs label Mar 7, 2025
@alexm-redhat
Copy link
Collaborator Author

@mgoin @NickLucche I was able to verify correctness for TP==4 for llama 3.1 8B (the new kernel works for this model). The other issue is Qwen specific.

@robertgshaw2-redhat robertgshaw2-redhat merged commit cb8bdfa into main Mar 8, 2025
39 checks passed
@robertgshaw2-redhat robertgshaw2-redhat deleted the v1_tpu_tp branch March 8, 2025 13:19
@NickLucche
Copy link
Collaborator

great job here @alexm-redhat ! I didn't know about use_kernel=True perhaps pallas.py could use a few more comments.

Alexei-V-Ivanov-AMD added a commit to ROCm/vllm that referenced this pull request Mar 11, 2025
* Fix `head_dim` not existing in all model configs (Transformers backend) (vllm-project#14141)

Signed-off-by: Harry Mellor <[email protected]>

* [V0][Metrics] Remove unimplemented `vllm:tokens_total` (vllm-project#14134)

Signed-off-by: Mark McLoughlin <[email protected]>

* [V0][Metrics] Deprecate some KV/prefix cache metrics (vllm-project#14136)

Signed-off-by: Mark McLoughlin <[email protected]>

* [V1] Simplify stats logging (vllm-project#14082)

Signed-off-by: Nick Hill <[email protected]>

* [WIP][[V1][Metrics] Implement max_num_generation_tokens,  request_params_n, and request_params_max_tokens metrics (vllm-project#14055)

Signed-off-by: Mark McLoughlin <[email protected]>

* [Bugfix] Allow shared_experts skip quantization for DeepSeekV2/V3 (vllm-project#14100)

Signed-off-by: mgoin <[email protected]>

* [Kernel] Optimize moe intermediate_cache usage (vllm-project#13625)

Signed-off-by: mgoin <[email protected]>

* [Docs] Add GPTQModel (vllm-project#14056)

Signed-off-by: mgoin <[email protected]>
Co-authored-by: mgoin <[email protected]>

* [v1] Add comments to the new ragged paged attention Pallas kernel (vllm-project#14155)

Signed-off-by: Xiongfei Wei <[email protected]>
Co-authored-by: Michael Goin <[email protected]>

* [Model] Add support for GraniteMoeShared models (vllm-project#13313)

Signed-off-by: Travis Johnson <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>

* [core] moe fp8 block quant tuning support (vllm-project#14068)

Signed-off-by: Divakar Verma <[email protected]>

* [Misc] Remove lru_cache in NvmlCudaPlatform (vllm-project#14156)

Signed-off-by: Cody Yu <[email protected]>

* [core] Pass all driver env vars to ray workers unless excluded (vllm-project#14099)

Signed-off-by: Rui Qiao <[email protected]>

* Use math.prod instead of np.prod for trivial ops (vllm-project#14142)

* Fix benchmark_moe.py tuning for CUDA devices (vllm-project#14164)

* [platform] add debug logging during inferring the device type (vllm-project#14195)

Signed-off-by: youkaichao <[email protected]>

* [sleep mode] error out with expandable_segments (vllm-project#14189)

Signed-off-by: youkaichao <[email protected]>

* [doc] add "Failed to infer device type" to faq (vllm-project#14200)

Signed-off-by: youkaichao <[email protected]>

* [Bugfix] Restrict MacOS CPU detection (vllm-project#14210)

Signed-off-by: mgoin <[email protected]>

* [V1][BugFix] Fix remaining sync engine client shutdown errors/hangs (vllm-project#13869)

Signed-off-by: Nick Hill <[email protected]>

* [V0][Metrics] Deprecate some questionable request time metrics (vllm-project#14135)

Signed-off-by: Mark McLoughlin <[email protected]>

* [V1][Molmo] Fix get_multimodal_embeddings() in molmo.py (vllm-project#14161)

* add cutlass support for blackwell fp8 gemm (vllm-project#13798)

* [TPU][Profiler] Support start_profile/stop_profile in TPU worker (vllm-project#13988)

Signed-off-by: Siyuan Liu <[email protected]>
Co-authored-by: mgoin <[email protected]>

* Fix performance when `--generation-config` is not `None` (vllm-project#14223)

Signed-off-by: Harry Mellor <[email protected]>

* [Frontend] Do `prompt_logprobs` clamping for chat as well as completions (vllm-project#14225)

Signed-off-by: Harry Mellor <[email protected]>

* [Docs] Update Dockerfile dependency image (vllm-project#14215)

Signed-off-by: mgoin <[email protected]>

* [v1][Metrics] Add design doc (vllm-project#12745)

Signed-off-by: Mark McLoughlin <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
Co-authored-by: Cody Yu <[email protected]>

* [Security] Serialize using safetensors instead of pickle in Mooncake Pipe (vllm-project#14228)

Signed-off-by: KuntaiDu <[email protected]>

* Clean up unused padding_idx variables across many model definitions (vllm-project#13240)

Signed-off-by: Tyler Michael Smith <[email protected]>

* [ROCm] Disable a few more kernel tests that are broken on ROCm (vllm-project#14145)

Signed-off-by: Sage Moore <[email protected]>

* [V1][TPU] TPU multimodal model support for ragged attention (vllm-project#14158)

Signed-off-by: Michael Goin <[email protected]>

* [misc] announce china meetup (vllm-project#14248)

Signed-off-by: youkaichao <[email protected]>

* Moved numba from common requirements to cuda/rocm specific requirements (vllm-project#14199)

Signed-off-by: Nishidha Panpaliya <[email protected]>

* Disable GPTQ AllSpark kernels for CUDA Compiler < 12.0 (vllm-project#14157)

Signed-off-by: mgoin <[email protected]>

* [Bugfix] Fix gptq_marlin for deepseek-v3 (vllm-project#13750)

Signed-off-by: dangshunya <[email protected]>
Co-authored-by: dangshunya <[email protected]>

* [V1][Bugfix] Do not reset prefix caching metrics (vllm-project#14235)

* [Model] New model support for Phi-4-multimodal-instruct (vllm-project#14119)

* [V1] EP/TP MoE + DP Attention (vllm-project#13931)

* [platforms] improve rocm debugging info (vllm-project#14257)

* Temporarily disable test_awq_gemm_opcheck (vllm-project#14251)

Signed-off-by: mgoin <[email protected]>

* [Frontend] Allow return_tokens_as_token_ids to be passed as a request param (vllm-project#14066)

Signed-off-by: Benjamin Chislett <[email protected]>

* [Misc][V1] Avoid using `envs.VLLM_USE_V1` in mm processing (vllm-project#14256)

Signed-off-by: Roger Wang <[email protected]>

* [Bugfix][V1] Fix allowed_token_ids for v1 Sampler (vllm-project#14169)

Signed-off-by: Lu Fang <[email protected]>

* [Doc] Update nginx guide: remove privileged from vllm container run and add target GPU ID (vllm-project#14217)

Signed-off-by: Iacopo Poli <[email protected]>

* [Doc] [3/N] Refer code examples for common cases in dev multimodal processor (vllm-project#14278)

Signed-off-by: DarkLight1337 <[email protected]>

* Small update for external_launcher backend docs (vllm-project#14288)

* [V1][Frontend] Add Testing For V1 Runtime Parameters (vllm-project#14159)

Signed-off-by: [email protected] <[email protected]>

* [LoRA] Remove linear hack outside transformers backend (vllm-project#14177)

Signed-off-by: Isotr0py <[email protected]>

* [Misc] Add Qwen2MoeForCausalLM moe tuning support  (vllm-project#14276)

Signed-off-by: Jee Jee Li <[email protected]>

* prefix_caching.md: Fixed typo (vllm-project#14293)

Signed-off-by: Daivid Savernin-Frenk <[email protected]>

* [Bugfix] Fix broken vision language example (vllm-project#14292)

Signed-off-by: Isotr0py <[email protected]>

* [Docs] Add Meta Slides (vllm-project#14297)

Signed-off-by: simon-mo <[email protected]>

* [V1][Minor] Remove obsolete FIXME comment (vllm-project#14304)

Signed-off-by: Nick Hill <[email protected]>

* Deprecate `best_of` Sampling Parameter in anticipation for vLLM V1 (vllm-project#13997)

Signed-off-by: vincent-4 <[email protected]>
Signed-off-by: Brayden Zhong <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Co-authored-by: Brayden Zhong <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>

* [V1][BugFix] Fix for mixed top_k batch (vllm-project#14301)

Signed-off-by: Nick Hill <[email protected]>


Co-authored-by: Ye Cao <[email protected]>

* [misc] Add FlashMLA as a new option of VLLM_ATTENTION_BACKEND env (vllm-project#14267)

* [V1][Easy] Add empty allowed_token_ids in the v1 sampler test (vllm-project#14308)

Signed-off-by: Lu Fang <[email protected]>

* init

Signed-off-by: Sage Moore <[email protected]>

* [Bugfix] Fix DeepSeek MTP crash when using TP1ModelRunner with CUDA graph due to shape mismatch (vllm-project#14237)

Signed-off-by: pyc96 <[email protected]>

* [Bugfix] Remove num_tokens_across_dp (vllm-project#14302)

Signed-off-by: Tyler Michael Smith <[email protected]>

* [BugFix] Fix prefix caching V0 MLA (vllm-project#14255)

Signed-off-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Ying Zhong <[email protected]>

* [CI/Build] Use spawn multiprocessing mode for V1 test pipeline (vllm-project#14243)

Signed-off-by: Russell Bryant <[email protected]>

* Add benchmark for DeepGEMM and vLLM Block FP8 Dense GEMM (vllm-project#13917)

Signed-off-by: mgoin <[email protected]>

* [Build] Add UV_HTTP_TIMEOUT to avoid timeout during installation (vllm-project#13850)

Signed-off-by: Yuan Tang <[email protected]>

* [BugFix] MLA + V1, illegal memory access and accuracy issues (vllm-project#14253)

Signed-off-by: Lucas Wilkinson <[email protected]>

* [misc] Mention `ray list nodes` command to troubleshoot ray issues (vllm-project#14318)

Signed-off-by: Rui Qiao <[email protected]>

* [Bugfix][Structured Output] Support outlines engine with reasoning outputs for DeepSeek R1 (vllm-project#14114)

* [V1] LoRA - Enable more V1 tests (vllm-project#14315)

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>

* [Bugfix][CI] ALiBi test case in xformers multi_query_kv_attention (vllm-project#11301)

* [Hardware] Update the flash attn tag to support Blackwell (vllm-project#14244)

* [Model] Update Paligemma multimodal processing with PromptUpdate  (vllm-project#14015)

Signed-off-by: Kyle Huang <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>

* [V1][VLM][Pixtral-HF] Support Pixtral-HF on V1 (vllm-project#14275)

Signed-off-by: Linkun Chen <[email protected]>

* [Core] Optimizing cross-attention `QKVParallelLinear` computation (vllm-project#12325)

Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Co-authored-by: NickLucche <[email protected]>

* [Frontend][Docs] Transcription API streaming (vllm-project#13301)

Signed-off-by: NickLucche <[email protected]>

* [Doc] Update reasoning with stream example to use OpenAI library (vllm-project#14077)

Signed-off-by: liuyanyi <[email protected]>

* [Doc] Correct beam_search using in generative_models.md (vllm-project#14363)

* [Kernel] [V1] Improved performance for V1 Triton (ROCm) backend  (vllm-project#14152)

* [Bugfix][Core] fix abort_seq_group and memory leak when n>1 (vllm-project#14326)

Signed-off-by: courage17340 <[email protected]>

* [Core] Don't use cache during multi-modal profiling (vllm-project#14336)

* [Doc] Fix date typo in README.md (vllm-project#14366)

Signed-off-by: Jitse Klomp <[email protected]>

* [RLHF] use worker_extension_cls for compatibility with V0 and V1 (vllm-project#14185)

Signed-off-by: youkaichao <[email protected]>

* Reinstate `best_of` for V0 (vllm-project#14356)

Signed-off-by: Harry Mellor <[email protected]>

* Adding cpu inference with VXE ISA for s390x architecture (vllm-project#12613)

Signed-off-by: Dilip Gowda Bhagavan <[email protected]>
Signed-off-by: Rishika Kedia <[email protected]>
Co-authored-by: Rishika Kedia <[email protected]>

* Add authors to license header. (vllm-project#14371)

Signed-off-by: Thomas Parnell <[email protected]>
Co-authored-by: Burkhard Ringlein <[email protected]>
Co-authored-by: Jan van Lunteren <[email protected]>

* Fix mla prefill context performance (vllm-project#13897)

Signed-off-by: ZhongYingMatrix <[email protected]>

* [V1] Do not detokenize if sampling param detokenize is False (vllm-project#14224)

Signed-off-by: Himanshu Jaju <[email protected]>
Signed-off-by: Nick Hill <[email protected]>
Co-authored-by: Nick Hill <[email protected]>

* [Distributed] Add enable_expert_parallel arg (vllm-project#14305)

Signed-off-by: Tyler Michael Smith <[email protected]>

* [CI/Build] Use uv python for docker rather than ppa:deadsnakes/ppa (vllm-project#13569)

Signed-off-by: mgoin <[email protected]>

* [CI] Disable spawn when running V1 Test (vllm-project#14345)

Signed-off-by: Thomas Parnell <[email protected]>

* [Kernel] Add needs_fixed_stride_order tag to most GEMMs (vllm-project#14306)

Signed-off-by: Tyler Michael Smith <[email protected]>

* [Bugfix] Fix use_direct_call condition in FusedMoE layer for  (vllm-project#14382)

Signed-off-by: Tyler Michael Smith <[email protected]>

* [Bug] Fix Attention when ignored in by quant_method (vllm-project#14313)

Signed-off-by: mgoin <[email protected]>

* [V1][Bugfix] Standardize quantized kv cache rejection for attention backends (vllm-project#14221)

Signed-off-by: mgoin <[email protected]>

* [Docs] Add nsight guide to profiling docs (vllm-project#14298)

Signed-off-by: mgoin <[email protected]>

* cleanup boolean logic

Signed-off-by: Sage Moore <[email protected]>

* [Hardware][TPU]Enable ragged paged attention kernel and resolve recompilation issue (vllm-project#14310)

Signed-off-by: Chengji Yao <[email protected]>

* [Doc] Fix a typo (vllm-project#14385)

* [Bugfix] Correctly call `cudaProfilerStop` in benchmarks script (vllm-project#14183)

Signed-off-by: Brayden Zhong <[email protected]>

* [Perf] Reduce MLA CPU overheads in V1 (vllm-project#14384)

Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>

* [FP8] Refactor apply_fp8_linear and apply_fp8_linear_generic into an object (vllm-project#14390)

Signed-off-by: luka <[email protected]>

* [BugFix] Illegal Memory Access in the blockwise cutlass fp8 GEMMs (vllm-project#14396)

* [Bugfix] Fix JambaForCausalLM LoRA  (vllm-project#14370)

Signed-off-by: Jee Jee Li <[email protected]>

* [Build] Add nightly wheel fallback when latest commit wheel unavailable (vllm-project#14358)

Signed-off-by: Isotr0py <[email protected]>

* OpenVINO: added CPU-like conditions (vllm-project#14338)

Signed-off-by: Ilya Lavrenov <[email protected]>

* [GH] Auto-apply multi-modality label to relevant PRs (vllm-project#14402)

Signed-off-by: DarkLight1337 <[email protected]>

* correct wrong markdown syntax (vllm-project#14414)

Signed-off-by: vincent-pli <[email protected]>

* [Bugfix] Further clean up LoRA test (vllm-project#14422)

Signed-off-by: Jee Jee Li <[email protected]>

* [Bugfix] Clean up multi-modal processors (vllm-project#14417)

Signed-off-by: DarkLight1337 <[email protected]>

* [Misc] Set default value of seed to None (vllm-project#14274)

Signed-off-by: மனோஜ்குமார் பழனிச்சாமி <[email protected]>

* [BUGFIX] Skip tokenization support for throughput benchmark (vllm-project#12712)

Signed-off-by: root <[email protected]>
Signed-off-by: Aleksandr Malyshev <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: Aleksandr Malyshev <[email protected]>

* Fix missing `kv_caches` and `attn_metadata` in `OpenVINOCausalLM` (vllm-project#14271)

Signed-off-by: Harry Mellor <[email protected]>

* Use the optimized block sizes after tuning the kernel. (vllm-project#14329)

* [V1][Core] Support for Structured Outputs (vllm-project#12388)

Signed-off-by: Aaron Pham <[email protected]>
Signed-off-by: Russell Bryant <[email protected]>
Co-authored-by: Russell Bryant <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Nick Hill <[email protected]>

* [Doc] Update prefix_caching.md to match the example image (vllm-project#14420)

* [Benchmarks] Make detokenization optional in benchmark scripts (vllm-project#11697)

Signed-off-by: Jeremy Arnold <[email protected]>

* comments

Signed-off-by: Sage Moore <[email protected]>

* [Kernel] optimize performance of gptq marlin kernel when n is small (vllm-project#14138)

Signed-off-by: Jinzhen Lin <[email protected]>

* [Misc] Add Phi4-MM example (vllm-project#14343)

Signed-off-by: Jee Jee Li <[email protected]>

* [v1] torch.compile integration explanation (vllm-project#14437)

Signed-off-by: youkaichao <[email protected]>

* [V1] Eagerly remove finished requests from the batch (vllm-project#14388)

Signed-off-by: Nick Hill <[email protected]>

* [V1][Metrics] Fix traceback with preemptions+LoRA (vllm-project#14220)

Signed-off-by: Mark McLoughlin <[email protected]>

* [Bugfix] Fix torch_xla which can't handle None seed introduced in vllm-project#14274 (vllm-project#14459)

Signed-off-by: Yarong Mu <[email protected]>

* [V1] Prompt logprobs + APC compatibility; prompt logprobs reqs cannot fill APC (vllm-project#13949)

* [Bugfix][V1] Handle MLA in kv_cache_interface (vllm-project#14462)

Signed-off-by: Tyler Michael Smith <[email protected]>

* Revert "[Perf] Reduce MLA CPU overheads in V1 (vllm-project#14384)" (vllm-project#14471)

* [Bugfix][Disaggregated] Add a check in send_kv_caches_and_hidden_states and fix the reshape of the KVCache (vllm-project#14369)

Signed-off-by: Mathis Felardos <[email protected]>

* [MISC][V1] Register process killing handler only in the main thread (vllm-project#14380)

Signed-off-by: Cody Yu <[email protected]>

* [core] add `extra_args` to `SamplingParams` (vllm-project#13300)

Signed-off-by: Aviv Keshet <[email protected]>

* [CI/Build] refactor: set timezone of container to UTC (vllm-project#12888)

Signed-off-by: Roger Meier <[email protected]>

* Default to `generation_config` from model (vllm-project#12622)

Signed-off-by: Harry Mellor <[email protected]>

* [Doc]add doc for Qwen models tool calling (vllm-project#14478)

Signed-off-by: WangErXiao <[email protected]>

* [Doc] Added QwQ-32B to the supported models list in the reasoning out… (vllm-project#14479)

Signed-off-by: WangErXiao <[email protected]>

* [Bugfix] Make the deviceprofiler include LoRA memory. (vllm-project#14469)

Signed-off-by: Jee Jee Li <[email protected]>

* Add training doc signposting to TRL (vllm-project#14439)

Signed-off-by: Harry Mellor <[email protected]>

* [Build/BugFix] Fix hopper 12.8 build (vllm-project#14354)

Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>

* Add RLHF document (vllm-project#14482)

Signed-off-by: Harry Mellor <[email protected]>

* [CI/Build] Use a fixed seed to avoid flaky tests (vllm-project#14480)

Signed-off-by: DarkLight1337 <[email protected]>

* [V1] TPU - Add tensor parallel support via Ray (vllm-project#13618)

Signed-off-by: Alexander Matveev <[email protected]>

* [VLM] Add TP support for Phi-4-MM (vllm-project#14453)

Signed-off-by: Isotr0py <[email protected]>

* [Misc] add `use_tqdm_on_load` to reduce logs (vllm-project#14407)

Signed-off-by: Aaron Pham <[email protected]>

* [V1][Core] Fix memory issue with logits & sampling (vllm-project#13776)

Signed-off-by: Roger Wang <[email protected]>

* [benchmarks] Add option to use unique jsonschema for each request (vllm-project#14457)

Signed-off-by: Russell Bryant <[email protected]>

* [Misc] Don't run ruff at all on 3rd party libs (vllm-project#14493)

Signed-off-by: DarkLight1337 <[email protected]>

* Move requirements into their own directory (vllm-project#12547)

Signed-off-by: Harry Mellor <[email protected]>

* [Bugfix] DeepSeek Accuracy (vllm-project#14476)

Signed-off-by: Lucas Wilkinson <[email protected]>

* [Bugfix] Fix profiling OOM and decouple encoder multimodal profiling (vllm-project#14361)

Signed-off-by: Isotr0py <[email protected]>

* Update CODEOWNERS for structured output (vllm-project#14496)

Signed-off-by: Russell Bryant <[email protected]>

* [Misc] Upgrade to Python 3.9 typing for additional directories (vllm-project#14492)

Signed-off-by: DarkLight1337 <[email protected]>

* [V1] Support bad_words in sampler (vllm-project#13376)

Signed-off-by: 22quinn <[email protected]>
Co-authored-by: Nick Hill <[email protected]>

* Revert "[V1][Core] Fix memory issue with logits & sampling" (vllm-project#14504)

Signed-off-by: Roger Wang <[email protected]>
Co-authored-by: Roger Wang <[email protected]>

* [Attention] Default to FlashMLA backend for MLA (vllm-project#14451)

Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>

* [V1][TPU] Remove unnecessary padding for running on TPU. (vllm-project#14467)

* [Feat] Support chunked prefill for LMCache connector (vllm-project#14505)

Signed-off-by: YaoJiayi <[email protected]>

* [Bugfix] Fix tqdm progress bar when SamplingParams.n > 1 (vllm-project#12428)

Signed-off-by: Yuchen Yan <[email protected]>

* [Bugfix] Revert QKVCrossParallelLinear usage in Mllama to keep BNB quantization work (vllm-project#14498)

Signed-off-by: Isotr0py <[email protected]>

* [Hardware][TPU] Fix the recompiling issue in logits processor after warmup (vllm-project#14510)

Signed-off-by: Chengji Yao <[email protected]>

* [Misc] Ensure out-of-tree quantization method recognize by cli args (vllm-project#14328)

Signed-off-by: liuyanyi <[email protected]>

* [Bugfix] Wrong requirements path - rocm (vllm-project#14527)

Signed-off-by: Martin Hoyer <[email protected]>

* [Feature] Consolidate performance benchmark datasets (vllm-project#14036)

Signed-off-by: Jennifer Zhao <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Co-authored-by: Jennifer Zhao <[email protected]>
Co-authored-by: Roger Wang <[email protected]>

* [Misc] Add log information for handle_process_request. (vllm-project#14130)

Signed-off-by: chaunceyjiang <[email protected]>

* [Docs] Mention `model_impl` arg when explaining Transformers fallback (vllm-project#14552)

Signed-off-by: Harry Mellor <[email protected]>

* [Frontend] support image embeds (vllm-project#13955)

Signed-off-by: chaunceyjiang <[email protected]>

* [Kernel] Add more dtype support for GGUF kernels (vllm-project#14043)

Signed-off-by: SzymonOzog <[email protected]>
Signed-off-by: SzymonOzog <[email protected]>

* [Doc] Update PaliGemma note to a warning (vllm-project#14565)

Signed-off-by: DarkLight1337 <[email protected]>

* V1 rocm support (#469)

* Initial commit for V1 successfull compilation

* Small improvement for linear

* Small improvement for linear

* making use of forward_cuda for all except ROPE in llama

---------

Co-authored-by: maleksan85 <[email protected]>

* nightly_fixed_aiter_integration_final_20250305 README update (#470)

* nightly_fixed_aiter_integration_final_20250305 README update (perf results only)

* Update Docker Manifest git hash

* Update Docker Manifest and added nightly_fixed_aiter_integration_final_20250305

* some more updates

* Update AITER section with example

* Updated AITER command with larger batch size and model name

* Fixing typo

* Removed --max-model-len in AITER command

* Updating AITER instructions

* typo

* Another typo

* Whitespace

* modifying whats new section

* Another typo

---------

Co-authored-by: arakowsk-amd <[email protected]>
Co-authored-by: Gregory Shtrasberg <[email protected]>

---------

Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Mark McLoughlin <[email protected]>
Signed-off-by: Nick Hill <[email protected]>
Signed-off-by: mgoin <[email protected]>
Signed-off-by: Xiongfei Wei <[email protected]>
Signed-off-by: Travis Johnson <[email protected]>
Signed-off-by: Divakar Verma <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Siyuan Liu <[email protected]>
Signed-off-by: KuntaiDu <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Sage Moore <[email protected]>
Signed-off-by: Michael Goin <[email protected]>
Signed-off-by: Nishidha Panpaliya <[email protected]>
Signed-off-by: dangshunya <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Iacopo Poli <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: [email protected] <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Jee Jee Li <[email protected]>
Signed-off-by: Daivid Savernin-Frenk <[email protected]>
Signed-off-by: simon-mo <[email protected]>
Signed-off-by: vincent-4 <[email protected]>
Signed-off-by: Brayden Zhong <[email protected]>
Signed-off-by: pyc96 <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Russell Bryant <[email protected]>
Signed-off-by: Yuan Tang <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Kyle Huang <[email protected]>
Signed-off-by: Linkun Chen <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: liuyanyi <[email protected]>
Signed-off-by: courage17340 <[email protected]>
Signed-off-by: Jitse Klomp <[email protected]>
Signed-off-by: Dilip Gowda Bhagavan <[email protected]>
Signed-off-by: Rishika Kedia <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: ZhongYingMatrix <[email protected]>
Signed-off-by: Himanshu Jaju <[email protected]>
Signed-off-by: Chengji Yao <[email protected]>
Signed-off-by: luka <[email protected]>
Signed-off-by: Ilya Lavrenov <[email protected]>
Signed-off-by: vincent-pli <[email protected]>
Signed-off-by: மனோஜ்குமார் பழனிச்சாமி <[email protected]>
Signed-off-by: root <[email protected]>
Signed-off-by: Aleksandr Malyshev <[email protected]>
Signed-off-by: Aaron Pham <[email protected]>
Signed-off-by: Jeremy Arnold <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Yarong Mu <[email protected]>
Signed-off-by: Mathis Felardos <[email protected]>
Signed-off-by: Aviv Keshet <[email protected]>
Signed-off-by: Roger Meier <[email protected]>
Signed-off-by: WangErXiao <[email protected]>
Signed-off-by: Alexander Matveev <[email protected]>
Signed-off-by: 22quinn <[email protected]>
Signed-off-by: YaoJiayi <[email protected]>
Signed-off-by: Yuchen Yan <[email protected]>
Signed-off-by: Martin Hoyer <[email protected]>
Signed-off-by: Jennifer Zhao <[email protected]>
Signed-off-by: chaunceyjiang <[email protected]>
Signed-off-by: SzymonOzog <[email protected]>
Signed-off-by: SzymonOzog <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
Co-authored-by: Mark McLoughlin <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Qubitium-ModelCloud <[email protected]>
Co-authored-by: mgoin <[email protected]>
Co-authored-by: iefgnoix <[email protected]>
Co-authored-by: Travis Johnson <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Divakar Verma <[email protected]>
Co-authored-by: Cody Yu <[email protected]>
Co-authored-by: Rui Qiao <[email protected]>
Co-authored-by: Zhanwen Chen <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Co-authored-by: lkchen <[email protected]>
Co-authored-by: kushanam <[email protected]>
Co-authored-by: Siyuan Liu <[email protected]>
Co-authored-by: Kuntai Du <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Sage Moore <[email protected]>
Co-authored-by: Nishidha <[email protected]>
Co-authored-by: rainkert <[email protected]>
Co-authored-by: dangshunya <[email protected]>
Co-authored-by: Congcong Chen <[email protected]>
Co-authored-by: Benjamin Chislett <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Co-authored-by: Lu Fang <[email protected]>
Co-authored-by: Iacopo Poli <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Zhe Zhang <[email protected]>
Co-authored-by: Robert Shaw <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Co-authored-by: Jee Jee Li <[email protected]>
Co-authored-by: DaividFrank <[email protected]>
Co-authored-by: Simon Mo <[email protected]>
Co-authored-by: Vincent <[email protected]>
Co-authored-by: Brayden Zhong <[email protected]>
Co-authored-by: Ye Cao <[email protected]>
Co-authored-by: Serena <[email protected]>
Co-authored-by: pyc96 <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Ying Zhong <[email protected]>
Co-authored-by: Russell Bryant <[email protected]>
Co-authored-by: Yuan Tang <[email protected]>
Co-authored-by: Ce Gao <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Nicolò Lucchesi <[email protected]>
Co-authored-by: Pavani Majety <[email protected]>
Co-authored-by: kYLe <[email protected]>
Co-authored-by: NickLucche <[email protected]>
Co-authored-by: Yanyi Liu <[email protected]>
Co-authored-by: Irina Yuryeva <[email protected]>
Co-authored-by: Thomas Parnell <[email protected]>
Co-authored-by: courage17340 <[email protected]>
Co-authored-by: Jitse Klomp <[email protected]>
Co-authored-by: Dilip Gowda Bhagavan <[email protected]>
Co-authored-by: Rishika Kedia <[email protected]>
Co-authored-by: Burkhard Ringlein <[email protected]>
Co-authored-by: Jan van Lunteren <[email protected]>
Co-authored-by: Himanshu Jaju <[email protected]>
Co-authored-by: Chengji Yao <[email protected]>
Co-authored-by: Daniel Li <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Co-authored-by: Ilya Lavrenov <[email protected]>
Co-authored-by: Peng Li <[email protected]>
Co-authored-by: மனோஜ்குமார் பழனிச்சாமி <[email protected]>
Co-authored-by: Aleksandr Malyshev <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: Aleksandr Malyshev <[email protected]>
Co-authored-by: Aaron Pham <[email protected]>
Co-authored-by: York-RDWang <[email protected]>
Co-authored-by: Jeremy Arnold <[email protected]>
Co-authored-by: Jinzhen Lin <[email protected]>
Co-authored-by: yarongmu-google <[email protected]>
Co-authored-by: afeldman-nm <[email protected]>
Co-authored-by: Mathis Felardos <[email protected]>
Co-authored-by: Aviv Keshet <[email protected]>
Co-authored-by: Roger Meier <[email protected]>
Co-authored-by: Robin <[email protected]>
Co-authored-by: Alexander Matveev <[email protected]>
Co-authored-by: 22quinn <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Co-authored-by: Jiayi Yao <[email protected]>
Co-authored-by: Yuchen Yan <[email protected]>
Co-authored-by: Martin Hoyer <[email protected]>
Co-authored-by: Jennifer Zhao <[email protected]>
Co-authored-by: Jennifer Zhao <[email protected]>
Co-authored-by: Chauncey <[email protected]>
Co-authored-by: Szymon Ożóg <[email protected]>
Co-authored-by: Gregory Shtrasberg <[email protected]>
Co-authored-by: Gregory Shtrasberg <[email protected]>
Co-authored-by: Mcirino1 <[email protected]>
Co-authored-by: arakowsk-amd <[email protected]>
captainzmc pushed a commit to captainzmc/vllm that referenced this pull request Mar 12, 2025
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants