@@ -96,7 +96,8 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
9696 self .connector_worker : Optional [NixlConnectorWorker ] = None
9797 elif role == KVConnectorRole .WORKER :
9898 self .connector_scheduler = None
99- self .connector_worker = NixlConnectorWorker (str (self .engine_id ))
99+ self .connector_worker = NixlConnectorWorker (
100+ vllm_config , str (self .engine_id ))
100101
101102 ############################################################
102103 # Scheduler Side Methods
@@ -302,7 +303,7 @@ def request_finished(
302303class NixlConnectorWorker :
303304 """Implementation of Worker side methods"""
304305
305- def __init__ (self , engine_id : str ):
306+ def __init__ (self , vllm_config : VllmConfig , engine_id : str ):
306307 if NixlWrapper is None :
307308 logger .error ("NIXL is not available" )
308309 raise RuntimeError ("NIXL is not available" )
@@ -329,6 +330,7 @@ def __init__(self, engine_id: str):
329330 # Number of NIXL regions. Currently one region per cache
330331 # (so 1 per layer for MLA, otherwise 2 per layer)
331332 self .num_regions = 0
333+ self .num_layers = 0
332334
333335 # nixl_prepped_dlist_handle (int).
334336 self .src_xfer_side_handle : int = 0
@@ -355,6 +357,14 @@ def __init__(self, engine_id: str):
355357 # Background thread for establishing new connections.
356358 self ._nixl_handshake_listener_t : Optional [threading .Thread ] = None
357359
360+ self .vllm_config = vllm_config
361+ self .block_size = vllm_config .cache_config .block_size
362+
363+ # TODO(mgoin): remove this once we have hybrid memory allocator
364+ # Optimization for models with local attention (Llama 4)
365+ # List of block window sizes for each layer for local attention
366+ self .block_window_per_layer : list [Optional [int ]] = []
367+
358368 @staticmethod
359369 def _nixl_handshake_listener (metadata : NixlAgentMetadata ,
360370 ready_event : threading .Event , rank : int ):
@@ -465,6 +475,27 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
465475 kv_caches_base_addr .append (base_addr )
466476 self .kv_caches_base_addr [self .engine_id ] = kv_caches_base_addr
467477 self .num_regions = len (caches_data )
478+ self .num_layers = len (self .kv_caches .keys ())
479+
480+ # TODO(mgoin): remove this once we have hybrid memory allocator
481+ # Optimization for models with local attention (Llama 4)
482+ if self .vllm_config .model_config .hf_config .model_type == "llama4" :
483+ from transformers import Llama4TextConfig
484+ assert isinstance (self .vllm_config .model_config .hf_text_config ,
485+ Llama4TextConfig )
486+ llama4_config = self .vllm_config .model_config .hf_text_config
487+ no_rope_layers = llama4_config .no_rope_layers
488+ chunk_size = llama4_config .attention_chunk_size
489+ chunk_block_size = math .ceil (chunk_size / self .block_size )
490+ for layer_idx in range (self .num_layers ):
491+ # no_rope_layers[layer_idx] == 0 means NoPE (global)
492+ # Any other value means RoPE (local chunked)
493+ is_local_attention = no_rope_layers [layer_idx ] != 0
494+ block_window = chunk_block_size if is_local_attention else None
495+ self .block_window_per_layer .append (block_window )
496+ logger .debug ("Llama 4 block window per layer mapping: %s" ,
497+ self .block_window_per_layer )
498+ assert len (self .block_window_per_layer ) == self .num_layers
468499
469500 descs = self .nixl_wrapper .get_reg_descs (caches_data , "VRAM" )
470501 logger .debug ("Registering descs: %s" , caches_data )
@@ -699,10 +730,39 @@ def _read_blocks(
699730 remote_xfer_side_handle = self .dst_xfer_side_handles [dst_engine_id ]
700731
701732 # Get descs ids.
702- remote_block_descs_ids = self ._get_block_descs_ids (
703- dst_engine_id , remote_block_ids )
704- local_block_descs_ids = self ._get_block_descs_ids (
705- self .engine_id , local_block_ids )
733+ local_block_descs_ids : list [int ] = []
734+ remote_block_descs_ids : list [int ] = []
735+ if not self .block_window_per_layer :
736+ # Default case: assume global attention
737+ remote_block_descs_ids = self ._get_block_descs_ids (
738+ dst_engine_id , remote_block_ids )
739+ local_block_descs_ids = self ._get_block_descs_ids (
740+ self .engine_id , local_block_ids )
741+ else :
742+ # TODO(mgoin): remove this once we have hybrid memory allocator
743+ # Optimization for models with local attention (Llama 4)
744+ for layer_idx , block_window in enumerate (
745+ self .block_window_per_layer ):
746+ # For each layer:
747+ if block_window is None :
748+ # If not chunked, we just use the
749+ # full block lists (global attention)
750+ layer_local_block_ids = local_block_ids
751+ layer_remote_block_ids = remote_block_ids
752+ else :
753+ # If chunked, get the last block_window blocks
754+ layer_local_block_ids = local_block_ids [- block_window :]
755+ layer_remote_block_ids = remote_block_ids [- block_window :]
756+
757+ # Get descs ids for the layer.
758+ layer_local_desc_ids = self ._get_block_descs_ids (
759+ self .engine_id , layer_local_block_ids , layer_idx )
760+ layer_remote_desc_ids = self ._get_block_descs_ids (
761+ dst_engine_id , layer_remote_block_ids , layer_idx )
762+
763+ local_block_descs_ids .extend (layer_local_desc_ids )
764+ remote_block_descs_ids .extend (layer_remote_desc_ids )
765+
706766 assert len (local_block_descs_ids ) == len (remote_block_descs_ids )
707767
708768 # Prepare transfer with Nixl.
@@ -721,12 +781,31 @@ def _read_blocks(
721781 # Use handle to check completion in future step().
722782 self ._recving_transfers [request_id ].append (handle )
723783
724- def _get_block_descs_ids (self , engine_id : str ,
725- block_ids : list [int ]) -> list [int ]:
726- """Get the descs ids for a set of block ids."""
784+ def _get_block_descs_ids (self ,
785+ engine_id : str ,
786+ block_ids : list [int ],
787+ layer_idx : Optional [int ] = None ) -> list [int ]:
788+ """
789+ Get the descs ids for a set of block ids.
790+ If layer_idx is provided, we use the region_ids for the given layer.
791+ Otherwise, we use all regions.
792+ """
793+
794+ if layer_idx is None :
795+ region_ids = range (self .num_regions )
796+ else :
797+ assert layer_idx < self .num_layers
798+ if self .num_layers < self .num_regions :
799+ # If we have more regions than layers, we assume that
800+ # the regions are organized as [K0, V0, K1, V1, ...]
801+ # and we select K_i and V_i
802+ assert 2 * self .num_layers == self .num_regions
803+ region_ids = range (2 * layer_idx , 2 * layer_idx + 2 )
804+ else :
805+ # Otherwise, we assume we have MLA and select i-th layer
806+ assert self .num_layers == self .num_regions
807+ region_ids = range (layer_idx , layer_idx + 1 )
727808
728- # range(1) for MLA, range(2) otherwise.
729- region_ids = range (self .num_regions )
730809 num_blocks = self .dst_num_blocks [engine_id ]
731810
732811 # Compute the desc ids for each block.
0 commit comments