@@ -600,21 +600,76 @@ def host_kv_cache_pool_mapping(self) -> Optional[torch.Tensor]:
600600
601601 def __post_init__ (self ) -> None :
602602 super ().__post_init__ ()
603+ self ._post_init_with_buffers (self .cuda_graph_buffers )
604+
605+ def _post_init_with_buffers (self , buffers ) -> None :
606+
603607 # Set a default value, as max_num_sequences is not always set.
604608 if self .max_num_sequences is None :
605609 self .max_num_sequences = self .max_num_requests
606610
607- self .prompt_lens_cuda = torch .empty (
611+ def get_empty (tensor_shape : list [int ], dtype : torch .dtype ,
612+ cache_name : str ) -> torch .Tensor :
613+ """
614+ Finds a compatible, reusable buffer from a cache or creates a new one.
615+
616+ This function searches for a pre-allocated tensor (buffer) that can be
617+ reused for an operation involving a tensor with the shape of `tensor_shape`.
618+
619+ The compatibility rules are: The buffer's total elements must be >= tensor_shape's.
620+
621+ If a compatible buffer is found, it's returned immediately. Otherwise, a new
622+ buffer is allocated on the 'cuda' device with the give properties of 'tensor_shape' and 'dtype'.
623+
624+ Args:
625+ tensor_shape: The required shape.
626+ dtype: The required dtype.
627+ cache_name: The key for the specific list of buffers to search in.
628+
629+ Returns:
630+ An existing compatible buffer or a newly created one.
631+ """
632+ if buffers is not None :
633+ # Safely get the list of candidates. Defaults to an empty list if key is missing.
634+ candidate_buffers = buffers .get (cache_name , [])
635+ numel_like = math .prod (tensor_shape )
636+
637+ for buffer in candidate_buffers :
638+ numel_buffer = buffer .numel ()
639+
640+ # buffer just needs to be large enough.
641+ if numel_buffer >= numel_like :
642+ return buffer [0 :numel_like ].view (
643+ tensor_shape ) # Found a fit, return immediately.
644+
645+ # If we get here, no suitable buffer was found in the cache. Create a new one.
646+ new_buffer = torch .zeros (tensor_shape , device = 'cuda' , dtype = dtype )
647+ if buffers is not None :
648+ buffers .setdefault (cache_name , []).append (new_buffer )
649+ return new_buffer
650+
651+ def get_empty_like (like_tensor : torch .Tensor ,
652+ cache_name : str ) -> torch .Tensor :
653+ return get_empty (
654+ like_tensor .shape ,
655+ cache_name = cache_name ,
656+ dtype = like_tensor .dtype ,
657+ )
658+
659+ self .prompt_lens_cuda = get_empty (
608660 (self .max_num_sequences , ),
609- device = 'cuda' ,
661+ cache_name = "prompt_lens_cuda" ,
610662 dtype = torch .int ,
611663 )
612664 self .prompt_lens_cpu = torch .empty_like (
613665 self .prompt_lens_cuda ,
614666 device = 'cpu' ,
615667 pin_memory = True ,
616668 )
617- self .kv_lens_cuda = torch .empty_like (self .prompt_lens_cuda )
669+ self .kv_lens_cuda = get_empty_like (
670+ self .prompt_lens_cuda ,
671+ cache_name = "kv_lens_cuda" ,
672+ )
618673 self .kv_lens = torch .empty_like (self .kv_lens_cuda ,
619674 device = 'cpu' ,
620675 pin_memory = True )
@@ -628,13 +683,13 @@ def __post_init__(self) -> None:
628683 dtype = torch .int8 ,
629684 )
630685 if self .kv_cache_manager is not None :
631- self .kv_cache_block_offsets = torch . empty (
686+ self .kv_cache_block_offsets = get_empty (
632687 [
633688 self .kv_cache_manager .num_pools , self .max_num_sequences , 2 ,
634689 self .kv_cache_manager .max_blocks_per_seq
635690 ],
691+ cache_name = "kv_cache_block_offsets" ,
636692 dtype = torch .int32 ,
637- device = 'cuda' ,
638693 )
639694 self .host_kv_cache_block_offsets = torch .empty_like (
640695 self .kv_cache_block_offsets ,
@@ -644,37 +699,37 @@ def __post_init__(self) -> None:
644699 self .block_ids_per_seq = None
645700 self .kv_block_ids_per_seq = None
646701 if self .enable_flash_mla :
647- self .block_ids_per_seq = torch . zeros (
702+ self .block_ids_per_seq = get_empty (
648703 [
649704 self .kv_cache_manager .max_batch_size ,
650705 self .kv_cache_manager .max_blocks_per_seq
651706 ],
707+ cache_name = "block_ids_per_seq" ,
652708 dtype = torch .int32 ,
653- device = 'cuda' ,
654709 )
655- self .kv_block_ids_per_seq = torch . zeros (
710+ self .kv_block_ids_per_seq = get_empty (
656711 [
657712 self .kv_cache_manager .max_batch_size ,
658713 self .kv_cache_manager .max_blocks_per_seq
659714 ],
715+ cache_name = "kv_block_ids_per_seq" ,
660716 dtype = torch .int32 ,
661- device = 'cuda' ,
662717 )
663718 if self .enable_paged_context_mla :
664719 # for kv cache reuse/chunked context in MLA
665- self .ctx_cached_token_indptr = torch . zeros (
720+ self .ctx_cached_token_indptr = get_empty (
666721 (self .max_num_requests + 1 , ),
667- device = 'cuda' ,
722+ cache_name = "ctx_cached_token_indptr" ,
668723 dtype = torch .int64 ,
669724 )
670725 self .host_ctx_cached_token_indptr = torch .zeros_like (
671726 self .ctx_cached_token_indptr ,
672727 device = 'cpu' ,
673728 pin_memory = True ,
674729 )
675- self .ctx_uncached_token_indptr = torch . zeros (
730+ self .ctx_uncached_token_indptr = get_empty (
676731 (self .max_num_requests + 1 , ),
677- device = 'cuda' ,
732+ cache_name = "ctx_uncached_token_indptr" ,
678733 dtype = torch .int64 ,
679734 )
680735 self .host_ctx_uncached_token_indptr = torch .zeros_like (
@@ -683,9 +738,9 @@ def __post_init__(self) -> None:
683738 pin_memory = True ,
684739 )
685740 # context full seqlens include cached tokens and uncached tokens
686- self .ctx_kv_indptr = torch . zeros (
741+ self .ctx_kv_indptr = get_empty (
687742 (self .max_num_requests + 1 , ),
688- device = 'cuda' ,
743+ cache_name = "ctx_kv_indptr" ,
689744 dtype = torch .int64 ,
690745 )
691746 self .host_ctx_kv_indptr = torch .zeros_like (
0 commit comments