diff --git a/docs/features/spec_decode.md b/docs/features/spec_decode.md index 25c308a6ff20..885b457ea620 100644 --- a/docs/features/spec_decode.md +++ b/docs/features/spec_decode.md @@ -130,6 +130,43 @@ matching n-grams in the prompt. For more information read [this thread.](https:/ print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` +## Speculating using Suffix Decoding + +The following code configures vLLM to use speculative decoding where proposals are generated using Suffix Decoding ([technical report](https://arxiv.org/abs/2411.04975)). + +Like n-gram, Suffix Decoding can generate draft tokens by pattern-matching using the last `n` generated tokens. Unlike n-gram, Suffix Decoding (1) can pattern-match against both the prompt and previous generations, (2) uses frequency counts to propose the most likely continuations, and (3) speculates an adaptive number of tokens for each request at each iteration to get better acceptance rates. + +Suffix Decoding can achieve better performance for tasks with high repetition, such as code-editing, agentic loops (e.g. self-reflection, self-consistency), and RL rollouts. + +!!! tip "Install Arctic Inference" + Suffix Decoding requires [Arctic Inference](https://github.com/snowflakedb/ArcticInference). You can install it with `pip install arctic-inference`. + +??? code + + ```python + from vllm import LLM, SamplingParams + + prompts = [ + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + llm = LLM( + model="facebook/opt-6.7b", + tensor_parallel_size=1, + speculative_config={ + "method": "suffix", + "num_speculative_tokens": 16, + }, + ) + outputs = llm.generate(prompts, sampling_params) + + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + ``` + ## Speculating using MLP speculators The following code configures vLLM to use speculative decoding where proposals are generated by diff --git a/requirements/test.in b/requirements/test.in index ef21d6db5b4f..cbc14ed95cf4 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -48,6 +48,7 @@ buildkite-test-collector==0.1.9 genai_perf==0.0.8 tritonclient==2.51.0 +arctic-inference == 0.0.9 # Required for suffix decoding test numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding numba == 0.61.2; python_version > '3.9' numpy diff --git a/requirements/test.txt b/requirements/test.txt index 9cab85ce0ef6..4f475c184967 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -39,6 +39,8 @@ anyio==4.6.2.post1 # via # httpx # starlette +arctic-inference==0.0.9 + # via -r requirements/test.in argcomplete==3.5.1 # via datamodel-code-generator arrow==1.3.0 diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index fbbbd0389c26..457447a9b567 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -77,7 +77,18 @@ def model_name(): return "meta-llama/Llama-3.1-8B-Instruct" -def test_ngram_correctness( +@pytest.mark.parametrize("speculative_config", + [{ + "method": "ngram", + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, + "num_speculative_tokens": 3, + }, { + "method": "suffix", + "suffix_decoding_max_spec_factor": 2.0, + }]) +def test_ngram_and_suffix_correctness( + speculative_config: dict, monkeypatch: pytest.MonkeyPatch, sampling_config: SamplingParams, model_name: str, @@ -96,12 +107,7 @@ def test_ngram_correctness( spec_llm = LLM( model=model_name, - speculative_config={ - "method": "ngram", - "prompt_lookup_max": 5, - "prompt_lookup_min": 3, - "num_speculative_tokens": 3, - }, + speculative_config=speculative_config, max_model_len=1024, ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) @@ -123,6 +129,66 @@ def test_ngram_correctness( cleanup_dist_env_and_memory() +def test_suffix_decoding_acceptance( + monkeypatch: pytest.MonkeyPatch, + sampling_config: SamplingParams, + model_name: str, +): + ''' + Check that suffix decoding caching takes effect and improves acceptance + lengths and acceptance rates over multiple runs of the same prompts. + ''' + test_prompts = get_test_prompts(mm_enabled=False) + + spec_llm = LLM( + model=model_name, + speculative_config={ + "method": "suffix", + "suffix_decoding_max_spec_factor": 2.0, + "suffix_decoding_max_cached_requests": 1000, + }, + max_model_len=1024, + disable_log_stats=False, + ) + + # Run several times and check that the accepted tokens increase. + spec_llm.chat(test_prompts, sampling_config) + num_draft = [] + num_accept = [] + for i in range(10): # Run multiple times to warm up the cache. + spec_llm.chat(test_prompts, sampling_config) + # Collect draft and acceptance stats. + metrics = spec_llm.get_metrics() + for metric in metrics: + if metric.name == "vllm:spec_decode_num_draft_tokens": + num_draft.append(metric.value) + if metric.name == "vllm:spec_decode_num_accepted_tokens": + num_accept.append(metric.value) + + # Calculate the acceptance rates for the first and last runs. + first_accept_tokens = num_accept[0] + first_draft_tokens = num_draft[0] + first_accept_rate = first_accept_tokens / first_draft_tokens + + # Take the diff since the stats are cumulative. + last_accept_tokens = num_accept[-1] - num_accept[-2] + last_draft_tokens = num_draft[-1] - num_draft[-2] + last_accept_rate = last_accept_tokens / last_draft_tokens + + # Expect the acceptance length to improve. + assert first_accept_tokens < last_accept_tokens + + # Expect the acceptance rate to improve. + assert first_accept_rate < last_accept_rate + + # Heuristic: expect at least 85% acceptance rate at the end. + assert last_accept_rate > 0.85 + + del spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + @pytest.mark.parametrize( ["model_setup", "mm_enabled"], [ diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index aa0c07cf62a3..bbc6e55c8729 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -13,7 +13,7 @@ from vllm.config.parallel import ParallelConfig from vllm.config.utils import config from vllm.logger import init_logger -from vllm.utils import LazyLoader +from vllm.utils import LazyLoader, has_arctic_inference if TYPE_CHECKING: from transformers import PretrainedConfig @@ -43,6 +43,7 @@ "mimo_mtp", "longcat_flash_mtp", "mtp", + "suffix", ] MTP_MODEL_TYPES = ( "deepseek_mtp", @@ -140,6 +141,27 @@ class SpeculativeConfig: draft_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore """The parallel configuration for the draft model initialized internal.""" + # Suffix decoding configuration + suffix_decoding_max_tree_depth: int = 64 + """The maximum depth of the suffix decoding global and prompt trees. The + tree depth limits the sum of the prefix match and speculation lengths.""" + + suffix_decoding_max_cached_requests: int = 10000 + """The maximum number of requests to cache in the global suffix tree. If + exceeded, will trigger eviction in FIFO order. If set to 0, the global + suffix tree is disabled and past responses are not cached (prompt trees + are still used).""" + + suffix_decoding_max_spec_factor: float = 1.0 + """The maximum spec factor for suffix decoding. The spec factor controls + speculation lengths based on the prefix match length: max_spec_tokens = + max_spec_factor * prefix_match_length.""" + + suffix_decoding_min_token_prob: float = 0.1 + """The minimum token probability for suffix decoding. Will only speculate + tokens with estimated probability (based on frequency counts) greater than + or equal to this value.""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -247,6 +269,8 @@ def __post_init__(self): self.quantization = self.target_model_config.quantization elif self.method in ("ngram", "[ngram]"): self.model = "ngram" + elif self.method == "suffix": + self.model = "suffix" else: raise ValueError( "num_speculative_tokens was provided but without speculative model." @@ -294,6 +318,31 @@ def __post_init__(self): # draft related config as None here. self.draft_model_config = self.target_model_config self.draft_parallel_config = self.target_parallel_config + elif self.method == "suffix": + if not has_arctic_inference(): + raise ImportError( + "Arctic Inference is required for suffix decoding. " + "Please install via `pip install arctic-inference`.") + if self.num_speculative_tokens is None: + self.num_speculative_tokens = 32 + # Validate values + if self.suffix_decoding_max_tree_depth < 1: + raise ValueError( + f"suffix_decoding_max_tree_depth=" + f"{self.suffix_decoding_max_tree_depth} must be >= 1") + if self.suffix_decoding_max_cached_requests < 0: + raise ValueError( + f"suffix_decoding_max_cached_requests=" + f"{self.suffix_decoding_max_cached_requests} must be >= 0") + if self.suffix_decoding_max_spec_factor < 0: + raise ValueError( + f"suffix_decoding_max_spec_factor=" + f"{self.suffix_decoding_max_spec_factor} must be >= 0") + if (self.suffix_decoding_min_token_prob < 0 + or self.suffix_decoding_min_token_prob > 1): + raise ValueError( + f"suffix_decoding_min_token_prob=" + f"{self.suffix_decoding_min_token_prob} must be in [0, 1]") else: self.prompt_lookup_max = 0 self.prompt_lookup_min = 0 @@ -599,6 +648,9 @@ def use_eagle(self) -> bool: def __repr__(self) -> str: method = self.method - model = None if method == "ngram" else self.draft_model_config.model + if method in ("ngram", "suffix"): + model = None + else: + model = self.draft_model_config.model num_spec_tokens = self.num_speculative_tokens return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 4a6a79ad067b..85ac29881165 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3308,9 +3308,15 @@ def has_tilelang() -> bool: return _has_module("tilelang") -def set_process_title( - name: str, suffix: str = "", prefix: str = envs.VLLM_PROCESS_NAME_PREFIX -) -> None: +def has_arctic_inference() -> bool: + """Whether the optional `arctic_inference` package is available.""" + + return _has_module("arctic_inference") + + +def set_process_title(name: str, + suffix: str = "", + prefix: str = envs.VLLM_PROCESS_NAME_PREFIX) -> None: """ Set the current process title to a specific name with an optional suffix. diff --git a/vllm/v1/spec_decode/suffix_decoding.py b/vllm/v1/spec_decode/suffix_decoding.py new file mode 100644 index 000000000000..37136e596a56 --- /dev/null +++ b/vllm/v1/spec_decode/suffix_decoding.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.config import VllmConfig +from vllm.v1.worker.gpu_input_batch import InputBatch + + +class SuffixDecodingProposer: + + def __init__(self, vllm_config: VllmConfig): + config = vllm_config.speculative_config + self.num_speculative_tokens = config.num_speculative_tokens + self.max_tree_depth = config.suffix_decoding_max_tree_depth + self.max_spec_factor = config.suffix_decoding_max_spec_factor + self.min_token_prob = config.suffix_decoding_min_token_prob + self.max_model_len = vllm_config.model_config.max_model_len + + # Lazy import to avoid error when Suffix Decoding is not used. + from arctic_inference.suffix_decoding import SuffixDecodingCache + + self.suffix_cache = SuffixDecodingCache( + max_tree_depth=config.suffix_decoding_max_tree_depth, + max_cached_requests=config.suffix_decoding_max_cached_requests) + + def update( + self, + input_batch: InputBatch, + sampled_token_ids: list[list[int]], + ): + seen_req_ids = set() + for i, sampled_ids in enumerate(sampled_token_ids): + req_id = input_batch.req_ids[i] + seen_req_ids.add(req_id) + + if not sampled_ids: + continue + + index = input_batch.req_id_to_index[req_id] + if req_id not in self.suffix_cache.active_requests: + if req_id in self.suffix_cache.cached_requests: + # Reset the suffix cache for this request. + self.suffix_cache.evict_cached_response(req_id) + num_prompt_tokens = input_batch.num_prompt_tokens[index] + prompt_token_ids = ( + input_batch.token_ids_cpu[index, :num_prompt_tokens]) + prompt_token_ids = prompt_token_ids.tolist() + self.suffix_cache.start_request(req_id, prompt_token_ids) + + self.suffix_cache.add_active_response(req_id, sampled_ids) + + # Stop requests that are not seen + for req_id in list(self.suffix_cache.active_requests): + if req_id not in seen_req_ids: + self.suffix_cache.stop_request(req_id) + + def propose( + self, + input_batch: InputBatch, + sampled_token_ids: list[list[int]], + ) -> list[list[int]]: + req_ids = input_batch.req_ids + draft_token_ids: list[list[int]] = [] + for i, sampled_ids in enumerate(sampled_token_ids): + num_sampled_ids = len(sampled_ids) + if not num_sampled_ids: + # Skip speculative decoding. + draft_token_ids.append([]) + continue + + # Skip requests that require sampling parameters that are not + # supported with speculative decoding. + req_id = req_ids[i] + if req_id in input_batch.spec_decode_unsupported_reqs: + draft_token_ids.append([]) + continue + + num_tokens = input_batch.num_tokens_no_spec[i] + if num_tokens >= self.max_model_len: + # Skip requests that have already reached the max model length. + draft_token_ids.append([]) + continue + + start = max(0, num_tokens - self.max_tree_depth) + pattern = input_batch.token_ids_cpu[i, start:num_tokens] + pattern = pattern.tolist() + draft = self.suffix_cache.speculate( + req_id, + pattern, + max_spec_tokens=min(self.num_speculative_tokens, + self.max_model_len - num_tokens - 1), + max_spec_factor=self.max_spec_factor, + min_token_prob=self.min_token_prob) + + draft_token_ids.append(draft.token_ids) + + return draft_token_ids + + def load_model(self, *args, **kwargs): + # No model to load. + pass diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bd799c06c0eb..a5214f338a49 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -129,6 +129,7 @@ from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.worker.dp_utils import coordinate_batch_across_dp @@ -320,6 +321,9 @@ def __init__( if self.speculative_config and get_pp_group().is_last_rank: if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) + elif self.speculative_config.method == "suffix": + self.drafter = SuffixDecodingProposer( + self.vllm_config) # type: ignore elif self.speculative_config.use_eagle(): self.drafter = EagleProposer(self.vllm_config, self.device, self) # type: ignore if self.speculative_config.method == "eagle3": @@ -2346,6 +2350,10 @@ def _bookkeeping_sync( req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) + if (self.speculative_config + and isinstance(self.drafter, SuffixDecodingProposer)): + self.drafter.update(self.input_batch, valid_sampled_token_ids) + return ( num_nans_in_logits, logprobs_lists, @@ -2697,6 +2705,11 @@ def propose_draft_token_ids( self.input_batch.token_ids_cpu, self.input_batch.spec_decode_unsupported_reqs, ) + elif self.speculative_config.method == "suffix": + assert isinstance(sampled_token_ids, list) + assert isinstance(self.drafter, SuffixDecodingProposer) + draft_token_ids = self.drafter.propose(self.input_batch, + sampled_token_ids) elif self.speculative_config.method == "medusa": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer)