@@ -443,8 +443,8 @@ def _sample2(
443443    should_optimize  =  True 
444444
445445    for  i , seq_group  in  enumerate (input_metadata .seq_groups [input_metadata .num_prompts :]):
446-         _ , sampling_params  =  seq_group 
447-         if  sampling_params .use_beam_search  or  sampling_params .temperature  <  _SAMPLING_EPS :
446+         seq_ids , sampling_params  =  seq_group 
447+         if  sampling_params .use_beam_search  or  sampling_params .temperature  <  _SAMPLING_EPS   or   len ( seq_ids )  !=   1 :
448448            should_optimize  =  False 
449449            break 
450450
@@ -460,22 +460,26 @@ def _sample_optimized(
460460) ->  Dict [int , SequenceOutputs ]:
461461    seq_outputs : Dict [int , SequenceOutputs ] =  {}
462462
463-     idx  =  0 
464463    num_prompts  =  input_metadata .num_prompts 
465464
466465    gen_probs  =  probs [num_prompts :]
467466    gen_next_token_ids  =  torch .multinomial (gen_probs ,
468467                                        num_samples = 1 ,
469468                                        replacement = True ).squeeze (dim = - 1 )
469+     chosen_logprobs  =  logprobs [num_prompts :][torch .arange (gen_next_token_ids .shape [0 ]), gen_next_token_ids ]
470+     chosen_logprobs  =  chosen_logprobs .squeeze (dim = - 1 )
471+     if  chosen_logprobs .dim () ==  0 :  # If it's a scalar (happens when `gen_next_token_ids.shape == torch.Size([1])`, due to torch indexing) 
472+         chosen_logprobs  =  chosen_logprobs .unsqueeze (0 )  # Add a dimension back 
473+     chosen_logprobs  =  chosen_logprobs .tolist ()
474+     gen_next_token_ids  =  gen_next_token_ids .tolist ()
470475
471476    for  i , seq_group  in  enumerate (input_metadata .seq_groups ):
472477        seq_ids , sampling_params  =  seq_group 
473478        if  i  <  num_prompts :
474479            # Generate the next tokens for a prompt input. 
475480            assert  len (seq_ids ) ==  sampling_params .best_of 
476-             prob  =  probs [idx ]
477-             logprob  =  logprobs [idx ]
478-             idx  +=  1 
481+             prob  =  probs [i ]
482+             logprob  =  logprobs [i ]
479483
480484            # Sample the next tokens. 
481485            next_token_ids  =  _sample_from_prompt (prob , sampling_params )
@@ -492,33 +496,24 @@ def _sample_optimized(
492496                                                      output_logprobs )
493497        else :
494498            # Generate the next tokens for generation tokens. 
495-             prob  =  probs [idx :idx  +  len (seq_ids )]
496-             logprob  =  logprobs [idx :idx  +  len (seq_ids )]
499+             logprob  =  logprobs [i ]
497500
498501            # Sample the next tokens. 
499-             next_token_ids  =  gen_next_token_ids [idx  -  num_prompts :idx  +  len (seq_ids ) -  num_prompts ]
500-             next_token_ids  =  next_token_ids .tolist ()
501-             parent_seq_ids  =  seq_ids 
502-             idx  +=  len (seq_ids )
502+             next_token_id  =  gen_next_token_ids [i  -  num_prompts ]
503503
504504            # Get top-k log probabilities for the next tokens. 
505505            next_logprobs : Dict [int , Dict [int , float ]] =  {}
506-             for  j , seq_id  in  enumerate (seq_ids ):
507-                 next_logprobs [seq_id ] =  _get_topk_logprobs (
508-                     logprob [j ], sampling_params .logprobs )
506+             seq_id  =  seq_ids [0 ]
507+             next_logprobs [seq_id ] =  _get_topk_logprobs ([logprob ], sampling_params .logprobs )
509508
510509            # Build the output. 
511-             for  seq_id , parent_seq_id , next_token_id  in  zip (
512-                     seq_ids , parent_seq_ids , next_token_ids ):
513-                 j  =  seq_ids .index (parent_seq_id )
514-                 output_logprobs  =  next_logprobs [parent_seq_id ].copy ()
515-                 output_logprobs [next_token_id ] =  logprob [j ,
516-                                                          next_token_id ].item ()
517-                 seq_outputs [seq_id ] =  SequenceOutputs (
518-                     seq_id ,
519-                     parent_seq_id ,
520-                     next_token_id ,
521-                     output_logprobs ,
522-                 )
510+             output_logprobs  =  next_logprobs [seq_id ].copy ()
511+             output_logprobs [next_token_id ] =  chosen_logprobs [i  -  num_prompts ]
512+             seq_outputs [seq_id ] =  SequenceOutputs (
513+                 seq_id ,
514+                 seq_id ,
515+                 next_token_id ,
516+                 output_logprobs ,
517+             )
523518
524519    return  seq_outputs 
0 commit comments