diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index d2ee22a214c..9726b69e864 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -709,6 +709,9 @@ def _process_requests(self, batched_strategy = strategies[0] else: batched_strategy = None + else: + assert len(set(strategies)) == 1, "mixed sampler is not enabled" + batched_strategy = strategies[0] generator = self.get_generator(raw_logits.device) if batched_strategy is not None: logits = raw_logits[:sum_steps] diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 3ba66ab5fc0..abf202f31fb 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -220,13 +220,12 @@ class CompletionRequest(OpenAIBaseModel): stream_options: Optional[StreamOptions] = None suffix: Optional[str] = None temperature: Optional[float] = 1.0 - top_p: Optional[float] = 1.0 + top_p: Optional[float] = Field(None) user: Optional[str] = None lora_request: Optional[LoRARequest] = None # doc: begin-completion-sampling-params use_beam_search: bool = False - top_k: int = 0 top_p_min: float = 0.0 min_p: float = 0.0 repetition_penalty: float = 1.0 @@ -279,7 +278,6 @@ def to_sampling_params(self, vocab_size: int = 32000) -> SamplingParams: # completion-sampling-params use_beam_search=self.use_beam_search, - top_k=self.top_k, top_p_min=self.top_p_min if self.top_p_min > 0 else None, min_p=self.min_p, repetition_penalty=self.repetition_penalty, @@ -510,7 +508,7 @@ class ChatCompletionRequest(OpenAIBaseModel): stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None temperature: Optional[float] = 1.0 - top_p: Optional[float] = 1.0 + top_p: Optional[float] = Field(None) tools: Optional[List[ChatCompletionToolsParam]] = None tool_choice: Optional[Union[Literal["none", "auto"], ChatCompletionNamedToolChoiceParam]] = "none" @@ -527,7 +525,6 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: begin-chat-completion-sampling-params best_of: Optional[int] = None use_beam_search: bool = False - top_k: int = 0 top_p_min: float = 0.0 min_p: float = 0.0 repetition_penalty: float = 1.0 @@ -618,7 +615,6 @@ def to_sampling_params(self, # chat-completion-sampling-params best_of=self.best_of, use_beam_search=self.use_beam_search, - top_k=self.top_k, top_p=self.top_p, top_p_min=self.top_p_min if self.top_p_min > 0 else None, min_p=self.min_p, diff --git a/tests/unittest/_torch/speculative/test_draft_target.py b/tests/unittest/_torch/speculative/test_draft_target.py index 9aaa81e8375..f21b36bb61c 100644 --- a/tests/unittest/_torch/speculative/test_draft_target.py +++ b/tests/unittest/_torch/speculative/test_draft_target.py @@ -14,10 +14,12 @@ from utils.util import similar -@pytest.mark.parametrize("use_cuda_graph,attn_backend", - [[False, "TRTLLM"], [True, "TRTLLM"]]) +@pytest.mark.parametrize( + "use_cuda_graph,attn_backend,use_greedy_sampling", + [[False, "TRTLLM", True], [True, "TRTLLM", True], [True, "TRTLLM", False]]) @pytest.mark.high_cuda_memory -def test_llama_draft_target(use_cuda_graph: bool, attn_backend: str): +def test_llama_draft_target(use_cuda_graph: bool, attn_backend: str, + use_greedy_sampling: bool): total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if total_mem_gb < 60: pytest.skip("Not enough memory to load target model") @@ -52,7 +54,10 @@ def test_llama_draft_target(use_cuda_graph: bool, attn_backend: str): "The capital of France is", "The president of the United States is", ] - sampling_params = SamplingParams(max_tokens=32) + + sampling_params = SamplingParams( + max_tokens=32) if use_greedy_sampling else SamplingParams( + max_tokens=32, top_p=0.9, temperature=1.0) llm_spec = LLM(**llm_common_config, speculative_config=spec_config) results_spec = llm_spec.generate(prompts, sampling_params) @@ -66,7 +71,9 @@ def test_llama_draft_target(use_cuda_graph: bool, attn_backend: str): for text_spec, text_ref in zip(generated_text_spec, generated_text_ref): # The spec decode algorithm currently guarantees identical results - assert similar(text_spec, text_ref) + # Skip the reference check for non-greedy sampling as the output is not deterministic + if use_greedy_sampling: + assert similar(text_spec, text_ref) if __name__ == "__main__":