Skip to content
37 changes: 37 additions & 0 deletions docs/features/spec_decode.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,43 @@ matching n-grams in the prompt. For more information read [this thread.](https:/
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

## Speculating using Suffix Decoding

The following code configures vLLM to use speculative decoding where proposals are generated using Suffix Decoding ([technical report](https://arxiv.org/abs/2411.04975)).

Like n-gram, Suffix Decoding can generate draft tokens by pattern-matching using the last `n` generated tokens. Unlike n-gram, Suffix Decoding (1) can pattern-match against both the prompt and previous generations, (2) uses frequency counts to propose the most likely continuations, and (3) speculates an adaptive number of tokens for each request at each iteration to get better acceptance rates.

Suffix Decoding can achieve better performance for tasks with high repetition, such as code-editing, agentic loops (e.g. self-reflection, self-consistency), and RL rollouts.

!!! tip "Install Arctic Inference"
Suffix Decoding requires [Arctic Inference](https://github.com/snowflakedb/ArcticInference). You can install it with `pip install arctic-inference`.

??? code

```python
from vllm import LLM, SamplingParams

prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(
model="facebook/opt-6.7b",
tensor_parallel_size=1,
speculative_config={
"method": "suffix",
"num_speculative_tokens": 16,
},
)
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

## Speculating using MLP speculators

The following code configures vLLM to use speculative decoding where proposals are generated by
Expand Down
1 change: 1 addition & 0 deletions requirements/test.in
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ buildkite-test-collector==0.1.9
genai_perf==0.0.8
tritonclient==2.51.0

arctic-inference == 0.0.9 # Required for suffix decoding test
numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
numba == 0.61.2; python_version > '3.9'
numpy
Expand Down
2 changes: 2 additions & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ anyio==4.6.2.post1
# via
# httpx
# starlette
arctic-inference==0.0.9
# via -r requirements/test.in
argcomplete==3.5.1
# via datamodel-code-generator
arrow==1.3.0
Expand Down
80 changes: 73 additions & 7 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,18 @@ def model_name():
return "meta-llama/Llama-3.1-8B-Instruct"


def test_ngram_correctness(
@pytest.mark.parametrize("speculative_config",
[{
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
}, {
"method": "suffix",
"suffix_decoding_max_spec_factor": 2.0,
}])
def test_ngram_and_suffix_correctness(
speculative_config: dict,
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_name: str,
Expand All @@ -98,12 +109,7 @@ def test_ngram_correctness(

spec_llm = LLM(
model=model_name,
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
speculative_config=speculative_config,
max_model_len=1024,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
Expand All @@ -125,6 +131,66 @@ def test_ngram_correctness(
cleanup_dist_env_and_memory()


def test_suffix_decoding_acceptance(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_name: str,
):
'''
Check that suffix decoding caching takes effect and improves acceptance
lengths and acceptance rates over multiple runs of the same prompts.
'''
test_prompts = get_test_prompts(mm_enabled=False)

spec_llm = LLM(
model=model_name,
speculative_config={
"method": "suffix",
"suffix_decoding_max_spec_factor": 2.0,
"suffix_decoding_max_cached_requests": 1000,
},
max_model_len=1024,
disable_log_stats=False,
)

# Run several times and check that the accepted tokens increase.
spec_llm.chat(test_prompts, sampling_config)
num_draft = []
num_accept = []
for i in range(10): # Run multiple times to warm up the cache.
spec_llm.chat(test_prompts, sampling_config)
# Collect draft and acceptance stats.
metrics = spec_llm.get_metrics()
for metric in metrics:
if metric.name == "vllm:spec_decode_num_draft_tokens":
num_draft.append(metric.value)
if metric.name == "vllm:spec_decode_num_accepted_tokens":
num_accept.append(metric.value)

# Calculate the acceptance rates for the first and last runs.
first_accept_tokens = num_accept[0]
first_draft_tokens = num_draft[0]
first_accept_rate = first_accept_tokens / first_draft_tokens

# Take the diff since the stats are cumulative.
last_accept_tokens = num_accept[-1] - num_accept[-2]
last_draft_tokens = num_draft[-1] - num_draft[-2]
last_accept_rate = last_accept_tokens / last_draft_tokens

# Expect the acceptance length to improve.
assert first_accept_tokens < last_accept_tokens

# Expect the acceptance rate to improve.
assert first_accept_rate < last_accept_rate

# Heuristic: expect at least 85% acceptance rate at the end.
assert last_accept_rate > 0.85

del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()


@pytest.mark.parametrize(
["model_setup", "mm_enabled"],
[
Expand Down
57 changes: 54 additions & 3 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils import LazyLoader
from vllm.utils import LazyLoader, has_arctic_inference

if TYPE_CHECKING:
from transformers import PretrainedConfig
Expand All @@ -32,7 +32,7 @@
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp",
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp",
"longcat_flash_mtp", "mtp"]
"longcat_flash_mtp", "mtp", "suffix"]
MTP_MODEL_TYPES = ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp",
"qwen3_next_mtp", "longcat_flash_mtp")

Expand Down Expand Up @@ -123,6 +123,27 @@ class SpeculativeConfig:
ParallelConfig] = None # type: ignore
"""The parallel configuration for the draft model initialized internal."""

# Suffix decoding configuration
suffix_decoding_max_tree_depth: int = 64
"""The maximum depth of the suffix decoding global and prompt trees. The
tree depth limits the sum of the prefix match and speculation lengths."""

suffix_decoding_max_cached_requests: int = 10000
"""The maximum number of requests to cache in the global suffix tree. If
exceeded, will trigger eviction in FIFO order. If set to 0, the global
suffix tree is disabled and past responses are not cached (prompt trees
are still used)."""

suffix_decoding_max_spec_factor: float = 1.0
"""The maximum spec factor for suffix decoding. The spec factor controls
speculation lengths based on the prefix match length: max_spec_tokens =
max_spec_factor * prefix_match_length."""

suffix_decoding_min_token_prob: float = 0.1
"""The minimum token probability for suffix decoding. Will only speculate
tokens with estimated probability (based on frequency counts) greater than
or equal to this value."""

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down Expand Up @@ -227,6 +248,8 @@ def __post_init__(self):
self.quantization = self.target_model_config.quantization
elif self.method in ("ngram", "[ngram]"):
self.model = "ngram"
elif self.method == "suffix":
self.model = "suffix"
else:
raise ValueError(
"num_speculative_tokens was provided but without "
Expand Down Expand Up @@ -271,6 +294,31 @@ def __post_init__(self):
# draft related config as None here.
self.draft_model_config = self.target_model_config
self.draft_parallel_config = self.target_parallel_config
elif self.method == "suffix":
if not has_arctic_inference():
raise ImportError(
"Arctic Inference is required for suffix decoding. "
"Please install via `pip install arctic-inference`.")
if self.num_speculative_tokens is None:
self.num_speculative_tokens = 32
# Validate values
if self.suffix_decoding_max_tree_depth < 1:
raise ValueError(
f"suffix_decoding_max_tree_depth="
f"{self.suffix_decoding_max_tree_depth} must be >= 1")
if self.suffix_decoding_max_cached_requests < 0:
raise ValueError(
f"suffix_decoding_max_cached_requests="
f"{self.suffix_decoding_max_cached_requests} must be >= 0")
if self.suffix_decoding_max_spec_factor < 0:
raise ValueError(
f"suffix_decoding_max_spec_factor="
f"{self.suffix_decoding_max_spec_factor} must be >= 0")
if (self.suffix_decoding_min_token_prob < 0
or self.suffix_decoding_min_token_prob > 1):
raise ValueError(
f"suffix_decoding_min_token_prob="
f"{self.suffix_decoding_min_token_prob} must be in [0, 1]")
else:
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0
Expand Down Expand Up @@ -557,6 +605,9 @@ def use_eagle(self) -> bool:

def __repr__(self) -> str:
method = self.method
model = None if method == "ngram" else self.draft_model_config.model
if method in ("ngram", "suffix"):
model = None
else:
model = self.draft_model_config.model
num_spec_tokens = self.num_speculative_tokens
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
6 changes: 6 additions & 0 deletions vllm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3440,6 +3440,12 @@ def has_tilelang() -> bool:
return _has_module("tilelang")


def has_arctic_inference() -> bool:
"""Whether the optional `arctic_inference` package is available."""

return _has_module("arctic_inference")


def set_process_title(name: str,
suffix: str = "",
prefix: str = envs.VLLM_PROCESS_NAME_PREFIX) -> None:
Expand Down
99 changes: 99 additions & 0 deletions vllm/v1/spec_decode/suffix_decoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config import VllmConfig
from vllm.v1.worker.gpu_input_batch import InputBatch


class SuffixDecodingProposer:

def __init__(self, vllm_config: VllmConfig):
config = vllm_config.speculative_config
self.num_speculative_tokens = config.num_speculative_tokens
self.max_tree_depth = config.suffix_decoding_max_tree_depth
self.max_spec_factor = config.suffix_decoding_max_spec_factor
self.min_token_prob = config.suffix_decoding_min_token_prob
self.max_model_len = vllm_config.model_config.max_model_len

# Lazy import to avoid error when Suffix Decoding is not used.
from arctic_inference.suffix_decoding import SuffixDecodingCache

self.suffix_cache = SuffixDecodingCache(
max_tree_depth=config.suffix_decoding_max_tree_depth,
max_cached_requests=config.suffix_decoding_max_cached_requests)

def update(
self,
input_batch: InputBatch,
sampled_token_ids: list[list[int]],
):
seen_req_ids = set()
for i, sampled_ids in enumerate(sampled_token_ids):
req_id = input_batch.req_ids[i]
seen_req_ids.add(req_id)

if not sampled_ids:
continue

index = input_batch.req_id_to_index[req_id]
if req_id not in self.suffix_cache.active_requests:
if req_id in self.suffix_cache.cached_requests:
# Reset the suffix cache for this request.
self.suffix_cache.evict_cached_response(req_id)
num_prompt_tokens = input_batch.num_prompt_tokens[index]
prompt_token_ids = (
input_batch.token_ids_cpu[index, :num_prompt_tokens])
prompt_token_ids = prompt_token_ids.tolist()
self.suffix_cache.start_request(req_id, prompt_token_ids)

self.suffix_cache.add_active_response(req_id, sampled_ids)

# Stop requests that are not seen
for req_id in list(self.suffix_cache.active_requests):
if req_id not in seen_req_ids:
self.suffix_cache.stop_request(req_id)

def propose(
self,
input_batch: InputBatch,
sampled_token_ids: list[list[int]],
) -> list[list[int]]:
req_ids = input_batch.req_ids
draft_token_ids: list[list[int]] = []
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
# Skip speculative decoding.
draft_token_ids.append([])
continue

# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id = req_ids[i]
if req_id in input_batch.spec_decode_unsupported_reqs:
draft_token_ids.append([])
continue

num_tokens = input_batch.num_tokens_no_spec[i]
if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length.
draft_token_ids.append([])
continue

start = max(0, num_tokens - self.max_tree_depth)
pattern = input_batch.token_ids_cpu[i, start:num_tokens]
pattern = pattern.tolist()
draft = self.suffix_cache.speculate(
req_id,
pattern,
max_spec_tokens=min(self.num_speculative_tokens,
self.max_model_len - num_tokens - 1),
max_spec_factor=self.max_spec_factor,
min_token_prob=self.min_token_prob)

draft_token_ids.append(draft.token_ids)

return draft_token_ids

def load_model(self, *args, **kwargs):
# No model to load.
pass
Loading