Skip to content

Commit 0236a0a

Browse files
committed
chore: Adjust cache indirection passing to AttentionMetadata
- Moved the cache indirection buffer into AttentionMetadata instead of TrtllmAttentionMetadata - Updated PyTorchModelEngine to utilize the cache indirection buffer conditionally based on the attention backend. - Combined the beam search testcases for overlap scheduling and cuda graphs. - Adjusted size estimation of cache indirection buffer in model_engine to correctly cover overlap scheduling Signed-off-by: Stefan Niebler <[email protected]>
1 parent 974c6c4 commit 0236a0a

File tree

4 files changed

+15
-17
lines changed

4 files changed

+15
-17
lines changed

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ class AttentionMetadata:
135135
_num_ctx_tokens: int = field(init=False, default=0, repr=False)
136136
_num_tokens: int = field(init=False, default=0, repr=False)
137137

138+
# This buffer is currently only used for TrtllmAttentionMetadata.
139+
cache_indirection: Optional[torch.Tensor] = None
140+
138141
def __post_init__(self) -> None:
139142
if self.is_cross:
140143
assert self.cross is None or self.cross is self, "Cross attention metadata should not have sub metadata"

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,10 +517,9 @@ def is_nvfp4_output_kernel_available(
517517
class TrtllmAttentionMetadata(AttentionMetadata):
518518
workspace: Optional[torch.Tensor] = None
519519

520-
# TrtllmAttention needs to know the beam width and access to the cache indirection buffer,
520+
# TrtllmAttention needs to know the beam width to access to the cache indirection buffer,
521521
# when beam search is enabled.
522522
beam_width: int = 1
523-
cache_indirection: Optional[torch.Tensor] = None
524523

525524
# TrtllmAttention needs to know the max sequence length.
526525
# Implemented as a property to support no cache mode.

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,8 @@ def __init__(
426426
# This way it can also be used for CUDA graphs.
427427
if self.use_beam_search:
428428
self.cache_indirection_attention = torch.zeros(
429-
(self.batch_size, self.max_beam_width, self.max_seq_len),
429+
(self.batch_size, self.max_beam_width, self.max_seq_len +
430+
(0 if self._disable_overlap_scheduler else 1)),
430431
device="cuda",
431432
dtype=torch.int32)
432433
else:
@@ -753,11 +754,7 @@ def _set_up_attn_metadata(self, kv_cache_manager: KVCacheManager):
753754
self.model.model_config.pretrained_config) and (
754755
self.attn_runtime_features.cache_reuse
755756
or self.attn_runtime_features.chunked_prefill)
756-
# Cache indirection is only used for beam search on generation requests with TRTLLM backend.
757-
if self.attn_backend.Metadata is TrtllmAttentionMetadata:
758-
kwargs = {"cache_indirection": self.cache_indirection_attention}
759-
else:
760-
kwargs = {}
757+
cache_indirection = self.cache_indirection_attention if self.attn_backend.Metadata is TrtllmAttentionMetadata else None
761758
if kv_cache_manager is None:
762759
return self.attn_backend.Metadata(
763760
max_num_requests=self.batch_size,
@@ -768,7 +765,7 @@ def _set_up_attn_metadata(self, kv_cache_manager: KVCacheManager):
768765
runtime_features=self.attn_runtime_features,
769766
enable_flash_mla=self.model.model_config.enable_flash_mla,
770767
enable_paged_context_mla=enable_paged_context_mla,
771-
**kwargs)
768+
cache_indirection=cache_indirection)
772769

773770
if self.attn_metadata is not None:
774771
# This assertion can be relaxed if needed: just create a new metadata
@@ -785,7 +782,7 @@ def _set_up_attn_metadata(self, kv_cache_manager: KVCacheManager):
785782
runtime_features=self.attn_runtime_features,
786783
enable_flash_mla=self.model.model_config.enable_flash_mla,
787784
enable_paged_context_mla=enable_paged_context_mla,
788-
**kwargs)
785+
cache_indirection=cache_indirection)
789786

790787
return self.attn_metadata
791788

tests/unittest/_torch/test_beam_search.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def llm(fixed_params, input_prompts):
5151

5252

5353
@pytest.fixture(scope="module")
54-
def llm_overlap(fixed_params, input_prompts):
54+
def llm_cuda_graph(fixed_params, input_prompts):
5555
return LLM(
5656
model=os.path.join(llm_models_root(), "llama-models-v2",
5757
"TinyLlama-1.1B-Chat-v1.0"),
@@ -63,8 +63,7 @@ def llm_overlap(fixed_params, input_prompts):
6363
enable_trtllm_sampler=True,
6464
max_beam_width=fixed_params["max_beam_width"],
6565
disable_overlap_scheduler=False,
66-
#TODO: remove this once we have a proper fix for CUDA graph in beam search
67-
cuda_graph_config=None,
66+
cuda_graph_config=CudaGraphConfig(enabled=True),
6867
)
6968

7069

@@ -131,10 +130,10 @@ def test_beam_search_output_shapes(gather_context_logits: bool,
131130
@pytest.mark.parametrize("num_output_beams", [1, 2])
132131
@pytest.mark.parametrize("num_prompts", [1, 2])
133132
@pytest.mark.threadleak(enabled=False)
134-
def test_beam_search_output_shapes_overlap(
133+
def test_beam_search_output_shapes_cuda_graph_and_overlap(
135134
gather_context_logits: bool, gather_generation_logits: bool,
136135
return_log_probs: bool, num_output_beams: int, num_prompts: int,
137-
llm_overlap, fixed_params, input_prompts, expected_outputs):
136+
llm_cuda_graph, fixed_params, input_prompts, expected_outputs):
138137
if return_log_probs and num_prompts > 1:
139138
pytest.skip(
140139
"Beam search currently does not support return_log_probs with multiple prompts"
@@ -148,8 +147,8 @@ def test_beam_search_output_shapes_overlap(
148147
return_generation_logits=gather_generation_logits,
149148
logprobs=return_log_probs,
150149
)
151-
outputs = llm_overlap.generate(input_prompts[:num_prompts],
152-
sampling_params=sampling_params)
150+
outputs = llm_cuda_graph.generate(input_prompts[:num_prompts],
151+
sampling_params=sampling_params)
153152
assert len(outputs) == num_prompts
154153
for output_idx, output in enumerate(outputs):
155154
if gather_context_logits:

0 commit comments

Comments
 (0)