Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ Easy, fast, and cheap LLM serving for everyone

---

## What is the purpose of this fork?

This is a fork of vLLM which we are using to develop support for *span semantics*.

---

*Latest News* 🔥

- [2025/08] We hosted [vLLM Shenzhen Meetup](https://mp.weixin.qq.com/s/k8ZBO1u2_2odgiKWH_GVTQ) focusing on the ecosystem around vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Ua2SVKVSu-wp5vou_6ElraDt2bnKhiEA).
Expand Down
192 changes: 192 additions & 0 deletions examples/offline_inference/spans/spans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import time

# to ensure deterministic behaviour
os.environ["TOKENIZERS_PARALLELISM"] = "False"

# standard imports
from vllm import LLM, SamplingParams
from vllm.inputs import TokensPrompt


# helper functions
def pad(toklist, padtok):
return toklist[:-1] + [padtok] * ((16 - len(toklist)) % 16) + toklist[-1:]


def avg(list_of_numbers):
return sum(list_of_numbers) / max(len(list_of_numbers), 1)


def wrap(prompt):
if isinstance(prompt[0], list):
return [TokensPrompt(prompt_token_ids=p) for p in prompt]
return TokensPrompt(prompt_token_ids=prompt)


def initialize_vllm(
model, temp=0.6, logprobs=None, max_toks=32768, max_generated_toks=1
):
# boot up vLLM
samp_params_preload = SamplingParams(temperature=temp, max_tokens=1)
samp_params_generate = SamplingParams(
temperature=temp, max_tokens=max_generated_toks, logprobs=logprobs
)
llm = LLM(
model=model,
gpu_memory_utilization=0.9,
enforce_eager=True, # <- so it boots faster
block_size=16,
)
tok = llm.get_tokenizer()
tok_fun = lambda x: tok.convert_tokens_to_ids(tok.tokenize(x))
return samp_params_preload, samp_params_generate, tok_fun, llm


def main():
model_names = [
"ldsjmdy/Tulu3-Block-FT", # <- finetuned to handle block-attention
"ldsjmdy/Tulu3-RAG", # <- baseline
]
model_name = model_names[0]

# tokens that need to be set to perform block-attention
PAD_TOK = 27 # <- "<"
SPAN_TOK_PLUS = 10 # <- "+"
SPAN_TOK_CROSS = 31 # <- "@"

# vLLM-specific env vars

# enables block attention
# -> when this line is not commented, we expect a speedup
# in the execution of the last two .generate calls
os.environ["VLLM_V1_SPANS_ENABLED"] = "True"

# the token that tells vLLM "this is the beginning of a span"
os.environ["VLLM_V1_SPANS_TOKEN_PLUS"] = str(SPAN_TOK_PLUS)

# token that tells vLLM:
# "from here on, recompute KV vectors if any previous tokens differ"
os.environ["VLLM_V1_SPANS_TOKEN_CROSS"] = str(SPAN_TOK_CROSS)

# will print every step of the span process if set to true
os.environ["VLLM_V1_SPANS_DEBUG"] = "True"

# will disable the adjustment of positional encodings when a KV cache
# block is loaded to a different position than it was stored
# -> when this line is not commented,
# spans overlap in their positional encodings
os.environ["VLLM_V1_SPANS_DISABLE_REPOSITION"] = "True"

# general env vars

# now we instantiate the model
samp_params_preload, samp_params_generate, tok, llm = initialize_vllm(
model_name, max_generated_toks=128, max_toks=10_000, temp=0.0
)

# components of the prompt template
prefix = pad(
tok(
"<|system|>\nYou are an intelligent AI assistant. "
"Please answer questions based on the user's instructions. "
"Below are some reference documents that may help you in "
"answering the user's question."
),
PAD_TOK,
)
midfx = [SPAN_TOK_CROSS] + tok(
"<|user|>\nPlease write a high-quality answer for the "
"given question using only the provided search documents "
"(some of which might be irrelevant).\nQuestion: "
)
postfx = tok("""\n<|assistant|>\n""")

print("---->", postfx)

# task-specific documents
doc_a = pad(
[SPAN_TOK_PLUS]
+ tok(
"[0] The Template-Assisted "
"Selective Epitaxy (TASE) method, developed at "
"IBM Research Europe – Zurich, permits to "
"create a homogeneous integration route for "
"various semiconductor materials which is "
"compatible with the CMOS process."
),
PAD_TOK,
)

doc_b = pad(
[SPAN_TOK_PLUS]
+ tok(
"[1] The dominant sequence transduction "
"models are based on complex recurrent or "
"convolutional neural networks in an encoder-decoder "
"configuration. "
),
PAD_TOK,
)

# # alt-docs (purely to check performance on longer documents)
"""
a_toks = tok("Sequence Transduction Models")
b_toks = tok("Template-Assisted Selective Epitaxy")
doc_a = pad(
[SPAN_TOK_PLUS]
+ [a_toks[idx % len(a_toks)] for idx in range(10_000)],
PAD_TOK,
)
doc_b = pad(
[SPAN_TOK_PLUS]
+ [b_toks[idx % len(a_toks)] for idx in range(10_000)],
PAD_TOK,
)
"""

# user query
query = (
midfx
+ tok(
"Tell me which one concerns deep learning. "
"Indicate your answer with a number in brackets."
)
+ postfx
)

# preload documents
ts_pre = time.time()
llm.generate(
[wrap(doc_a), wrap(doc_b), wrap(prefix)], sampling_params=samp_params_preload
)
te_pre = time.time() - ts_pre

ts_gen = time.time()

# this now will load prefix, doc_a, doc_b,
# from the KV cache regardless of the order
model_response_1 = llm.generate(
wrap(prefix + doc_a + doc_b + query),
sampling_params=samp_params_generate,
use_tqdm=False,
)

# this should also run faster:
model_response_2 = llm.generate(
wrap(prefix + doc_b + doc_a + query),
sampling_params=samp_params_generate,
use_tqdm=False,
)

te_gen = time.time() - ts_gen

print(f"doc preload time / TTFT : {te_pre:.4f} / {te_gen:.4f} (s)")
print("model output 1 was:", model_response_1[0].outputs[0].text)
print("model output 2 was:", model_response_2[0].outputs[0].text)


if __name__ == "__main__":
main()
31 changes: 31 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,12 @@
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
# spans vars
VLLM_V1_SPANS_ENABLED: bool = False
VLLM_V1_SPANS_DEBUG: bool = False
VLLM_V1_SPANS_TOKEN_PLUS: int = -1
VLLM_V1_SPANS_TOKEN_CROSS: int = -1
VLLM_V1_SPANS_DISABLE_REPOSITION: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -1221,6 +1227,31 @@ def get_vllm_port() -> Optional[int]:
# raw bytes. Defaults to True for backward compatibility.
"VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES":
lambda: bool(int(os.getenv("VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES", "1"))),

# whether to enable block-attention (span detection, fan-in, repositioning)
"VLLM_V1_SPANS_ENABLED":
lambda: os.environ.get("VLLM_V1_SPANS_ENABLED", "False") == "True",

# whether to print details pertaining to the block-attention
# implementation
"VLLM_V1_SPANS_DEBUG":
lambda: os.environ.get("VLLM_V1_SPANS_DEBUG", "False") == "True",

# for block-attention, the token that will be used in order to
# indicate the beginning of a span (needed for it to work)
"VLLM_V1_SPANS_TOKEN_PLUS":
lambda: int(os.environ.get("VLLM_V1_SPANS_TOKEN_PLUS", "-1")),

# for block-attention, a token that signals the beginning of a
# span which needs to depend on all previous tokens
"VLLM_V1_SPANS_TOKEN_CROSS":
lambda: int(os.environ.get("VLLM_V1_SPANS_TOKEN_CROSS", "-1")),

# for block-attention, detected spans will be loaded but not repositioned
"VLLM_V1_SPANS_DISABLE_REPOSITION":
lambda: os.environ.get("VLLM_V1_SPANS_DISABLE_REPOSITION", "False"
) == "True",

}

# --8<-- [end:env-vars-definition]
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/rotary_embedding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def forward_native(
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
invert_rotation_angle: bool = False # <- to unrope kv's
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""A PyTorch-native implementation of forward()."""
if offsets is not None:
Expand All @@ -71,6 +72,8 @@ def forward_native(
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
if invert_rotation_angle:
sin = -sin

query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/rotary_embedding/mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def forward_native(
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
invert_rotation_angle: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward().

Expand Down
53 changes: 52 additions & 1 deletion vllm/v1/core/block_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Iterable
from typing import Optional

import vllm.envs as envs
from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared,
BlockRemoved, BlockStored,
KVCacheEvent)
Expand Down Expand Up @@ -145,6 +146,8 @@ def cache_full_blocks(
if new_hashes is not None:
new_hashes.append(maybe_convert_block_hash(block_hash))

self._set_block_positions(new_full_blocks, blocks, request)

if self.enable_kv_cache_events:
if num_cached_blocks == 0:
parent_block_hash: Optional[ExternalBlockHash] = None
Expand All @@ -167,6 +170,47 @@ def cache_full_blocks(
medium=MEDIUM_GPU,
))

def _set_block_positions(self, new_full_blocks: list[KVCacheBlock],
blocks: list[KVCacheBlock], request: Request):
"""Sets the positions of new full blocks in the KV cache.

This function assigns positions to newly filled blocks based
on their order within the provided block list. The position
corresponds to the location embedded in K vectors (if using RoPE)
in the KV cache and is critical for maintaining correct alignment,
especially when prompt positions differ between requests.

Args:
new_full_blocks: List of KVCacheBlock objects that have been newly
filled and require position assignment.
blocks: List of all blocks associated with the current request,
used to determine the order in which positions are assigned.
request: The Request object containing token information for
debugging purposes.

Note:
When VLLM_V1_SPANS_DEBUG is enabled, this function includes
debug logging that prints each block's tokens, to help
debug span-related workflows.
"""
pos = 0
for blk in blocks:
if blk in new_full_blocks:
blk.position = pos
if envs.VLLM_V1_SPANS_DEBUG:
# this prints the tokens assigned to a new block
# in the KV cache
blk_tks = request.all_token_ids[pos:pos + 16]
assert blk.block_hash is not None
bhash = str(abs(blk.block_hash.block_hash.hash_value)
)[:4] if blk.block_hash.block_hash else None
print('[SPANS -> block_pool] assigning to pos', pos,
'with hash', bhash, 'block: ', blk_tks)
pos += 16
if envs.VLLM_V1_SPANS_DEBUG:
print('[SPANS -> block_pool] assigned block count now ->',
len([b for b in self.blocks if b._block_hash]))

def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
"""Get new blocks from the free block pool.

Expand Down Expand Up @@ -261,8 +305,15 @@ def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
blocks_list = list(ordered_blocks)
for block in blocks_list:
block.ref_cnt -= 1
# remove duplicates (blocks can now appear twice)
block_ids = set()
blocks_list_filtered = []
for block in blocks_list:
if block.block_id not in block_ids:
blocks_list_filtered.append(block)
block_ids.add(block.block_id)
self.free_block_queue.append_n([
block for block in blocks_list
block for block in blocks_list_filtered
if block.ref_cnt == 0 and not block.is_null
])

Expand Down
Loading