Skip to content

Commit 369e16d

Browse files
committed
[OPTIMIZATION] Coalesces reads from CPU to GPU in sampler.py
1 parent 81ee050 commit 369e16d

File tree

1 file changed

+22
-27
lines changed

1 file changed

+22
-27
lines changed

vllm/model_executor/layers/sampler.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -443,8 +443,8 @@ def _sample2(
443443
should_optimize = True
444444

445445
for i, seq_group in enumerate(input_metadata.seq_groups[input_metadata.num_prompts:]):
446-
_, sampling_params = seq_group
447-
if sampling_params.use_beam_search or sampling_params.temperature < _SAMPLING_EPS:
446+
seq_ids, sampling_params = seq_group
447+
if sampling_params.use_beam_search or sampling_params.temperature < _SAMPLING_EPS or len(seq_ids) != 1:
448448
should_optimize = False
449449
break
450450

@@ -460,22 +460,26 @@ def _sample_optimized(
460460
) -> Dict[int, SequenceOutputs]:
461461
seq_outputs: Dict[int, SequenceOutputs] = {}
462462

463-
idx = 0
464463
num_prompts = input_metadata.num_prompts
465464

466465
gen_probs = probs[num_prompts:]
467466
gen_next_token_ids = torch.multinomial(gen_probs,
468467
num_samples=1,
469468
replacement=True).squeeze(dim=-1)
469+
chosen_logprobs = logprobs[num_prompts:][torch.arange(gen_next_token_ids.shape[0]), gen_next_token_ids]
470+
chosen_logprobs = chosen_logprobs.squeeze(dim=-1)
471+
if chosen_logprobs.dim() == 0: # If it's a scalar (happens when `gen_next_token_ids.shape == torch.Size([1])`, due to torch indexing)
472+
chosen_logprobs = chosen_logprobs.unsqueeze(0) # Add a dimension back
473+
chosen_logprobs = chosen_logprobs.tolist()
474+
gen_next_token_ids = gen_next_token_ids.tolist()
470475

471476
for i, seq_group in enumerate(input_metadata.seq_groups):
472477
seq_ids, sampling_params = seq_group
473478
if i < num_prompts:
474479
# Generate the next tokens for a prompt input.
475480
assert len(seq_ids) == sampling_params.best_of
476-
prob = probs[idx]
477-
logprob = logprobs[idx]
478-
idx += 1
481+
prob = probs[i]
482+
logprob = logprobs[i]
479483

480484
# Sample the next tokens.
481485
next_token_ids = _sample_from_prompt(prob, sampling_params)
@@ -492,33 +496,24 @@ def _sample_optimized(
492496
output_logprobs)
493497
else:
494498
# Generate the next tokens for generation tokens.
495-
prob = probs[idx:idx + len(seq_ids)]
496-
logprob = logprobs[idx:idx + len(seq_ids)]
499+
logprob = logprobs[i]
497500

498501
# Sample the next tokens.
499-
next_token_ids = gen_next_token_ids[idx - num_prompts:idx + len(seq_ids) - num_prompts]
500-
next_token_ids = next_token_ids.tolist()
501-
parent_seq_ids = seq_ids
502-
idx += len(seq_ids)
502+
next_token_id = gen_next_token_ids[i - num_prompts]
503503

504504
# Get top-k log probabilities for the next tokens.
505505
next_logprobs: Dict[int, Dict[int, float]] = {}
506-
for j, seq_id in enumerate(seq_ids):
507-
next_logprobs[seq_id] = _get_topk_logprobs(
508-
logprob[j], sampling_params.logprobs)
506+
seq_id = seq_ids[0]
507+
next_logprobs[seq_id] = _get_topk_logprobs([logprob], sampling_params.logprobs)
509508

510509
# Build the output.
511-
for seq_id, parent_seq_id, next_token_id in zip(
512-
seq_ids, parent_seq_ids, next_token_ids):
513-
j = seq_ids.index(parent_seq_id)
514-
output_logprobs = next_logprobs[parent_seq_id].copy()
515-
output_logprobs[next_token_id] = logprob[j,
516-
next_token_id].item()
517-
seq_outputs[seq_id] = SequenceOutputs(
518-
seq_id,
519-
parent_seq_id,
520-
next_token_id,
521-
output_logprobs,
522-
)
510+
output_logprobs = next_logprobs[seq_id].copy()
511+
output_logprobs[next_token_id] = chosen_logprobs[i - num_prompts]
512+
seq_outputs[seq_id] = SequenceOutputs(
513+
seq_id,
514+
seq_id,
515+
next_token_id,
516+
output_logprobs,
517+
)
523518

524519
return seq_outputs

0 commit comments

Comments
 (0)