|
1 | 1 | import json |
2 | 2 | from dataclasses import dataclass |
3 | 3 | from http import HTTPStatus |
4 | | -from typing import Dict, List, Optional, Tuple, Union |
| 4 | +from typing import Any, Dict, List, Optional, Tuple, Union |
5 | 5 |
|
6 | 6 | from pydantic import Field |
7 | 7 | from typing_extensions import Annotated |
@@ -165,24 +165,34 @@ def _maybe_get_lora( |
165 | 165 | raise ValueError(f"The model `{request.model}` does not exist.") |
166 | 166 |
|
167 | 167 | def _validate_prompt_and_tokenize( |
168 | | - self, |
169 | | - request: Union[ChatCompletionRequest, CompletionRequest, |
170 | | - EmbeddingRequest], |
171 | | - prompt: Optional[str] = None, |
172 | | - prompt_ids: Optional[List[int]] = None, |
173 | | - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None |
174 | | - ) -> Tuple[List[int], str]: |
| 168 | + self, |
| 169 | + request: Union[ChatCompletionRequest, CompletionRequest, |
| 170 | + EmbeddingRequest], |
| 171 | + prompt: Optional[str] = None, |
| 172 | + prompt_ids: Optional[List[int]] = None, |
| 173 | + truncate_prompt_tokens: Optional[Annotated[int, |
| 174 | + Field(ge=1)]] = None, |
| 175 | + add_special_tokens: bool = True) -> Tuple[List[int], str]: |
175 | 176 | if not (prompt or prompt_ids): |
176 | 177 | raise ValueError("Either prompt or prompt_ids should be provided.") |
177 | 178 | if (prompt and prompt_ids): |
178 | 179 | raise ValueError( |
179 | 180 | "Only one of prompt or prompt_ids should be provided.") |
180 | 181 |
|
181 | 182 | if prompt_ids is None: |
182 | | - tokenizer_kwargs = {} if truncate_prompt_tokens is None else { |
183 | | - "truncation": True, |
184 | | - "max_length": truncate_prompt_tokens, |
| 183 | + # When using OpenAIServingChat for chat completions, the |
| 184 | + # special tokens (e.g., BOS) have already been added by the |
| 185 | + # chat template. Therefore, we do not need to add them again. |
| 186 | + # Set add_special_tokens to False to avoid adding the BOS tokens |
| 187 | + # again. |
| 188 | + tokenizer_kwargs: Dict[str, Any] = { |
| 189 | + "add_special_tokens": add_special_tokens |
185 | 190 | } |
| 191 | + if truncate_prompt_tokens is not None: |
| 192 | + tokenizer_kwargs.update({ |
| 193 | + "truncation": True, |
| 194 | + "max_length": truncate_prompt_tokens, |
| 195 | + }) |
186 | 196 | input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids |
187 | 197 | elif truncate_prompt_tokens is not None: |
188 | 198 | input_ids = prompt_ids[-truncate_prompt_tokens:] |
|
0 commit comments