Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,13 +510,16 @@ def test_sampler_mixed(seed: int, device: str):
))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

generators: Dict[str, torch.Generator] = {}

def test_sampling():
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=is_pin_memory_available())
pin_memory=is_pin_memory_available(),
generators=generators)
sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata)

Expand Down
1 change: 1 addition & 0 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs,
def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
per_test_common_llm_kwargs, distinct_llm_kwargs,
seed):
print("CREATE LLM GENERATOR")
kwargs = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
Expand Down
54 changes: 53 additions & 1 deletion tests/spec_decode/e2e/test_mlp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

import pytest

from .conftest import run_greedy_equality_correctness_test
from .conftest import (run_equality_correctness_test,
run_greedy_equality_correctness_test)

# main model
MAIN_MODEL = "JackFram/llama-160m"
Expand Down Expand Up @@ -77,6 +78,57 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Required for spec decode.
"use_v2_block_manager": True,

# Print spec metrics.
"disable_log_stats": False,

# Precision
"dtype": PRECISION,

# Main model
"model": MAIN_MODEL,

# Speculative model
"speculative_model": SPEC_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
@pytest.mark.parametrize("output_len", [64])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("temperature", [0.1, 1.0])
@pytest.mark.parametrize("seed", [None])
def test_mlp_e2e_seeded_correctness(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int,
temperature: float):
"""Verify seeded runs produce the same output."""
run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
temperature=temperature,
seeded=True,
force_output_len=True)

# Ensure this same test does fail if we _don't_ include per-request seeds
with pytest.raises(AssertionError):
run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
temperature=temperature,
seeded=False,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand Down
2 changes: 1 addition & 1 deletion tests/spec_decode/e2e/test_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"output_len",
[
# Use smaller output len for fast test.
10,
20,
])
@pytest.mark.parametrize("seed", [None])
def test_seeded_consistency(baseline_llm_generator, test_llm_generator,
Expand Down
14 changes: 14 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,20 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]):
"usage": completion.usage,
})

# test seeded random sampling
completion = client.completions.create(model=model,
prompt=prompt,
max_tokens=5,
seed=33,
temperature=1.0)

results.append({
"test": "seeded_sampling",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
})

# test simple list
batch = client.completions.create(
model=model,
Expand Down
1 change: 0 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,6 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums,
state=seq_group.state,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
Expand Down
95 changes: 39 additions & 56 deletions vllm/model_executor/layers/rejection_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import torch
import torch.jit
Expand Down Expand Up @@ -36,7 +36,7 @@ def forward(
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
) -> torch.Tensor:
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
Expand Down Expand Up @@ -66,6 +66,9 @@ def forward(
probabilities.
shape = [batch_size, num_speculative_tokens]

seeded_seqs: Dict of batch row index to torch generator, for
sequences using seeded generation.

Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
Expand All @@ -83,7 +86,7 @@ def forward(
target_probs,
draft_probs,
draft_token_ids,
generators,
seeded_seqs,
))

output_token_ids = self._create_output(
Expand All @@ -100,7 +103,7 @@ def _batch_modified_rejection_sampling(
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Perform modified rejection sampling on each sequence.

Expand All @@ -117,23 +120,17 @@ def _batch_modified_rejection_sampling(

# shape [batch_size, k]
accepted = self._get_accepted(target_probs, draft_probs,
draft_token_ids, generators)
draft_token_ids, seeded_seqs)

recovered_probs = self._get_recovered_probs(
target_probs, draft_probs).reshape(batch_size * k, vocab_size)

seed_indices, non_seed_indices = self._split_batch_by_seeded(
generators, k=k)

# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids = _multinomial(
recovered_probs,
num_samples=1,
k=k,
generators=generators,
seed_indices=seed_indices,
# this arg is unused when None but torch.jit requires a list
non_seed_indices=non_seed_indices or [],
seeded_seqs=seeded_seqs or {},
).reshape(batch_size, k)

return accepted, recovered_token_ids
Expand All @@ -143,7 +140,7 @@ def _get_accepted(
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]],
) -> torch.Tensor:
r"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
Expand Down Expand Up @@ -178,24 +175,26 @@ def _get_accepted(
selected_target_probs = target_probs[batch_indices, probs_indicies,
draft_token_ids]

seed_indices, non_seed_indices = self._split_batch_by_seeded(
generators)

if len(seed_indices) == 0:
if not seeded_seqs:
uniform_rand = torch.rand_like(selected_target_probs)
else:
uniform_rand = torch.empty_like(selected_target_probs)

for idx in seed_indices:
uniform_rand[idx, :] = torch.rand(1,
k,
dtype=self.probs_dtype,
device=target_probs.device,
generator=generators[idx])

if non_seed_indices:
uniform_rand[non_seed_indices, :] = torch.rand(
len(non_seed_indices),
non_seeded_indices = []
for idx in range(batch_size):
generator = seeded_seqs.get(idx)
if generator is None:
non_seeded_indices.append(idx)
else:
uniform_rand[idx, :] = torch.rand(
1,
k,
dtype=self.probs_dtype,
device=target_probs.device,
generator=generator)
if non_seeded_indices:
uniform_rand[non_seeded_indices, :] = torch.rand(
len(non_seeded_indices),
k,
dtype=self.probs_dtype,
device=target_probs.device)
Expand Down Expand Up @@ -272,27 +271,6 @@ def _smallest_positive_value(self) -> float:
"""
return torch.finfo(self.probs_dtype).tiny

# partition batch into indices for which a generator is provided
# and indicies for which no generator is provided
@staticmethod
def _split_batch_by_seeded(
generators: List[Optional[torch.Generator]],
k: int = 1,
) -> Tuple[List[int], Optional[List[int]]]:

if all(generator is None for generator in generators):
seed_indices: List[int] = []
non_seed_indices: Optional[List[int]] = None
else:
seed_indices, non_seed_indices = [], []
for i, generator in enumerate(generators):
if generator is None:
non_seed_indices.extend(range(k * i, k * (i + 1)))
else:
seed_indices.extend(range(k * i, k * (i + 1)))

return seed_indices, non_seed_indices


# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
Expand All @@ -304,9 +282,7 @@ def _multinomial(
probs: torch.Tensor,
num_samples: int,
k: int,
generators: List[Optional[torch.Generator]],
seed_indices: List[int],
non_seed_indices: List[int],
seeded_seqs: Dict[int, torch.Generator],
) -> torch.Tensor:

if num_samples > 1:
Expand All @@ -315,13 +291,20 @@ def _multinomial(
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view(
-1, probs.shape[1])

q = torch.empty_like(probs)
if len(seed_indices) == 0:
if not seeded_seqs:
q.exponential_(1.0)
else:
q[non_seed_indices].exponential_(1.0)
for idx in seed_indices:
q[idx].exponential_(1.0, generator=generators[idx // k])
non_seeded_indices: List[int] = []
start = 0
for idx in range(len(q) // k):
end = start + k
generator = seeded_seqs.get(idx)
if generator is None:
non_seeded_indices.extend(list(range(start, end)))
else:
q[start:end].exponential_(1.0, generator=generator)
start = end
q[non_seeded_indices].exponential_(1.0)

return probs.div_(q).argmax(dim=1).view(-1, num_samples)
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/spec_decode_base_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import List, Optional
from typing import Dict, Optional

import torch
import torch.jit
Expand Down Expand Up @@ -237,6 +237,6 @@ def forward(
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
) -> torch.Tensor:
raise NotImplementedError
Loading