Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/finish_reason.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from tensorrt_llm.bindings.executor import FinishReason


class FinishedState:
# State flags
FINISHED_EOS = 1 << 0
FINISHED_STOP_WORDS = 1 << 1
FINISHED_MAX_LENGTH = 1 << 2
FINISHED = FINISHED_EOS | FINISHED_STOP_WORDS | FINISHED_MAX_LENGTH
SKIP_DECODING = 1 << 3

def __init__(self, state=0):
self._state = state

@classmethod
def empty(cls):
return cls(0)

@classmethod
def finished(cls):
return cls(cls.FINISHED)

@classmethod
def skip_decoding(cls):
return cls(cls.SKIP_DECODING)

@classmethod
def finished_eos(cls):
return cls(cls.FINISHED_EOS)

@classmethod
def finished_max_length(cls):
return cls(cls.FINISHED_MAX_LENGTH)

@classmethod
def finished_stop_words(cls):
return cls(cls.FINISHED_STOP_WORDS)

def set_finished_eos(self):
self._state |= self.FINISHED_EOS

@property
def is_finished_eos(self):
return self._any_bit_set(self.FINISHED_EOS)

def set_finished_stop_words(self):
self._state |= self.FINISHED_STOP_WORDS

@property
def is_finished_stop_words(self):
return self._any_bit_set(self.FINISHED_STOP_WORDS)

def set_finished_max_length(self):
self._state |= self.FINISHED_MAX_LENGTH

@property
def is_finished_max_length(self):
return self._any_bit_set(self.FINISHED_MAX_LENGTH)

def set_finished(self):
self._state |= self.FINISHED

@property
def is_finished(self):
return self._any_bit_set(self.FINISHED)

def set_skip_decoding(self):
self._state |= self.SKIP_DECODING

@property
def is_skip_decoding(self):
return self._any_bit_set(self.SKIP_DECODING)

def to_finish_reason(self):
if self.is_finished_eos:
return FinishReason.END_ID
if self.is_finished_stop_words:
return FinishReason.STOP_WORDS
if self.is_finished_max_length:
return FinishReason.LENGTH
return FinishReason.NOT_FINISHED

def to_underlying(self):
return self._state

def _any_bit_set(self, bits):
return (self._state & bits) != 0
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,6 +1535,9 @@ def _prepare_disagg_gen_init(self, fitting_disagg_gen_init_requests):
self.resource_manager.resource_managers[
ResourceManagerType.KV_CACHE_MANAGER].prepare_resources(
disagg_gen_init_to_prepare)
self.resource_manager.resource_managers[
ResourceManagerType.SEQ_SLOT_MANAGER].prepare_resources(
disagg_gen_init_to_prepare)

# Trigger KV cache exchange for new disagg_gen_init_requests
self._recv_disagg_gen_cache(fitting_disagg_gen_init_requests)
Expand Down
9 changes: 6 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tensorrt_llm.executor.result import Logprob
from tensorrt_llm.mapping import Mapping

from .finish_reason import FinishedState
from .llm_request import LlmRequest, LlmRequestState
from .scheduler import ScheduledRequests

Expand Down Expand Up @@ -648,6 +649,7 @@ def update_requests(self, state: SampleStateTRTLLM):
for beam in range(beam_width):
seq_len = sequence_lengths_host_data[seq_slot * beam_width +
beam].item()
seq_len = seq_len + 1 if self.is_trt_overlap else seq_len
num_new_tokens[beam] = min(
num_generated_tokens,
seq_len - request.get_num_tokens(beam))
Expand Down Expand Up @@ -678,9 +680,10 @@ def update_requests(self, state: SampleStateTRTLLM):
state.host.cum_log_probs[seq_slot * beam_width +
beam].item())

finish_reason = finish_reasons_host[seq_slot * beam_width +
beam].item()
request.set_finished_reason(FinishReason(finish_reason), beam)
finish_reason = FinishedState(
finish_reasons_host[seq_slot * beam_width +
beam].item()).to_finish_reason()
request.set_finished_reason(finish_reason, beam)

if request.py_return_log_probs:
request.py_result.append_log_probs([log_probs], cum_log_probs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
hostname: localhost
port: 8000
backend: "pytorch"
free_gpu_memory_fraction: 0.2
context_servers:
num_instances: 1
max_batch_size: 1
max_num_tokens: 3000
max_seq_len: 4096
tensor_parallel_size: 1
pipeline_parallel_size: 1
enable_trtllm_sampler: True
kv_cache_config:
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
use_cuda_graph: False
disable_overlap_scheduler: True
urls:
- "localhost:8001"
generation_servers:
num_instances: 1
tensor_parallel_size: 1
pipeline_parallel_size: 1
max_batch_size: 256
max_num_tokens: 4096
max_seq_len: 4096
enable_trtllm_sampler: True
kv_cache_config:
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
use_cuda_graph: False
disable_overlap_scheduler: False
urls:
- "localhost:8002"
26 changes: 24 additions & 2 deletions tests/integration/defs/disaggregated/test_disaggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def get_test_config(test_desc, example_dir, test_root):
(2, f"{test_configs_root}/disagg_config_cuda_graph_padding.yaml"),
"mixed": (2, f"{test_configs_root}/disagg_config_mixed.yaml"),
"overlap": (2, f"{test_configs_root}/disagg_config_overlap.yaml"),
"trtllm_sampler":
(2, f"{test_configs_root}/disagg_config_trtllm_sampler.yaml"),
"load_balance":
(4, f"{test_configs_root}/disagg_config_load_balance.yaml"),
"cache_aware_balance":
Expand Down Expand Up @@ -179,7 +181,7 @@ def run_disaggregated_test(example_dir,
poll_procs=[workers_proc, server_proc])

# Run the chat completion endpoint test only for TinyLlama
if test_desc == "overlap":
if test_desc == "overlap" or test_desc == "trtllm_sampler":
chat_client_cmd = client_cmd + [
'-e', 'chat', '-o', 'output_chat.json'
]
Expand All @@ -198,7 +200,7 @@ def run_disaggregated_test(example_dir,
not_expected_strings = ["Berlin Berlin"]

output_files = ['output.json', 'output_streaming.json']
if test_desc == "overlap":
if test_desc == "overlap" or test_desc == "trtllm_sampler":
# Disable streaming chat completion for overlap test
# due to bug
output_files.extend(['output_chat.json'])
Expand Down Expand Up @@ -420,6 +422,26 @@ def test_disaggregated_overlap(disaggregated_test_root, llm_venv,
cwd=llm_venv.get_working_directory())


@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
indirect=True)
def test_disaggregated_trtllm_sampler(disaggregated_test_root, llm_venv,
disaggregated_example_root,
llama_model_root):
src_dst_dict = {
llama_model_root:
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}
for src, dst in src_dst_dict.items():
if not os.path.islink(dst):
os.makedirs(os.path.dirname(dst), exist_ok=True)
os.symlink(src, dst, target_is_directory=True)

run_disaggregated_test(disaggregated_example_root,
"trtllm_sampler",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())


@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
indirect=True)
def test_disaggregated_load_balance(disaggregated_test_root, llm_venv,
Expand Down