Skip to content
36 changes: 22 additions & 14 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1932,24 +1932,32 @@ def execute_model(

if model_input.inputs_embeds is not None:
if self.is_driver_worker:
sampled = broadcast_tensor_dict(
{"token_ids": output.sampled_token_ids})
sampled_token_ids = []
valid_outputs = []
for sequence_group_output in output.outputs:
if len(sequence_group_output.samples) == 0:
continue
assert len(sequence_group_output.samples) == 1
valid_outputs.append(sequence_group_output)
sampled_token_ids.append(
sequence_group_output.samples[0].output_token)
sampled_token_ids = torch.tensor(sampled_token_ids).to(
self.device)
sampled_token_ids = broadcast_tensor_dict(
{"sampled_token_ids":
sampled_token_ids})["sampled_token_ids"]
else:
sampled = broadcast_tensor_dict()
if sampled["token_ids"] is not None:
sampled_token_embeds = self.model.get_input_embeddings(
sampled["token_ids"].squeeze(1))
sampled_token_ids = broadcast_tensor_dict(
)["sampled_token_ids"]
if len(sampled_token_ids) > 0:
sampled_token_embeds = \
self.model.get_input_embeddings(sampled_token_ids)
if self.is_driver_worker:
self.sampler.include_gpu_probs_tensor = \
orig_include_gpu_probs

output.sampled_token_embeds = sampled_token_embeds

for token_embed, sequence_group_output in zip(
output.sampled_token_embeds, output.outputs):
assert len(sequence_group_output.samples) == 1
sequence_group_output.samples[
0].output_embed = token_embed
for i, sequence_group_output in enumerate(valid_outputs):
sequence_group_output.samples[0].output_embed = \
sampled_token_embeds[i]

if not self.is_driver_worker:
return []
Expand Down