From 7f9ccb2d53d9a14ffefcf26a7d444e3407fd36c6 Mon Sep 17 00:00:00 2001 From: Dan Lord Date: Tue, 26 Sep 2023 12:28:28 -0700 Subject: [PATCH 1/4] Implement keep special tokens as sampling params. --- vllm/engine/llm_engine.py | 7 ++++--- vllm/sampling_params.py | 6 +++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c8d7164d3b4c..247df42b58f6 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -387,7 +387,7 @@ def _process_sequence_group_samples( child_seqs.append((parent, parent)) for seq, _ in child_seqs: - self._decode_sequence(seq) + self._decode_sequence(seq, seq_group.sampling_params) self._check_stop(seq, seq_group.sampling_params) # Non-beam search case @@ -621,7 +621,8 @@ def _log_system_stats( f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") self.last_logging_time = now - def _decode_sequence(self, seq: Sequence) -> None: + def _decode_sequence(self, seq: Sequence, + sampling_params: SamplingParams) -> None: """Decodes the new token for a sequence.""" (new_tokens, new_output_text, prefix_offset, read_offset) = detokenize_incrementally( @@ -630,7 +631,7 @@ def _decode_sequence(self, seq: Sequence) -> None: prev_tokens=seq.tokens, prefix_offset=seq.prefix_offset, read_offset=seq.read_offset, - skip_special_tokens=True, + skip_special_tokens=not sampling_params.keep_special_tokens, ) if seq.tokens is None: seq.tokens = new_tokens diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 53bd743fce9d..62008de15731 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -60,6 +60,7 @@ class SamplingParams: tokens after the EOS token is generated. max_tokens: Maximum number of tokens to generate per output sequence. logprobs: Number of log probabilities to return per output token. + keep_special_tokens: Whether to keep special tokens in the output """ def __init__( @@ -79,6 +80,7 @@ def __init__( ignore_eos: bool = False, max_tokens: int = 16, logprobs: Optional[int] = None, + keep_special_tokens: bool = False, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n @@ -103,6 +105,7 @@ def __init__( self.ignore_eos = ignore_eos self.max_tokens = max_tokens self.logprobs = logprobs + self.keep_special_tokens = keep_special_tokens self._verify_args() if self.use_beam_search: @@ -196,4 +199,5 @@ def __repr__(self) -> str: f"stop={self.stop}, " f"ignore_eos={self.ignore_eos}, " f"max_tokens={self.max_tokens}, " - f"logprobs={self.logprobs})") + f"logprobs={self.logprobs}, " + f"keep_special_tokens={self.keep_special_tokens})") From dc3a5e1c0f423bb83b102fa0e55b572eaf08bfc7 Mon Sep 17 00:00:00 2001 From: Dan Lord Date: Tue, 26 Sep 2023 14:55:51 -0700 Subject: [PATCH 2/4] Pass new arg through entrypoints. --- vllm/entrypoints/openai/api_server.py | 2 ++ vllm/entrypoints/openai/protocol.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d260396e47c4..199dc0c90f22 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -225,6 +225,7 @@ async def create_chat_completion(request: ChatCompletionRequest, top_k=request.top_k, ignore_eos=request.ignore_eos, use_beam_search=request.use_beam_search, + keep_special_tokens=request.keep_special_tokens, ) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) @@ -426,6 +427,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): max_tokens=request.max_tokens, logprobs=request.logprobs, use_beam_search=request.use_beam_search, + keep_special_tokens=request.keep_special_tokens, ) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 473400a7faf9..b21ec40dd4c1 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -71,6 +71,7 @@ class ChatCompletionRequest(BaseModel): ignore_eos: Optional[bool] = False use_beam_search: Optional[bool] = False stop_token_ids: Optional[List[int]] = Field(default_factory=list) + keep_special_tokens: Optional[bool] = False class CompletionRequest(BaseModel): @@ -96,6 +97,7 @@ class CompletionRequest(BaseModel): ignore_eos: Optional[bool] = False use_beam_search: Optional[bool] = False stop_token_ids: Optional[List[int]] = Field(default_factory=list) + keep_special_tokens: Optional[bool] = False class LogProbs(BaseModel): From 8f34113da9e8ac337188d3e19c2424cc43a98f5b Mon Sep 17 00:00:00 2001 From: Dan Lord Date: Wed, 27 Sep 2023 16:34:00 -0700 Subject: [PATCH 3/4] Change `keep_special_tokens` to `skip_special_tokens` and invert default logic. --- vllm/engine/llm_engine.py | 2 +- vllm/entrypoints/openai/api_server.py | 4 ++-- vllm/entrypoints/openai/protocol.py | 4 ++-- vllm/sampling_params.py | 8 ++++---- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 247df42b58f6..9e39fe44da71 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -631,7 +631,7 @@ def _decode_sequence(self, seq: Sequence, prev_tokens=seq.tokens, prefix_offset=seq.prefix_offset, read_offset=seq.read_offset, - skip_special_tokens=not sampling_params.keep_special_tokens, + skip_special_tokens=sampling_params.skip_special_tokens, ) if seq.tokens is None: seq.tokens = new_tokens diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 199dc0c90f22..643dd06cb17d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -225,7 +225,7 @@ async def create_chat_completion(request: ChatCompletionRequest, top_k=request.top_k, ignore_eos=request.ignore_eos, use_beam_search=request.use_beam_search, - keep_special_tokens=request.keep_special_tokens, + skip_special_tokens=request.skip_special_tokens, ) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) @@ -427,7 +427,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): max_tokens=request.max_tokens, logprobs=request.logprobs, use_beam_search=request.use_beam_search, - keep_special_tokens=request.keep_special_tokens, + skip_special_tokens=request.skip_special_tokens, ) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index b21ec40dd4c1..12b7453de819 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -71,7 +71,7 @@ class ChatCompletionRequest(BaseModel): ignore_eos: Optional[bool] = False use_beam_search: Optional[bool] = False stop_token_ids: Optional[List[int]] = Field(default_factory=list) - keep_special_tokens: Optional[bool] = False + skip_special_tokens: Optional[bool] = True class CompletionRequest(BaseModel): @@ -97,7 +97,7 @@ class CompletionRequest(BaseModel): ignore_eos: Optional[bool] = False use_beam_search: Optional[bool] = False stop_token_ids: Optional[List[int]] = Field(default_factory=list) - keep_special_tokens: Optional[bool] = False + skip_special_tokens: Optional[bool] = True class LogProbs(BaseModel): diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 62008de15731..54b55b058b40 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -60,7 +60,7 @@ class SamplingParams: tokens after the EOS token is generated. max_tokens: Maximum number of tokens to generate per output sequence. logprobs: Number of log probabilities to return per output token. - keep_special_tokens: Whether to keep special tokens in the output + skip_special_tokens: Whether to skip special tokens in the output. Defaults to true. """ def __init__( @@ -80,7 +80,7 @@ def __init__( ignore_eos: bool = False, max_tokens: int = 16, logprobs: Optional[int] = None, - keep_special_tokens: bool = False, + skip_special_tokens: bool = True, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n @@ -105,7 +105,7 @@ def __init__( self.ignore_eos = ignore_eos self.max_tokens = max_tokens self.logprobs = logprobs - self.keep_special_tokens = keep_special_tokens + self.skip_special_tokens = skip_special_tokens self._verify_args() if self.use_beam_search: @@ -200,4 +200,4 @@ def __repr__(self) -> str: f"ignore_eos={self.ignore_eos}, " f"max_tokens={self.max_tokens}, " f"logprobs={self.logprobs}, " - f"keep_special_tokens={self.keep_special_tokens})") + f"skip_special_tokens={self.skip_special_tokens})") From d25e74407e0c9915c16afb8ba5f17633a3dd24cc Mon Sep 17 00:00:00 2001 From: Dan Lord Date: Wed, 27 Sep 2023 17:14:52 -0700 Subject: [PATCH 4/4] Fix linter error. --- vllm/sampling_params.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 54b55b058b40..5206eb0b8c4d 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -60,7 +60,8 @@ class SamplingParams: tokens after the EOS token is generated. max_tokens: Maximum number of tokens to generate per output sequence. logprobs: Number of log probabilities to return per output token. - skip_special_tokens: Whether to skip special tokens in the output. Defaults to true. + skip_special_tokens: Whether to skip special tokens in the output. + Defaults to true. """ def __init__(