From a3748de6d444900600c6cffd9349e7529bdf4df5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Olejniczak?= Date: Fri, 26 Sep 2025 13:40:51 +0300 Subject: [PATCH] [Core] Fix torch.dynamo compatibility for Qwen models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace NumPy-based make_tensor_with_pad with pure-PyTorch implementation - Add None-safety to penalty application functions - Fix max() builtin compilation issues with torch.dynamo - Enables Qwen2.5-14B-Instruct to run successfully on vllm-gaudi Fixes compatibility issues where torch.dynamo guard failures occurred due to dispatch key set mismatches and AttributeErrors when applying repetition penalties to logits. Signed-off-by: Paweł Olejniczak --- vllm/model_executor/layers/utils.py | 30 +++++++++++++++++++++-------- vllm/utils/__init__.py | 25 ++++++++++++++++-------- vllm/v1/sample/sampler.py | 1 - 3 files changed, 39 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 96dd58c0e4d2..fa00e8495c06 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -37,6 +37,8 @@ def get_token_bin_counts_and_mask( vocab_size: int, num_seqs: int, ) -> tuple[torch.Tensor, torch.Tensor]: + if tokens is None: + return None, None # Compute the bin counts for the tokens. # vocab_size + 1 for padding. bin_counts = torch.zeros((num_seqs, vocab_size + 1), @@ -73,15 +75,27 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, output_bin_counts, output_mask = get_token_bin_counts_and_mask( output_tokens_tensor, vocab_size, num_seqs) - # Apply repetition penalties as a custom op - from vllm._custom_ops import apply_repetition_penalties - apply_repetition_penalties(logits, prompt_mask, output_mask, - repetition_penalties) + if prompt_mask is not None or output_mask is not None: + from vllm._custom_ops import apply_repetition_penalties + + if prompt_mask is None: + prompt_mask = torch.zeros((num_seqs, vocab_size), + dtype=torch.bool, + device=logits.device) + if output_mask is None: + output_mask = torch.zeros((num_seqs, vocab_size), + dtype=torch.bool, + device=logits.device) + + apply_repetition_penalties(logits, prompt_mask, output_mask, + repetition_penalties) + + if output_bin_counts is not None: + logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts + + if output_mask is not None: + logits -= presence_penalties.unsqueeze(dim=1) * output_mask - # We follow the definition in OpenAI API. - # Refer to https://platform.openai.com/docs/api-reference/parameter-details - logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts - logits -= presence_penalties.unsqueeze(dim=1) * output_mask return logits diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index c502a69ea500..f5021039bbf2 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -1175,7 +1175,7 @@ def make_ndarray_with_pad( """ if max_len is None: # Unlike for most functions, map is faster than a genexpr over `len` - max_len = max(map(len, x), default=0) + max_len = max(map(len, x)) if x else 0 padded_x = np.full((len(x), max_len), pad, dtype=dtype) for ind, blocktb in enumerate(x): @@ -1196,18 +1196,27 @@ def make_tensor_with_pad( ) -> torch.Tensor: """ Make a padded tensor from 2D inputs. - The padding is applied to the end of each inner list until it reaches `max_len`. """ - np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype] - padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len) + if max_len is None: + max_len = max(len(row) for row in x) if x else 0 + + padded_tensor = torch.full((len(x), max_len), + fill_value=pad, + dtype=dtype, + device=device) + + for i, row in enumerate(x): + row_len = len(row) + if row_len > 0: + row_tensor = torch.as_tensor(row, dtype=dtype, device=device) + padded_tensor[i, :row_len] = row_tensor - tensor = torch.from_numpy(padded_x).to(device) - if pin_memory: - tensor = tensor.pin_memory() + if pin_memory and padded_tensor.device.type == 'cpu': + return padded_tensor.pin_memory() - return tensor + return padded_tensor def async_tensor_h2d( diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 83ea766b1b4a..381f67c0ca23 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -250,7 +250,6 @@ def apply_penalties( sampling_metadata: SamplingMetadata, ) -> torch.Tensor: if not sampling_metadata.no_penalties: - assert sampling_metadata.prompt_token_ids is not None logits = apply_all_penalties( logits, sampling_metadata.prompt_token_ids,