3636    from  vllm .v1 .request  import  Request 
3737
3838Transfer  =  tuple [int , float ]  # (xfer_handle, start_time) 
39+ EngineId  =  str 
40+ ReqId  =  str 
3941GET_META_MSG  =  b"get_meta_msg" 
4042
4143logger  =  init_logger (__name__ )
@@ -75,7 +77,7 @@ class ReqMeta:
7577class  NixlConnectorMetadata (KVConnectorMetadata ):
7678
7779    def  __init__ (self ):
78-         self .requests : dict [str , ReqMeta ] =  {}
80+         self .requests : dict [ReqId , ReqMeta ] =  {}
7981
8082    def  add_new_req (
8183        self ,
@@ -96,16 +98,17 @@ class NixlConnector(KVConnectorBase_V1):
9698
9799    def  __init__ (self , vllm_config : VllmConfig , role : KVConnectorRole ):
98100        assert  vllm_config .kv_transfer_config  is  not   None 
99-         self .engine_id  =  vllm_config .kv_transfer_config .engine_id 
101+         assert  vllm_config .kv_transfer_config .engine_id  is  not   None 
102+         self .engine_id : EngineId  =  vllm_config .kv_transfer_config .engine_id 
100103
101104        if  role  ==  KVConnectorRole .SCHEDULER :
102105            self .connector_scheduler  : Optional [NixlConnectorScheduler ] =  \
103-                 NixlConnectorScheduler (vllm_config , str ( self .engine_id ) )
106+                 NixlConnectorScheduler (vllm_config , self .engine_id )
104107            self .connector_worker : Optional [NixlConnectorWorker ] =  None 
105108        elif  role  ==  KVConnectorRole .WORKER :
106109            self .connector_scheduler  =  None 
107110            self .connector_worker  =  NixlConnectorWorker (
108-                 vllm_config , str ( self .engine_id ) )
111+                 vllm_config , self .engine_id )
109112
110113    ############################################################ 
111114    # Scheduler Side Methods 
@@ -179,7 +182,7 @@ class NixlConnectorScheduler:
179182    def  __init__ (self , vllm_config : VllmConfig , engine_id : str ):
180183        self .vllm_config  =  vllm_config 
181184        self .block_size  =  vllm_config .cache_config .block_size 
182-         self .engine_id  =  engine_id 
185+         self .engine_id :  EngineId  =  engine_id 
183186        self .side_channel_host  =  envs .VLLM_NIXL_SIDE_CHANNEL_HOST 
184187        self .side_channel_port  =  (
185188            envs .VLLM_NIXL_SIDE_CHANNEL_PORT  + 
@@ -190,7 +193,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
190193        # Requests that need to start recv. 
191194        # New requests are added by update_state_after_alloc in 
192195        # the scheduler. Used to make metadata passed to Worker. 
193-         self ._reqs_need_recv : dict [str , tuple [Request , list [int ]]] =  {}
196+         self ._reqs_need_recv : dict [ReqId , tuple [Request , list [int ]]] =  {}
194197
195198    def  get_num_new_matched_tokens (
196199            self , request : "Request" ,
@@ -332,19 +335,19 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
332335        # Agent. 
333336        self .nixl_wrapper  =  NixlWrapper (str (uuid .uuid4 ()), None )
334337        # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. 
335-         self ._remote_agents : dict [str , dict [int , str ]] =  defaultdict (dict )
338+         self ._remote_agents : dict [EngineId , dict [int , str ]] =  defaultdict (dict )
336339
337340        # NIXL handshake port. 
338341        # NOTE(rob): Within a DP group, each DP rank gets its own 
339342        # base port (which is sent in the KVTransferParams). 
340343        # Each TP rank listens/queries on the base_port + tp_rank. 
341-         self .side_channel_port  =  (
344+         self .side_channel_port :  int  =  (
342345            envs .VLLM_NIXL_SIDE_CHANNEL_PORT  + 
343346            vllm_config .parallel_config .data_parallel_rank_local  * 
344347            vllm_config .parallel_config .tensor_parallel_size )
345348
346349        # Metadata. 
347-         self .engine_id  =  engine_id 
350+         self .engine_id :  EngineId  =  engine_id 
348351        self .tp_rank  =  get_tensor_model_parallel_rank ()
349352        self .world_size  =  get_tensor_model_parallel_world_size ()
350353        self .tp_group  =  get_tp_group ()
@@ -354,7 +357,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
354357
355358        # Map of engine_id -> kv_caches_base_addr. For TP case, each local 
356359        # rank will still only pull from a single remote TP worker. 
357-         self .kv_caches_base_addr : dict [str , list [int ]] =  {}
360+         self .kv_caches_base_addr : dict [EngineId , list [int ]] =  {}
358361
359362        # Number of NIXL regions. Currently one region per cache 
360363        # (so 1 per layer for MLA, otherwise 2 per layer) 
@@ -364,23 +367,23 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
364367        # nixl_prepped_dlist_handle. 
365368        self .src_xfer_side_handle : int  =  0 
366369        # Map of engine_id -> nixl_prepped_dlist_handle (int)]. 
367-         self .dst_xfer_side_handles : dict [str , int ] =  {}
370+         self .dst_xfer_side_handles : dict [EngineId , int ] =  {}
368371
369372        # Map of engine_id -> num_blocks. All ranks in the same deployment will 
370373        # have the same number of blocks. 
371-         self .dst_num_blocks : dict [str , int ] =  {}
374+         self .dst_num_blocks : dict [EngineId , int ] =  {}
372375        self ._registered_descs : list [Any ] =  []
373376
374377        # In progress transfers. 
375378        # [req_id -> list[handle]] 
376-         self ._recving_transfers  =  defaultdict [str , list [Transfer ]](list )
379+         self ._recving_transfers  =  defaultdict [ReqId , list [Transfer ]](list )
377380
378381        # Complete transfer tracker. Used by the rank 0 to track finished 
379382        # transactions on ranks 1 to N-1. 
380383        # [req_id -> count] 
381-         self ._done_recving_count : defaultdict [str ,
384+         self ._done_recving_count : defaultdict [ReqId ,
382385                                              int ] =  defaultdict (lambda : 0 )
383-         self ._done_sending_count : defaultdict [str ,
386+         self ._done_sending_count : defaultdict [ReqId ,
384387                                              int ] =  defaultdict (lambda : 0 )
385388
386389        # Background thread for establishing new connections. 
@@ -408,10 +411,10 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
408411        self ._use_flashinfer  =  attn_backend  ==  _Backend .FLASHINFER_VLLM_V1 
409412        logger .debug ("Detected attention backend %s" , self .backend_name )
410413
411-         self ._tp_size : dict [str , int ] =  {self .engine_id : self .world_size }
414+         self ._tp_size : dict [EngineId , int ] =  {self .engine_id : self .world_size }
412415        # With heterogeneous TP, P must wait for all assigned D TP workers to 
413416        # finish reading before safely freeing the blocks. 
414-         self .consumer_notification_counts_by_req  =  defaultdict [str , int ](int )
417+         self .consumer_notification_counts_by_req  =  defaultdict [ReqId , int ](int )
415418
416419    @staticmethod  
417420    def  _nixl_handshake_listener (metadata : NixlAgentMetadata ,
0 commit comments