@@ -84,15 +84,16 @@ def score_proposals(
8484 assert len (target_sampler_output ) == 1 , "expected single-step output"
8585 target_sampler_output , _ = target_sampler_output [0 ]
8686
87- all_tokens , all_probs , spec_logprobs , all_extra_output_data = self ._contract_batch (
88- contracted_bs = len (execute_model_req .seq_group_metadata_list ),
89- target_sampler_output = target_sampler_output ,
90- proposals = proposals ,
91- num_scoring_tokens = num_scoring_tokens ,
92- non_spec_indices = non_spec_indices ,
93- spec_indices = spec_indices ,
94- k = execute_model_req .num_lookahead_slots ,
95- )
87+ (all_tokens , all_probs , spec_logprobs ,
88+ all_extra_output_data ) = self ._contract_batch (
89+ contracted_bs = len (execute_model_req .seq_group_metadata_list ),
90+ target_sampler_output = target_sampler_output ,
91+ proposals = proposals ,
92+ num_scoring_tokens = num_scoring_tokens ,
93+ non_spec_indices = non_spec_indices ,
94+ spec_indices = spec_indices ,
95+ k = execute_model_req .num_lookahead_slots ,
96+ )
9697
9798 return SpeculativeScores (
9899 probs = all_probs ,
@@ -217,7 +218,8 @@ def _contract_batch(
217218 all_probs [non_spec_indices , :1 , :] = non_spec_target_probs
218219 all_logprobs [non_spec_indices , :1 , :] = non_spec_target_logprobs
219220
220- if all_extra_output_data and non_spec_target_extra_output_data is not None :
221+ if all_extra_output_data and \
222+ non_spec_target_extra_output_data is not None :
221223 for k in all_extra_output_data :
222224 all_extra_output_data [k ][
223225 non_spec_indices , :
@@ -382,8 +384,9 @@ def _split_scoring_output(
382384 if sampler_output .extra_tensor_data is None :
383385 spec_extra_output_data , no_spec_extra_output_data = (None , None )
384386 else :
385- spec_extra_output_data , no_spec_extra_output_data = sampler_output .extra_tensor_data .split (
386- split_sizes )
387+ spec_extra_output_data , no_spec_extra_output_data = sampler_output \
388+ .extra_tensor_data \
389+ .split (split_sizes )
387390
388391 # Convert scores to tensors.
389392 sampler_output .sampled_token_probs = spec_probs
0 commit comments