Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,9 @@ def _process_requests(self,
batched_strategy = strategies[0]
else:
batched_strategy = None
else:
assert len(set(strategies)) == 1, "mixed sampler is not enabled"
batched_strategy = strategies[0]
generator = self.get_generator(raw_logits.device)
if batched_strategy is not None:
logits = raw_logits[:sum_steps]
Expand Down
8 changes: 2 additions & 6 deletions tensorrt_llm/serve/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,12 @@ class CompletionRequest(OpenAIBaseModel):
stream_options: Optional[StreamOptions] = None
suffix: Optional[str] = None
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
top_p: Optional[float] = Field(None)
user: Optional[str] = None
lora_request: Optional[LoRARequest] = None

# doc: begin-completion-sampling-params
use_beam_search: bool = False
top_k: int = 0
top_p_min: float = 0.0
min_p: float = 0.0
repetition_penalty: float = 1.0
Expand Down Expand Up @@ -279,7 +278,6 @@ def to_sampling_params(self, vocab_size: int = 32000) -> SamplingParams:

# completion-sampling-params
use_beam_search=self.use_beam_search,
top_k=self.top_k,
top_p_min=self.top_p_min if self.top_p_min > 0 else None,
min_p=self.min_p,
repetition_penalty=self.repetition_penalty,
Expand Down Expand Up @@ -510,7 +508,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
top_p: Optional[float] = Field(None)
tools: Optional[List[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none", "auto"],
ChatCompletionNamedToolChoiceParam]] = "none"
Expand All @@ -527,7 +525,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: begin-chat-completion-sampling-params
best_of: Optional[int] = None
use_beam_search: bool = False
top_k: int = 0
top_p_min: float = 0.0
min_p: float = 0.0
repetition_penalty: float = 1.0
Expand Down Expand Up @@ -618,7 +615,6 @@ def to_sampling_params(self,
# chat-completion-sampling-params
best_of=self.best_of,
use_beam_search=self.use_beam_search,
top_k=self.top_k,
top_p=self.top_p,
top_p_min=self.top_p_min if self.top_p_min > 0 else None,
min_p=self.min_p,
Expand Down
17 changes: 12 additions & 5 deletions tests/unittest/_torch/speculative/test_draft_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
from utils.util import similar


@pytest.mark.parametrize("use_cuda_graph,attn_backend",
[[False, "TRTLLM"], [True, "TRTLLM"]])
@pytest.mark.parametrize(
"use_cuda_graph,attn_backend,use_greedy_sampling",
[[False, "TRTLLM", True], [True, "TRTLLM", True], [True, "TRTLLM", False]])
@pytest.mark.high_cuda_memory
def test_llama_draft_target(use_cuda_graph: bool, attn_backend: str):
def test_llama_draft_target(use_cuda_graph: bool, attn_backend: str,
use_greedy_sampling: bool):
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 60:
pytest.skip("Not enough memory to load target model")
Expand Down Expand Up @@ -52,7 +54,10 @@ def test_llama_draft_target(use_cuda_graph: bool, attn_backend: str):
"The capital of France is",
"The president of the United States is",
]
sampling_params = SamplingParams(max_tokens=32)

sampling_params = SamplingParams(
max_tokens=32) if use_greedy_sampling else SamplingParams(
max_tokens=32, top_p=0.9, temperature=1.0)

llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
results_spec = llm_spec.generate(prompts, sampling_params)
Expand All @@ -66,7 +71,9 @@ def test_llama_draft_target(use_cuda_graph: bool, attn_backend: str):

for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
# The spec decode algorithm currently guarantees identical results
assert similar(text_spec, text_ref)
# Skip the reference check for non-greedy sampling as the output is not deterministic
if use_greedy_sampling:
assert similar(text_spec, text_ref)


if __name__ == "__main__":
Expand Down