Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
040ea59
feat(scaffolding): support more parameters in openai worker
ccs96307 Jun 11, 2025
e4ce9de
Merge branch 'main' into add-scaffolding-openai-worker-parameters
ccs96307 Jun 11, 2025
8d3a825
Merge branch 'main' into add-scaffolding-openai-worker-parameters
ccs96307 Jun 12, 2025
c1b897e
refactor: Revert param conversation to explicit calls
ccs96307 Jun 12, 2025
f5f3b13
Merge branch 'main' into add-scaffolding-openai-worker-parameters
ccs96307 Jun 12, 2025
cb7dd61
Merge branch 'main' into add-scaffolding-openai-worker-parameters
ccs96307 Jun 12, 2025
b5a2e30
Merge branch 'main' into add-scaffolding-openai-worker-parameters
ccs96307 Jun 13, 2025
d917766
Merge branch 'main' into add-scaffolding-openai-worker-parameters
ccs96307 Jun 13, 2025
a4dac83
Merge branch 'main' into add-scaffolding-openai-worker-parameters
ccs96307 Jun 15, 2025
82233a4
Merge branch 'main' into add-scaffolding-openai-worker-parameters
ccs96307 Jun 16, 2025
66d0206
Merge branch 'main' into add-scaffolding-openai-worker-parameters
ccs96307 Jun 17, 2025
771f64b
Merge branch 'main' into add-scaffolding-openai-worker-parameters
ccs96307 Jun 18, 2025
838f60d
Merge branch 'main' into add-scaffolding-openai-worker-parameters
ccs96307 Jun 19, 2025
a72aaa1
Merge branch 'main' into add-scaffolding-openai-worker-parameters
ccs96307 Jun 20, 2025
45631a9
Merge branch 'main' into add-scaffolding-openai-worker-parameters
ccs96307 Jun 21, 2025
13c5537
Merge branch 'main' into add-scaffolding-openai-worker-parameters
ccs96307 Jun 22, 2025
e4ec130
Merge branch 'main' into add-scaffolding-openai-worker-parameters
ccs96307 Jun 26, 2025
7ce90f1
Merge branch 'main' into add-scaffolding-openai-worker-parameters
ccs96307 Jul 1, 2025
5f312df
Merge branch 'main' into add-scaffolding-openai-worker-parameters
ccs96307 Jul 3, 2025
8957a45
Merge branch 'main' into add-scaffolding-openai-worker-parameters
ccs96307 Jul 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions tensorrt_llm/scaffolding/task.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

import torch

from tensorrt_llm.serve.openai_protocol import StreamOptions


class ScaffoldingOutput:

Expand Down Expand Up @@ -37,10 +39,28 @@ class GenerationTask(Task):
skip_tokenizer: bool = False
skip_detokenizer: bool = False

# sampling params
# sampling params for openai
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
# The special case is `num_logprobs`, its original name si `logprobs` but conflicted by the result field
best_of: Optional[int] = None
echo: Optional[bool] = False
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
num_logprobs: Optional[int] = None
max_tokens: Optional[int] = 2048
n: int = 1
presence_penalty: Optional[float] = 0.0
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = field(default_factory=list)
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
suffix: Optional[str] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
user: Optional[str] = None

# sampling params
top_k: Optional[int] = None
return_context_logits: Optional[bool] = False

Expand Down
16 changes: 16 additions & 0 deletions tensorrt_llm/scaffolding/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,25 @@ def convert_task_params(self, task: GenerationTask):
"model": self.model,
"prompt": task.input_str,
}
add_param_if_not_none(params, "best_of", [task.best_of])
add_param_if_not_none(params, "echo", [task.echo])
add_param_if_not_none(params, "frequency_penalty",
[task.frequency_penalty])
add_param_if_not_none(params, "logit_bias", [task.logit_bias])
add_param_if_not_none(params, "logprobs", [task.num_logprobs])
add_param_if_not_none(params, "max_tokens", [task.max_tokens])
add_param_if_not_none(params, "n", [task.n])
add_param_if_not_none(params, "presence_penalty",
[task.presence_penalty])
add_param_if_not_none(params, "seed", [task.seed])
add_param_if_not_none(params, "stop", [task.stop])
add_param_if_not_none(params, "stream", [task.stream])
add_param_if_not_none(params, "stream_options", [task.stream_options])
add_param_if_not_none(params, "suffix", [task.suffix])
add_param_if_not_none(params, "temperature", [task.temperature])
add_param_if_not_none(params, "top_p", [task.top_p])
add_param_if_not_none(params, "user", [task.user])

return params

def fill_generation_task_with_response(self, task: GenerationTask,
Expand Down