From e12904140b615d36b8abe6b892e0b18b7d7ff7a5 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 28 Feb 2024 19:37:48 -0800 Subject: [PATCH 1/3] [Fix] Don't deep-copy LogitsProcessors when copying SamplingParams --- vllm/engine/llm_engine.py | 5 +++-- vllm/sampling_params.py | 9 +++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f0fd7efdef81..f0cb53f998bc 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -483,8 +483,9 @@ def add_request( prompt_token_ids[:prefix_pos], lora_request.lora_int_id if lora_request else 0) if prefix_pos is not None else None - # Defensive copy of SamplingParams, which are used by the sampler - sampling_params = copy.deepcopy(sampling_params) + # Defensive copy of SamplingParams, which are used by the sampler, + # this doesn't deep-copy LogitsProcessor objects + sampling_params = sampling_params.copy() # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 51d39220ca9c..814298cd5866 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,4 +1,5 @@ """Sampling parameters for text generation.""" +import copy from enum import IntEnum from functools import cached_property from typing import Callable, List, Optional, Union @@ -237,6 +238,14 @@ def sampling_type(self) -> SamplingType: return SamplingType.RANDOM_SEED return SamplingType.RANDOM + def copy(self) -> "SamplingParams": + """ Deep copy excluding LogitsProcessor objects""" + logit_processor_refs = None if self.logits_processors is None else { + id(lp): lp + for lp in self.logits_processors + } + return copy.deepcopy(self, memo=logit_processor_refs) + def __repr__(self) -> str: return ( f"SamplingParams(n={self.n}, " From a49a5dfcb7c2e0b2df9d989c516e3802fd56e056 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 28 Feb 2024 21:21:37 -0800 Subject: [PATCH 2/3] Address review comments from @Yard1 --- vllm/engine/llm_engine.py | 2 +- vllm/sampling_params.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f0cb53f998bc..aec31c4ad60b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -485,7 +485,7 @@ def add_request( # Defensive copy of SamplingParams, which are used by the sampler, # this doesn't deep-copy LogitsProcessor objects - sampling_params = sampling_params.copy() + sampling_params = sampling_params.clone() # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 814298cd5866..34e7098bf377 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -238,8 +238,10 @@ def sampling_type(self) -> SamplingType: return SamplingType.RANDOM_SEED return SamplingType.RANDOM - def copy(self) -> "SamplingParams": - """ Deep copy excluding LogitsProcessor objects""" + def clone(self) -> "SamplingParams": + """Deep copy excluding LogitsProcessor objects, which may contain an + arbitrary/nontrivial amount of data.""" + logit_processor_refs = None if self.logits_processors is None else { id(lp): lp for lp in self.logits_processors From 6f5cd6f276aca3d83e71629440c30f70a4706ecc Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 29 Feb 2024 07:54:00 -0800 Subject: [PATCH 3/3] Add ref to original issue in docstring --- vllm/sampling_params.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 34e7098bf377..8103f3c2b24b 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -239,8 +239,12 @@ def sampling_type(self) -> SamplingType: return SamplingType.RANDOM def clone(self) -> "SamplingParams": - """Deep copy excluding LogitsProcessor objects, which may contain an - arbitrary/nontrivial amount of data.""" + """Deep copy excluding LogitsProcessor objects. + + LogitsProcessor objects are excluded because they may contain an + arbitrary, nontrivial amount of data. + See https://github.com/vllm-project/vllm/issues/3087 + """ logit_processor_refs = None if self.logits_processors is None else { id(lp): lp