diff --git a/tensorrt_llm/_torch/pyexecutor/finish_reason.py b/tensorrt_llm/_torch/pyexecutor/finish_reason.py new file mode 100644 index 00000000000..6ed723fd94a --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/finish_reason.py @@ -0,0 +1,87 @@ +from tensorrt_llm.bindings.executor import FinishReason + + +class FinishedState: + # State flags + FINISHED_EOS = 1 << 0 + FINISHED_STOP_WORDS = 1 << 1 + FINISHED_MAX_LENGTH = 1 << 2 + FINISHED = FINISHED_EOS | FINISHED_STOP_WORDS | FINISHED_MAX_LENGTH + SKIP_DECODING = 1 << 3 + + def __init__(self, state=0): + self._state = state + + @classmethod + def empty(cls): + return cls(0) + + @classmethod + def finished(cls): + return cls(cls.FINISHED) + + @classmethod + def skip_decoding(cls): + return cls(cls.SKIP_DECODING) + + @classmethod + def finished_eos(cls): + return cls(cls.FINISHED_EOS) + + @classmethod + def finished_max_length(cls): + return cls(cls.FINISHED_MAX_LENGTH) + + @classmethod + def finished_stop_words(cls): + return cls(cls.FINISHED_STOP_WORDS) + + def set_finished_eos(self): + self._state |= self.FINISHED_EOS + + @property + def is_finished_eos(self): + return self._any_bit_set(self.FINISHED_EOS) + + def set_finished_stop_words(self): + self._state |= self.FINISHED_STOP_WORDS + + @property + def is_finished_stop_words(self): + return self._any_bit_set(self.FINISHED_STOP_WORDS) + + def set_finished_max_length(self): + self._state |= self.FINISHED_MAX_LENGTH + + @property + def is_finished_max_length(self): + return self._any_bit_set(self.FINISHED_MAX_LENGTH) + + def set_finished(self): + self._state |= self.FINISHED + + @property + def is_finished(self): + return self._any_bit_set(self.FINISHED) + + def set_skip_decoding(self): + self._state |= self.SKIP_DECODING + + @property + def is_skip_decoding(self): + return self._any_bit_set(self.SKIP_DECODING) + + def to_finish_reason(self): + if self.is_finished_eos: + return FinishReason.END_ID + if self.is_finished_stop_words: + return FinishReason.STOP_WORDS + if self.is_finished_max_length: + return FinishReason.LENGTH + return FinishReason.NOT_FINISHED + + def to_underlying(self): + return self._state + + def _any_bit_set(self, bits): + return (self._state & bits) != 0 diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index edab51fb7f5..73d8efdb524 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1535,6 +1535,9 @@ def _prepare_disagg_gen_init(self, fitting_disagg_gen_init_requests): self.resource_manager.resource_managers[ ResourceManagerType.KV_CACHE_MANAGER].prepare_resources( disagg_gen_init_to_prepare) + self.resource_manager.resource_managers[ + ResourceManagerType.SEQ_SLOT_MANAGER].prepare_resources( + disagg_gen_init_to_prepare) # Trigger KV cache exchange for new disagg_gen_init_requests self._recv_disagg_gen_cache(fitting_disagg_gen_init_requests) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 04a07dc502b..885dd0c47b6 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -23,6 +23,7 @@ from tensorrt_llm.executor.result import Logprob from tensorrt_llm.mapping import Mapping +from .finish_reason import FinishedState from .llm_request import LlmRequest, LlmRequestState from .scheduler import ScheduledRequests @@ -648,6 +649,7 @@ def update_requests(self, state: SampleStateTRTLLM): for beam in range(beam_width): seq_len = sequence_lengths_host_data[seq_slot * beam_width + beam].item() + seq_len = seq_len + 1 if self.is_trt_overlap else seq_len num_new_tokens[beam] = min( num_generated_tokens, seq_len - request.get_num_tokens(beam)) @@ -678,9 +680,10 @@ def update_requests(self, state: SampleStateTRTLLM): state.host.cum_log_probs[seq_slot * beam_width + beam].item()) - finish_reason = finish_reasons_host[seq_slot * beam_width + - beam].item() - request.set_finished_reason(FinishReason(finish_reason), beam) + finish_reason = FinishedState( + finish_reasons_host[seq_slot * beam_width + + beam].item()).to_finish_reason() + request.set_finished_reason(finish_reason, beam) if request.py_return_log_probs: request.py_result.append_log_probs([log_probs], cum_log_probs) diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_trtllm_sampler.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_trtllm_sampler.yaml new file mode 100644 index 00000000000..6dc423b61d6 --- /dev/null +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_trtllm_sampler.yaml @@ -0,0 +1,35 @@ +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +hostname: localhost +port: 8000 +backend: "pytorch" +free_gpu_memory_fraction: 0.2 +context_servers: + num_instances: 1 + max_batch_size: 1 + max_num_tokens: 3000 + max_seq_len: 4096 + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + enable_trtllm_sampler: True + kv_cache_config: + free_gpu_memory_fraction: 0.2 + enable_partial_reuse: False + use_cuda_graph: False + disable_overlap_scheduler: True + urls: + - "localhost:8001" +generation_servers: + num_instances: 1 + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + max_batch_size: 256 + max_num_tokens: 4096 + max_seq_len: 4096 + enable_trtllm_sampler: True + kv_cache_config: + free_gpu_memory_fraction: 0.2 + enable_partial_reuse: False + use_cuda_graph: False + disable_overlap_scheduler: False + urls: + - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_disaggregated.py b/tests/integration/defs/disaggregated/test_disaggregated.py index df28c8d13b5..34987c286ce 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated.py +++ b/tests/integration/defs/disaggregated/test_disaggregated.py @@ -50,6 +50,8 @@ def get_test_config(test_desc, example_dir, test_root): (2, f"{test_configs_root}/disagg_config_cuda_graph_padding.yaml"), "mixed": (2, f"{test_configs_root}/disagg_config_mixed.yaml"), "overlap": (2, f"{test_configs_root}/disagg_config_overlap.yaml"), + "trtllm_sampler": + (2, f"{test_configs_root}/disagg_config_trtllm_sampler.yaml"), "load_balance": (4, f"{test_configs_root}/disagg_config_load_balance.yaml"), "cache_aware_balance": @@ -179,7 +181,7 @@ def run_disaggregated_test(example_dir, poll_procs=[workers_proc, server_proc]) # Run the chat completion endpoint test only for TinyLlama - if test_desc == "overlap": + if test_desc == "overlap" or test_desc == "trtllm_sampler": chat_client_cmd = client_cmd + [ '-e', 'chat', '-o', 'output_chat.json' ] @@ -198,7 +200,7 @@ def run_disaggregated_test(example_dir, not_expected_strings = ["Berlin Berlin"] output_files = ['output.json', 'output_streaming.json'] - if test_desc == "overlap": + if test_desc == "overlap" or test_desc == "trtllm_sampler": # Disable streaming chat completion for overlap test # due to bug output_files.extend(['output_chat.json']) @@ -420,6 +422,26 @@ def test_disaggregated_overlap(disaggregated_test_root, llm_venv, cwd=llm_venv.get_working_directory()) +@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'], + indirect=True) +def test_disaggregated_trtllm_sampler(disaggregated_test_root, llm_venv, + disaggregated_example_root, + llama_model_root): + src_dst_dict = { + llama_model_root: + f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0", + } + for src, dst in src_dst_dict.items(): + if not os.path.islink(dst): + os.makedirs(os.path.dirname(dst), exist_ok=True) + os.symlink(src, dst, target_is_directory=True) + + run_disaggregated_test(disaggregated_example_root, + "trtllm_sampler", + env=llm_venv._new_env, + cwd=llm_venv.get_working_directory()) + + @pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'], indirect=True) def test_disaggregated_load_balance(disaggregated_test_root, llm_venv,