2020logger = init_logger (__name__ )
2121
2222_PAD_SLOT_ID = 0 # FIXME(woosuk)
23+ # FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
24+ _ENABLE_TOP_P = False
2325
2426
2527class TPUModelRunner :
@@ -339,9 +341,34 @@ def _prepare_sample(
339341 assert seq_group_metadata .sampling_params is not None
340342 sampling_params = seq_group_metadata .sampling_params
341343
344+ # NOTE(woosuk): Here we mimic argmax sampling by applying a very
345+ # low temperature. This is not accurate.
342346 t .append (sampling_params .temperature
343347 if sampling_params .temperature >= 1e-5 else 1e-5 )
348+ if sampling_params .top_p != 1 and not _ENABLE_TOP_P :
349+ raise NotImplementedError (
350+ "Top-p sampling is currently disabled for the TPU backend "
351+ "due to performance issues." )
344352 p .append (sampling_params .top_p )
353+ if sampling_params .top_k != - 1 :
354+ raise NotImplementedError (
355+ "Top-k sampling is currently disabled for the TPU backend "
356+ "due to performance issues." )
357+ if sampling_params .best_of > 1 :
358+ raise NotImplementedError (
359+ "best_of > 1 is not currently supported by the TPU "
360+ "backend." )
361+ if sampling_params .use_beam_search :
362+ raise NotImplementedError (
363+ "Beam search is not supported by the TPU backend." )
364+ if sampling_params .logprobs is not None :
365+ raise NotImplementedError (
366+ "logprobs is not currently supported by the TPU backend." )
367+ if sampling_params .prompt_logprobs is not None :
368+ raise NotImplementedError (
369+ "prompt_logprobs is not currently supported by the TPU "
370+ "backend." )
371+
345372 num_paddings = padded_batch_size - len (seq_group_metadata_list )
346373 t += [1.0 ] * num_paddings
347374 p += [1.0 ] * num_paddings
@@ -350,35 +377,32 @@ def _prepare_sample(
350377 p = torch .tensor (p , dtype = torch .float32 , device = self .device )
351378 return t , p
352379
353- def prepare_inputs (
380+ def _execute_model (
354381 self ,
355- seq_group_metadata_list : Optional [List [SequenceGroupMetadata ]],
356- ):
357- assert seq_group_metadata_list is not None
382+ seq_group_metadata_list : List [SequenceGroupMetadata ],
383+ kv_caches : List [Tuple [torch .Tensor , torch .Tensor ]],
384+ ) -> List [CompletionSequenceGroupOutput ]:
385+ # Prepare inputs.
358386 assert len (seq_group_metadata_list ) > 0
359387 # NOTE: We assume that all sequences in the group are all prompts or
360388 # all decodes.
361- if seq_group_metadata_list [0 ].is_prompt :
389+ is_prompt = seq_group_metadata_list [0 ].is_prompt
390+ if is_prompt :
362391 inputs = self ._prepare_prompt (seq_group_metadata_list )
363392 else :
364393 inputs = self ._prepare_decode (seq_group_metadata_list )
365394 padded_batch_size = inputs [0 ].shape [0 ]
366- sample_inputs = self ._prepare_sample (seq_group_metadata_list ,
367- padded_batch_size )
368- return inputs + sample_inputs
395+ t , p = self ._prepare_sample (seq_group_metadata_list , padded_batch_size )
369396
370- def _execute_model (
371- self ,
372- seq_group_metadata_list : List [SequenceGroupMetadata ],
373- kv_caches : List [Tuple [torch .Tensor , torch .Tensor ]],
374- ) -> List [CompletionSequenceGroupOutput ]:
375- inputs = self .prepare_inputs (seq_group_metadata_list )
397+ # Execute the model.
376398 next_token_ids = self .model (inputs [0 ], inputs [1 ], kv_caches ,
377- * inputs [2 :])
378- if not self .is_driver_worker :
379- return []
399+ * inputs [2 :], t , p )
400+ # Retrieve the outputs to CPU.
380401 next_token_ids = next_token_ids .cpu ().tolist ()
381402
403+ # NOTE(woosuk): Minimal code to construct the sampler outputs.
404+ # The TPU backend does not reuse the sampler, since the TPU backend
405+ # does not support the advanced sampling parameters such as logprobs.
382406 i = 0
383407 sampler_outputs = []
384408 for seq_group_metadata in seq_group_metadata_list :
@@ -400,6 +424,7 @@ def execute_model(
400424 kv_caches : List [Tuple [torch .Tensor , torch .Tensor ]],
401425 ) -> SamplerOutput :
402426 assert seq_group_metadata_list is not None
427+ assert len (seq_group_metadata_list ) > 0
403428 if seq_group_metadata_list [0 ].is_prompt :
404429 # NOTE(woosuk): To reduce the compilation time, we only compile the
405430 # prefill inputs with batch size 1. Because the scheduler is not
@@ -492,8 +517,8 @@ def forward(
492517 logits = self .model .compute_logits (hidden_states , sampling_metadata )
493518
494519 logits = logits / t .unsqueeze (dim = 1 )
495- # FIXME(woosuk): Disabled top-p sampling since it's too slow.
496- # logits = _apply_top_p(logits, p.unsqueeze(dim=1))
520+ if _ENABLE_TOP_P :
521+ logits = _apply_top_p (logits , p .unsqueeze (dim = 1 ))
497522 probs = torch .softmax (logits , dim = - 1 , dtype = torch .float32 )
498523 # FIXME(woosuk): best_of > 1 is not supported.
499524 next_token_ids = torch .multinomial (probs , num_samples = 1 ).squeeze (dim = 1 )
0 commit comments