@@ -611,3 +611,98 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
611611        prompt_logprobs_dict = {},
612612    )
613613    scheduler .update_from_output (scheduler_output1 , model_runner_output )
614+ 
615+ 
616+ # Note - these test cases mirror some of those in test_rejection_sampler.py 
617+ @pytest .mark .parametrize ( 
618+     "spec_tokens,output_tokens,expected" , 
619+     [ 
620+         ([[1 , 2 , 3 ]], [[1 , 2 , 3 , 4 ]], (3 , 3 )),  # perfect match  
621+         ([[1 , 2 , 3 ]], [[1 , 5 ]], (3 , 1 )),  # early mismatch  
622+         ([[1 , 2 ], [3 ]], [[1 , 2 , 5 ], [3 , 4 ]], (3 , 3 )),  # multiple sequences  
623+         ([[1 ]], [[1 , 2 ]], (1 , 1 )),  # single token sequence  
624+         ([[]], [[5 ]], (0 , 0 )),  # empty sequence  
625+         ([[1 , 2 , 3 ], [4 , 5 , 6 ]], [[1 , 2 , 7 ], [4 , 8 ]], 
626+          (6 , 3 )),  # multiple mismatches  
627+     ]) 
628+ def  test_schedule_spec_decoding_stats (spec_tokens , output_tokens , expected ):
629+     """Test scheduling behavior with speculative decoding. 
630+ 
631+     This test verifies that: 
632+     1. Speculated tokens get scheduled correctly 
633+     2. Spec decoding stats properly count number of draft and accepted tokens 
634+     """ 
635+     scheduler  =  create_scheduler ()
636+     requests  =  create_requests (num_requests = len (spec_tokens ), num_tokens = 1 )
637+     req_ids  =  []
638+     req_to_index  =  {}
639+     for  i , request  in  enumerate (requests ):
640+         scheduler .add_request (request )
641+         req_ids .append (request .request_id )
642+         req_to_index [request .request_id ] =  i 
643+ 
644+     # Schedule a decode, which will also draft speculative tokens 
645+     output  =  scheduler .schedule ()
646+     assert  len (output .scheduled_new_reqs ) ==  len (requests )
647+     assert  output .total_num_scheduled_tokens  ==  len (requests )
648+     for  i  in  range (len (requests )):
649+         req_id  =  requests [i ].request_id 
650+         assert  output .num_scheduled_tokens [req_id ] ==  1 
651+         assert  req_id  not  in output .scheduled_spec_decode_tokens 
652+ 
653+     model_runner_output  =  ModelRunnerOutput (
654+         req_ids = req_ids ,
655+         req_id_to_index = req_to_index ,
656+         sampled_token_ids = [[0 ] for  _  in  range (len (requests ))],
657+         spec_token_ids = spec_tokens ,
658+         logprobs = None ,
659+         prompt_logprobs_dict = {},
660+     )
661+     engine_core_outputs  =  scheduler .update_from_output (output ,
662+                                                        model_runner_output )
663+ 
664+     for  i  in  range (len (requests )):
665+         running_req  =  scheduler .running [i ]
666+         # The prompt token 
667+         assert  running_req .num_computed_tokens  ==  1 
668+         # The prompt token and the sampled token 
669+         assert  running_req .num_tokens  ==  2 
670+         # The prompt token, the sampled token, and the speculated tokens 
671+         assert  running_req .num_tokens_with_spec  ==  2  +  len (spec_tokens [i ])
672+ 
673+     # No draft or accepted tokens counted yet 
674+     assert  engine_core_outputs .scheduler_stats .spec_decoding_stats  is  not None 
675+     stats  =  engine_core_outputs .scheduler_stats .spec_decoding_stats 
676+     assert  stats .num_draft_tokens  ==  0 
677+     assert  stats .num_accepted_tokens  ==  0 
678+ 
679+     # Schedule the speculated tokens for validation 
680+     output  =  scheduler .schedule ()
681+     assert  len (output .scheduled_new_reqs ) ==  0 
682+     # The sampled token and speculated tokens 
683+     assert  output .total_num_scheduled_tokens  ==  \
684+         len (requests ) +  sum (len (ids ) for  ids  in  spec_tokens )
685+     for  i  in  range (len (requests )):
686+         req_id  =  requests [i ].request_id 
687+         assert  output .num_scheduled_tokens [req_id ] ==  1  +  len (spec_tokens [i ])
688+         if  spec_tokens [i ]:
689+             assert  len (output .scheduled_spec_decode_tokens [req_id ]) ==  \
690+                 len (spec_tokens [i ])
691+         else :
692+             assert  req_id  not  in output .scheduled_spec_decode_tokens 
693+ 
694+     model_runner_output  =  ModelRunnerOutput (
695+         req_ids = req_ids ,
696+         req_id_to_index = req_to_index ,
697+         sampled_token_ids = output_tokens ,
698+         spec_token_ids = None ,
699+         logprobs = None ,
700+         prompt_logprobs_dict = {},
701+     )
702+     engine_core_outputs  =  scheduler .update_from_output (output ,
703+                                                        model_runner_output )
704+ 
705+     assert  engine_core_outputs .scheduler_stats .spec_decoding_stats  is  not None 
706+     stats  =  engine_core_outputs .scheduler_stats .spec_decoding_stats 
707+     assert  stats .num_draft_tokens  ==  expected [0 ]
708+     assert  stats .num_accepted_tokens  ==  expected [1 ]
0 commit comments