@@ -280,17 +280,30 @@ def _verify_greedy_sampling(self) -> None:
280280 f"Got { self .best_of } ." )
281281
282282 def update_from_generation_config (
283- self , generation_config : Dict [str , Any ]) -> None :
283+ self ,
284+ generation_config : Dict [str , Any ],
285+ model_eos_token_id : Optional [int ] = None ) -> None :
284286 """Update if there are non-default values from generation_config"""
287+
288+ if model_eos_token_id is not None :
289+ # Add the eos token id into the sampling_params to support
290+ # min_tokens processing.
291+ self .all_stop_token_ids .add (model_eos_token_id )
292+
285293 # Update eos_token_id for generation
286- if (not self .ignore_eos ) and (eos_ids :=
287- generation_config .get ("eos_token_id" )):
294+ if (eos_ids := generation_config .get ("eos_token_id" )) is not None :
288295 # it can be either int or list of int
289- if isinstance (eos_ids , int ):
290- eos_ids = [eos_ids ]
291- original_stop_token_ids = set (self .stop_token_ids )
292- original_stop_token_ids .update (eos_ids )
293- self .stop_token_ids = list (original_stop_token_ids )
296+ eos_ids = {eos_ids } if isinstance (eos_ids , int ) else set (eos_ids )
297+ if model_eos_token_id is not None :
298+ # We don't need to include the primary eos_token_id in
299+ # stop_token_ids since it's handled separately for stopping
300+ # purposes.
301+ eos_ids .discard (model_eos_token_id )
302+ if eos_ids :
303+ self .all_stop_token_ids .update (eos_ids )
304+ if not self .ignore_eos :
305+ eos_ids .update (self .stop_token_ids )
306+ self .stop_token_ids = list (eos_ids )
294307
295308 @cached_property
296309 def sampling_type (self ) -> SamplingType :
0 commit comments