55from  array  import  array 
66from  collections  import  defaultdict 
77from  dataclasses  import  dataclass 
8+ from  functools  import  cached_property , reduce 
89from  typing  import  TYPE_CHECKING , Any , Callable , Dict , List , Mapping , Optional 
910from  typing  import  Sequence  as  GenericSequence 
1011from  typing  import  Set , Tuple , Union , cast 
@@ -169,6 +170,35 @@ class SequenceData(msgspec.Struct,
169170    # It is used to compute mrope_position_ids. 
170171    _mrope_position_delta : Optional [int ] =  None 
171172
173+     @staticmethod  
174+     def  from_counts (counts_by_token : Mapping [int , int ]) ->  "SequenceData" :
175+         if  len (counts_by_token ) ==  0 :
176+             return  SequenceData .from_seqs ([])
177+ 
178+         arrs  =  [
179+             array (VLLM_TOKEN_ID_ARRAY_TYPE , [token_id ]) *  count 
180+             for  token_id , count  in  counts_by_token .items ()
181+         ]
182+ 
183+         return  SequenceData (reduce (array .__add__ , arrs ))
184+ 
185+     @staticmethod  
186+     def  from_seqs (
187+         prompt_token_ids : GenericSequence [int ],
188+         output_token_ids : Optional [GenericSequence [int ]] =  None ,
189+     ) ->  "SequenceData" :
190+         prompt_token_ids_arr  =  array (VLLM_TOKEN_ID_ARRAY_TYPE ,
191+                                      prompt_token_ids )
192+ 
193+         if  output_token_ids  is  None :
194+             return  SequenceData (prompt_token_ids_arr )
195+ 
196+         output_token_ids_arr  =  array (VLLM_TOKEN_ID_ARRAY_TYPE ,
197+                                      output_token_ids )
198+ 
199+         return  SequenceData (prompt_token_ids_arr ,
200+                             _output_token_ids = output_token_ids_arr )
201+ 
172202    def  __post_init__ (self ) ->  None :
173203        assert  self ._prompt_token_ids .typecode  ==  "l" 
174204        assert  self ._output_token_ids .typecode  ==  "l" 
@@ -370,8 +400,6 @@ def __init__(
370400        self .lora_request  =  lora_request 
371401        self .prompt_adapter_request  =  prompt_adapter_request 
372402        self .from_decoder_prompt  =  from_decoder_prompt 
373-         self ._prompt : Optional [str ] =  None 
374-         self ._prompt_token_ids : Optional [List [int ]] =  None 
375403
376404        # For decoder-only models, a Sequence is constructed 
377405        # from an LLMInputs instance (the `inputs` arg.) 
@@ -400,8 +428,7 @@ def __init__(
400428                             f"invalid input { inputs }  ; did you forget the " 
401429                             "encoder input prompt fields?" )
402430
403-         self .data  =  SequenceData (
404-             array (VLLM_TOKEN_ID_ARRAY_TYPE , self .prompt_token_ids ))
431+         self .data  =  SequenceData .from_seqs (self .prompt_token_ids )
405432        self .output_logprobs : SampleLogprobs  =  []
406433        self .output_text  =  "" 
407434
@@ -422,37 +449,23 @@ def __init__(
422449    def  n_blocks (self ) ->  int :
423450        return  (self .get_len () +  self .block_size  -  1 ) //  self .block_size 
424451
425-     @property  
452+     @cached_property  
426453    def  prompt (self ) ->  Optional [str ]:
427-         if  self ._prompt  is  not   None :
428-             # Reuse precomputed prompt string 
429-             return  self ._prompt 
430- 
431-         # Select decoder or encoder input prompt str, 
432-         # as appropriate 
454+         # Select decoder or encoder input prompt str, as appropriate 
433455        prompt_key : str  =  ("prompt" 
434456                           if  self .from_decoder_prompt  else  "encoder_prompt" )
435457
436-         # Cache prompt 
437-         self ._prompt  =  cast (Optional [str ], self .inputs .get (prompt_key ))
438-         return  self ._prompt 
458+         return  cast (Optional [str ], self .inputs .get (prompt_key ))
439459
440-     @property  
460+     @cached_property  
441461    def  prompt_token_ids (self ) ->  List [int ]:
442-         if  self ._prompt_token_ids  is  not   None :
443-             # Reuse precomputed prompt token ids 
444-             return  self ._prompt_token_ids 
445- 
446-         # Select decoder or encoder input prompt 
447-         # token ids, as appropriate 
462+         # Select decoder or encoder input prompt token ids, as appropriate 
448463        prompt_token_ids_key : str  =  ("prompt_token_ids" 
449464                                     if  self .from_decoder_prompt  else 
450465                                     "encoder_prompt_token_ids" )
451466
452467        # Cache computed prompt token ids 
453-         self ._prompt_token_ids  =  cast (List [int ],
454-                                       self .inputs .get (prompt_token_ids_key ))
455-         return  self ._prompt_token_ids 
468+         return  cast (List [int ], self .inputs .get (prompt_token_ids_key ))
456469
457470    @property  
458471    def  multi_modal_data (self ) ->  "MultiModalDataDict" :
0 commit comments