-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[TRTLLM-5974][feat] Support disaggregated serving in TRTLLM Sampler #5328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
dcampora
merged 6 commits into
NVIDIA:main
from
dcampora:user/dcampora/support_ds_in_trtllm_sampler
Jun 25, 2025
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
45dc34b
Added support for DS on TRTLLM Sampler.
dcampora 523de97
Formatting.
dcampora 0c02df2
Added missing file.
dcampora 634b8c7
Remove prints.
dcampora 6396845
Look for one more token in is_overlap.
dcampora 7cea569
Fix prepare_resources key.
dcampora File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from tensorrt_llm.bindings.executor import FinishReason | ||
|
||
|
||
class FinishedState: | ||
# State flags | ||
FINISHED_EOS = 1 << 0 | ||
FINISHED_STOP_WORDS = 1 << 1 | ||
FINISHED_MAX_LENGTH = 1 << 2 | ||
FINISHED = FINISHED_EOS | FINISHED_STOP_WORDS | FINISHED_MAX_LENGTH | ||
SKIP_DECODING = 1 << 3 | ||
|
||
def __init__(self, state=0): | ||
self._state = state | ||
|
||
@classmethod | ||
def empty(cls): | ||
return cls(0) | ||
|
||
@classmethod | ||
def finished(cls): | ||
return cls(cls.FINISHED) | ||
|
||
@classmethod | ||
def skip_decoding(cls): | ||
return cls(cls.SKIP_DECODING) | ||
|
||
@classmethod | ||
def finished_eos(cls): | ||
return cls(cls.FINISHED_EOS) | ||
|
||
@classmethod | ||
def finished_max_length(cls): | ||
return cls(cls.FINISHED_MAX_LENGTH) | ||
|
||
@classmethod | ||
def finished_stop_words(cls): | ||
return cls(cls.FINISHED_STOP_WORDS) | ||
|
||
def set_finished_eos(self): | ||
self._state |= self.FINISHED_EOS | ||
|
||
@property | ||
def is_finished_eos(self): | ||
return self._any_bit_set(self.FINISHED_EOS) | ||
|
||
def set_finished_stop_words(self): | ||
self._state |= self.FINISHED_STOP_WORDS | ||
|
||
@property | ||
def is_finished_stop_words(self): | ||
return self._any_bit_set(self.FINISHED_STOP_WORDS) | ||
|
||
def set_finished_max_length(self): | ||
self._state |= self.FINISHED_MAX_LENGTH | ||
|
||
@property | ||
def is_finished_max_length(self): | ||
return self._any_bit_set(self.FINISHED_MAX_LENGTH) | ||
|
||
def set_finished(self): | ||
self._state |= self.FINISHED | ||
|
||
@property | ||
def is_finished(self): | ||
return self._any_bit_set(self.FINISHED) | ||
|
||
def set_skip_decoding(self): | ||
self._state |= self.SKIP_DECODING | ||
|
||
@property | ||
def is_skip_decoding(self): | ||
return self._any_bit_set(self.SKIP_DECODING) | ||
|
||
def to_finish_reason(self): | ||
if self.is_finished_eos: | ||
return FinishReason.END_ID | ||
if self.is_finished_stop_words: | ||
return FinishReason.STOP_WORDS | ||
if self.is_finished_max_length: | ||
return FinishReason.LENGTH | ||
return FinishReason.NOT_FINISHED | ||
|
||
def to_underlying(self): | ||
return self._state | ||
|
||
def _any_bit_set(self, bits): | ||
return (self._state & bits) != 0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
35 changes: 35 additions & 0 deletions
35
tests/integration/defs/disaggregated/test_configs/disagg_config_trtllm_sampler.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 | ||
hostname: localhost | ||
port: 8000 | ||
backend: "pytorch" | ||
free_gpu_memory_fraction: 0.2 | ||
context_servers: | ||
num_instances: 1 | ||
max_batch_size: 1 | ||
max_num_tokens: 3000 | ||
max_seq_len: 4096 | ||
tensor_parallel_size: 1 | ||
pipeline_parallel_size: 1 | ||
enable_trtllm_sampler: True | ||
kv_cache_config: | ||
free_gpu_memory_fraction: 0.2 | ||
enable_partial_reuse: False | ||
use_cuda_graph: False | ||
disable_overlap_scheduler: True | ||
urls: | ||
- "localhost:8001" | ||
generation_servers: | ||
num_instances: 1 | ||
tensor_parallel_size: 1 | ||
pipeline_parallel_size: 1 | ||
max_batch_size: 256 | ||
max_num_tokens: 4096 | ||
max_seq_len: 4096 | ||
enable_trtllm_sampler: True | ||
kv_cache_config: | ||
free_gpu_memory_fraction: 0.2 | ||
enable_partial_reuse: False | ||
use_cuda_graph: False | ||
disable_overlap_scheduler: False | ||
urls: | ||
- "localhost:8002" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.