@@ -611,3 +611,98 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
611611 prompt_logprobs_dict = {},
612612 )
613613 scheduler .update_from_output (scheduler_output1 , model_runner_output )
614+
615+
616+ # Note - these test cases mirror some of those in test_rejection_sampler.py
617+ @pytest .mark .parametrize (
618+ "spec_tokens,output_tokens,expected" ,
619+ [
620+ ([[1 , 2 , 3 ]], [[1 , 2 , 3 , 4 ]], (3 , 3 )), # perfect match
621+ ([[1 , 2 , 3 ]], [[1 , 5 ]], (3 , 1 )), # early mismatch
622+ ([[1 , 2 ], [3 ]], [[1 , 2 , 5 ], [3 , 4 ]], (3 , 3 )), # multiple sequences
623+ ([[1 ]], [[1 , 2 ]], (1 , 1 )), # single token sequence
624+ ([[]], [[5 ]], (0 , 0 )), # empty sequence
625+ ([[1 , 2 , 3 ], [4 , 5 , 6 ]], [[1 , 2 , 7 ], [4 , 8 ]],
626+ (6 , 3 )), # multiple mismatches
627+ ])
628+ def test_schedule_spec_decoding_stats (spec_tokens , output_tokens , expected ):
629+ """Test scheduling behavior with speculative decoding.
630+
631+ This test verifies that:
632+ 1. Speculated tokens get scheduled correctly
633+ 2. Spec decoding stats properly count number of draft and accepted tokens
634+ """
635+ scheduler = create_scheduler ()
636+ requests = create_requests (num_requests = len (spec_tokens ), num_tokens = 1 )
637+ req_ids = []
638+ req_to_index = {}
639+ for i , request in enumerate (requests ):
640+ scheduler .add_request (request )
641+ req_ids .append (request .request_id )
642+ req_to_index [request .request_id ] = i
643+
644+ # Schedule a decode, which will also draft speculative tokens
645+ output = scheduler .schedule ()
646+ assert len (output .scheduled_new_reqs ) == len (requests )
647+ assert output .total_num_scheduled_tokens == len (requests )
648+ for i in range (len (requests )):
649+ req_id = requests [i ].request_id
650+ assert output .num_scheduled_tokens [req_id ] == 1
651+ assert req_id not in output .scheduled_spec_decode_tokens
652+
653+ model_runner_output = ModelRunnerOutput (
654+ req_ids = req_ids ,
655+ req_id_to_index = req_to_index ,
656+ sampled_token_ids = [[0 ] for _ in range (len (requests ))],
657+ spec_token_ids = spec_tokens ,
658+ logprobs = None ,
659+ prompt_logprobs_dict = {},
660+ )
661+ engine_core_outputs = scheduler .update_from_output (output ,
662+ model_runner_output )
663+
664+ for i in range (len (requests )):
665+ running_req = scheduler .running [i ]
666+ # The prompt token
667+ assert running_req .num_computed_tokens == 1
668+ # The prompt token and the sampled token
669+ assert running_req .num_tokens == 2
670+ # The prompt token, the sampled token, and the speculated tokens
671+ assert running_req .num_tokens_with_spec == 2 + len (spec_tokens [i ])
672+
673+ # No draft or accepted tokens counted yet
674+ assert engine_core_outputs .scheduler_stats .spec_decoding_stats is not None
675+ stats = engine_core_outputs .scheduler_stats .spec_decoding_stats
676+ assert stats .num_draft_tokens == 0
677+ assert stats .num_accepted_tokens == 0
678+
679+ # Schedule the speculated tokens for validation
680+ output = scheduler .schedule ()
681+ assert len (output .scheduled_new_reqs ) == 0
682+ # The sampled token and speculated tokens
683+ assert output .total_num_scheduled_tokens == \
684+ len (requests ) + sum (len (ids ) for ids in spec_tokens )
685+ for i in range (len (requests )):
686+ req_id = requests [i ].request_id
687+ assert output .num_scheduled_tokens [req_id ] == 1 + len (spec_tokens [i ])
688+ if spec_tokens [i ]:
689+ assert len (output .scheduled_spec_decode_tokens [req_id ]) == \
690+ len (spec_tokens [i ])
691+ else :
692+ assert req_id not in output .scheduled_spec_decode_tokens
693+
694+ model_runner_output = ModelRunnerOutput (
695+ req_ids = req_ids ,
696+ req_id_to_index = req_to_index ,
697+ sampled_token_ids = output_tokens ,
698+ spec_token_ids = None ,
699+ logprobs = None ,
700+ prompt_logprobs_dict = {},
701+ )
702+ engine_core_outputs = scheduler .update_from_output (output ,
703+ model_runner_output )
704+
705+ assert engine_core_outputs .scheduler_stats .spec_decoding_stats is not None
706+ stats = engine_core_outputs .scheduler_stats .spec_decoding_stats
707+ assert stats .num_draft_tokens == expected [0 ]
708+ assert stats .num_accepted_tokens == expected [1 ]
0 commit comments