diff --git a/tensorrt_llm/scaffolding/task.py b/tensorrt_llm/scaffolding/task.py index 92a278ebfb0..4900a650db3 100644 --- a/tensorrt_llm/scaffolding/task.py +++ b/tensorrt_llm/scaffolding/task.py @@ -1,9 +1,11 @@ from dataclasses import dataclass, field from enum import Enum -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import torch +from tensorrt_llm.serve.openai_protocol import StreamOptions + class ScaffoldingOutput: @@ -37,10 +39,28 @@ class GenerationTask(Task): skip_tokenizer: bool = False skip_detokenizer: bool = False - # sampling params + # sampling params for openai + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/completions/create + # The special case is `num_logprobs`, its original name si `logprobs` but conflicted by the result field + best_of: Optional[int] = None + echo: Optional[bool] = False + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = None + num_logprobs: Optional[int] = None max_tokens: Optional[int] = 2048 + n: int = 1 + presence_penalty: Optional[float] = 0.0 + seed: Optional[int] = None + stop: Optional[Union[str, List[str]]] = field(default_factory=list) + stream: Optional[bool] = False + stream_options: Optional[StreamOptions] = None + suffix: Optional[str] = None temperature: Optional[float] = None top_p: Optional[float] = None + user: Optional[str] = None + + # sampling params top_k: Optional[int] = None return_context_logits: Optional[bool] = False diff --git a/tensorrt_llm/scaffolding/worker.py b/tensorrt_llm/scaffolding/worker.py index 69086392648..667fb2614d6 100644 --- a/tensorrt_llm/scaffolding/worker.py +++ b/tensorrt_llm/scaffolding/worker.py @@ -73,9 +73,25 @@ def convert_task_params(self, task: GenerationTask): "model": self.model, "prompt": task.input_str, } + add_param_if_not_none(params, "best_of", [task.best_of]) + add_param_if_not_none(params, "echo", [task.echo]) + add_param_if_not_none(params, "frequency_penalty", + [task.frequency_penalty]) + add_param_if_not_none(params, "logit_bias", [task.logit_bias]) + add_param_if_not_none(params, "logprobs", [task.num_logprobs]) add_param_if_not_none(params, "max_tokens", [task.max_tokens]) + add_param_if_not_none(params, "n", [task.n]) + add_param_if_not_none(params, "presence_penalty", + [task.presence_penalty]) + add_param_if_not_none(params, "seed", [task.seed]) + add_param_if_not_none(params, "stop", [task.stop]) + add_param_if_not_none(params, "stream", [task.stream]) + add_param_if_not_none(params, "stream_options", [task.stream_options]) + add_param_if_not_none(params, "suffix", [task.suffix]) add_param_if_not_none(params, "temperature", [task.temperature]) add_param_if_not_none(params, "top_p", [task.top_p]) + add_param_if_not_none(params, "user", [task.user]) + return params def fill_generation_task_with_response(self, task: GenerationTask,