Skip to content

Commit 253af9f

Browse files
authored
[https://nvbugs/5410391][bug] Support to share device buffers in attention meta (#6557)
Signed-off-by: Hui Gao <[email protected]>
1 parent 79f1e6c commit 253af9f

File tree

6 files changed

+85
-23
lines changed

6 files changed

+85
-23
lines changed

tensorrt_llm/_torch/attention_backend/flashinfer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ def __post_init__(self) -> None:
170170
def create_cuda_graph_metadata(self,
171171
max_batch_size: int,
172172
sub_cross_metadata: bool = False,
173-
max_draft_tokens: int = 0) -> Self:
173+
max_draft_tokens: int = 0,
174+
buffers=None) -> Self:
174175
metadata = super().create_cuda_graph_metadata(max_batch_size,
175176
sub_cross_metadata,
176177
max_draft_tokens)

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class AttentionMetadata:
137137

138138
# This buffer is currently only used for TrtllmAttentionMetadata.
139139
cache_indirection: Optional[torch.Tensor] = None
140+
cuda_graph_buffers: dict[str, list[torch.Tensor]] = None
140141

141142
def __post_init__(self) -> None:
142143
if self.is_cross:
@@ -282,7 +283,8 @@ def prepare(self):
282283
def create_cuda_graph_metadata(self,
283284
max_batch_size: int,
284285
sub_cross_metadata: bool = False,
285-
max_draft_tokens: int = 0) -> Self:
286+
max_draft_tokens: int = 0,
287+
buffers=None) -> Self:
286288
"""
287289
Creates metadata for CUDA graph execution.
288290
CUDA graphs require to use pre-allocated buffers for all tensors in fields.
@@ -294,6 +296,7 @@ def create_cuda_graph_metadata(self,
294296

295297
cuda_graph_metadata = copy.copy(self)
296298
cuda_graph_metadata.is_cuda_graph = True
299+
cuda_graph_metadata.cuda_graph_buffers = buffers
297300
if self.has_cross_sub_metadata:
298301
cuda_graph_metadata.cross = cuda_graph_metadata.cross.create_cuda_graph_metadata(
299302
max_batch_size, True)

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 70 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -600,21 +600,76 @@ def host_kv_cache_pool_mapping(self) -> Optional[torch.Tensor]:
600600

601601
def __post_init__(self) -> None:
602602
super().__post_init__()
603+
self._post_init_with_buffers(self.cuda_graph_buffers)
604+
605+
def _post_init_with_buffers(self, buffers) -> None:
606+
603607
# Set a default value, as max_num_sequences is not always set.
604608
if self.max_num_sequences is None:
605609
self.max_num_sequences = self.max_num_requests
606610

607-
self.prompt_lens_cuda = torch.empty(
611+
def get_empty(tensor_shape: list[int], dtype: torch.dtype,
612+
cache_name: str) -> torch.Tensor:
613+
"""
614+
Finds a compatible, reusable buffer from a cache or creates a new one.
615+
616+
This function searches for a pre-allocated tensor (buffer) that can be
617+
reused for an operation involving a tensor with the shape of `tensor_shape`.
618+
619+
The compatibility rules are: The buffer's total elements must be >= tensor_shape's.
620+
621+
If a compatible buffer is found, it's returned immediately. Otherwise, a new
622+
buffer is allocated on the 'cuda' device with the give properties of 'tensor_shape' and 'dtype'.
623+
624+
Args:
625+
tensor_shape: The required shape.
626+
dtype: The required dtype.
627+
cache_name: The key for the specific list of buffers to search in.
628+
629+
Returns:
630+
An existing compatible buffer or a newly created one.
631+
"""
632+
if buffers is not None:
633+
# Safely get the list of candidates. Defaults to an empty list if key is missing.
634+
candidate_buffers = buffers.get(cache_name, [])
635+
numel_like = math.prod(tensor_shape)
636+
637+
for buffer in candidate_buffers:
638+
numel_buffer = buffer.numel()
639+
640+
# buffer just needs to be large enough.
641+
if numel_buffer >= numel_like:
642+
return buffer[0:numel_like].view(
643+
tensor_shape) # Found a fit, return immediately.
644+
645+
# If we get here, no suitable buffer was found in the cache. Create a new one.
646+
new_buffer = torch.zeros(tensor_shape, device='cuda', dtype=dtype)
647+
if buffers is not None:
648+
buffers.setdefault(cache_name, []).append(new_buffer)
649+
return new_buffer
650+
651+
def get_empty_like(like_tensor: torch.Tensor,
652+
cache_name: str) -> torch.Tensor:
653+
return get_empty(
654+
like_tensor.shape,
655+
cache_name=cache_name,
656+
dtype=like_tensor.dtype,
657+
)
658+
659+
self.prompt_lens_cuda = get_empty(
608660
(self.max_num_sequences, ),
609-
device='cuda',
661+
cache_name="prompt_lens_cuda",
610662
dtype=torch.int,
611663
)
612664
self.prompt_lens_cpu = torch.empty_like(
613665
self.prompt_lens_cuda,
614666
device='cpu',
615667
pin_memory=True,
616668
)
617-
self.kv_lens_cuda = torch.empty_like(self.prompt_lens_cuda)
669+
self.kv_lens_cuda = get_empty_like(
670+
self.prompt_lens_cuda,
671+
cache_name="kv_lens_cuda",
672+
)
618673
self.kv_lens = torch.empty_like(self.kv_lens_cuda,
619674
device='cpu',
620675
pin_memory=True)
@@ -628,13 +683,13 @@ def __post_init__(self) -> None:
628683
dtype=torch.int8,
629684
)
630685
if self.kv_cache_manager is not None:
631-
self.kv_cache_block_offsets = torch.empty(
686+
self.kv_cache_block_offsets = get_empty(
632687
[
633688
self.kv_cache_manager.num_pools, self.max_num_sequences, 2,
634689
self.kv_cache_manager.max_blocks_per_seq
635690
],
691+
cache_name="kv_cache_block_offsets",
636692
dtype=torch.int32,
637-
device='cuda',
638693
)
639694
self.host_kv_cache_block_offsets = torch.empty_like(
640695
self.kv_cache_block_offsets,
@@ -644,37 +699,37 @@ def __post_init__(self) -> None:
644699
self.block_ids_per_seq = None
645700
self.kv_block_ids_per_seq = None
646701
if self.enable_flash_mla:
647-
self.block_ids_per_seq = torch.zeros(
702+
self.block_ids_per_seq = get_empty(
648703
[
649704
self.kv_cache_manager.max_batch_size,
650705
self.kv_cache_manager.max_blocks_per_seq
651706
],
707+
cache_name="block_ids_per_seq",
652708
dtype=torch.int32,
653-
device='cuda',
654709
)
655-
self.kv_block_ids_per_seq = torch.zeros(
710+
self.kv_block_ids_per_seq = get_empty(
656711
[
657712
self.kv_cache_manager.max_batch_size,
658713
self.kv_cache_manager.max_blocks_per_seq
659714
],
715+
cache_name="kv_block_ids_per_seq",
660716
dtype=torch.int32,
661-
device='cuda',
662717
)
663718
if self.enable_paged_context_mla:
664719
# for kv cache reuse/chunked context in MLA
665-
self.ctx_cached_token_indptr = torch.zeros(
720+
self.ctx_cached_token_indptr = get_empty(
666721
(self.max_num_requests + 1, ),
667-
device='cuda',
722+
cache_name="ctx_cached_token_indptr",
668723
dtype=torch.int64,
669724
)
670725
self.host_ctx_cached_token_indptr = torch.zeros_like(
671726
self.ctx_cached_token_indptr,
672727
device='cpu',
673728
pin_memory=True,
674729
)
675-
self.ctx_uncached_token_indptr = torch.zeros(
730+
self.ctx_uncached_token_indptr = get_empty(
676731
(self.max_num_requests + 1, ),
677-
device='cuda',
732+
cache_name="ctx_uncached_token_indptr",
678733
dtype=torch.int64,
679734
)
680735
self.host_ctx_uncached_token_indptr = torch.zeros_like(
@@ -683,9 +738,9 @@ def __post_init__(self) -> None:
683738
pin_memory=True,
684739
)
685740
# context full seqlens include cached tokens and uncached tokens
686-
self.ctx_kv_indptr = torch.zeros(
741+
self.ctx_kv_indptr = get_empty(
687742
(self.max_num_requests + 1, ),
688-
device='cuda',
743+
cache_name="ctx_kv_indptr",
689744
dtype=torch.int64,
690745
)
691746
self.host_ctx_kv_indptr = torch.zeros_like(

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def __init__(
7373
self.optional_extra_model_inputs = ["mrope_position_deltas"]
7474

7575
def __del__(self):
76-
self._graph.reset()
76+
if self._graph is not None:
77+
self._graph.reset()
7778

7879
def capture(
7980
self,

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ def __init__(
425425
self.kv_cache_manager_key = ResourceManagerType.KV_CACHE_MANAGER
426426
self.lora_model_config: Optional[LoraModelConfig] = None
427427
self.cuda_graph_dummy_request = None
428+
self.cuda_graph_meta_buffers: dict[str, list[torch.Tensor]] = {}
428429

429430
# Setup the local cache indirection buffer only once and reuse it.
430431
# This way it can also be used for CUDA graphs.
@@ -970,15 +971,16 @@ def _maybe_get_cuda_graph(
970971

971972
num_sequences_in_batch = batch_size * self.max_beam_width
972973
attn_metadata = self.attn_metadata.create_cuda_graph_metadata(
973-
num_sequences_in_batch, False, draft_len)
974+
num_sequences_in_batch, False, draft_len,
975+
self.cuda_graph_meta_buffers)
976+
974977
assert attn_metadata.is_cuda_graph
975978

979+
spec_metadata = None
976980
if self.enable_spec_decode:
977981
spec_metadata = self.spec_metadata.create_cuda_graph_metadata(
978982
num_sequences_in_batch)
979983
spec_metadata.draft_tokens = self.draft_tokens_cuda
980-
else:
981-
spec_metadata = None
982984

983985
# Initialize nested dictionary if needed
984986
if batch_size not in self._cuda_graphs:
@@ -1143,9 +1145,10 @@ def _release_cuda_graphs(self):
11431145
for draft_len, graph in draft_graphs.items():
11441146
del graph
11451147
self._cuda_graphs.clear()
1146-
torch.cuda.empty_cache()
11471148
del self._cuda_graph_mem_pool
11481149
self._cuda_graph_mem_pool = None
1150+
self.cuda_graph_meta_buffers.clear()
1151+
torch.cuda.empty_cache()
11491152

11501153
def get_max_num_sequences(self) -> int:
11511154
"""

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,6 @@ full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen1.5_7b
278278
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
279279
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_vl_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5359696)
280280
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
281-
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5410391)
282281
accuracy/test_llm_api.py::TestMistral_Nemo_12B_Base::test_fp8 SKIP (https://nvbugs/5413197)
283282
accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl SKIP (https://nvbugs/5413362)
284283
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8] SKIP (https://nvbugs/5455140)

0 commit comments

Comments
 (0)