@@ -722,7 +722,7 @@ def execute_model(
722722 kv_caches : List [torch .Tensor ],
723723 extra_inputs : ExtraTensorData = None ,
724724 extra_outputs : Optional [Set [str ]] = None ,
725- ) -> Tuple [ Optional [SamplerOutput ], ExtraTensorData ]:
725+ ) -> Optional [SamplerOutput ]:
726726 (input_tokens , input_positions , attn_metadata , sampling_metadata ,
727727 lora_requests , lora_mapping , multi_modal_input , prepared_extra_inputs
728728 ) = self .prepare_input_tensors (seq_group_metadata_list )
@@ -771,19 +771,19 @@ def execute_model(
771771 logits = self .model .compute_logits (hidden_states , sampling_metadata )
772772
773773 # Only perform sampling in the driver worker.
774- if self .is_driver_worker :
775- # Sample the next token.
776- output = self .model .sample (
777- logits = logits ,
778- sampling_metadata = sampling_metadata ,
779- )
774+ if not self .is_driver_worker :
775+ return None
776+
777+ # Sample the next token.
778+ output = self .model .sample (
779+ logits = logits ,
780+ sampling_metadata = sampling_metadata ,
781+ )
780782
783+ if extra_outputs :
781784 sampled_extra_tensor_data = extra_tensor_data .index_select (
782785 0 , sampling_metadata .selected_token_indices )
783- else :
784- output = None
785786
786- if extra_outputs :
787787 if prefill_meta is not None :
788788 for k in extra_tensor_data :
789789 extra_tensor_data [k ] = extra_tensor_data [k ].roll (shifts = 1 ,
@@ -794,12 +794,13 @@ def execute_model(
794794 if output is not None :
795795 _move_extra_tensor_data_to_seq_outputs (
796796 output , sampled_extra_tensor_data , sampling_metadata )
797+
798+ output .extra_tensor_data = extra_tensor_data
797799 else :
798- extra_tensor_data .clear ()
799800 if output is not None :
800801 output .extra_tensor_data = sampled_extra_tensor_data
801802
802- return output , extra_tensor_data
803+ return output
803804
804805 @torch .inference_mode ()
805806 def profile_run (self ) -> None :
0 commit comments