Skip to content

Commit ec55021

Browse files
committed
refactor
Signed-off-by: Chen Zhang <[email protected]>
1 parent 4f81b65 commit ec55021

File tree

6 files changed

+350
-164
lines changed

6 files changed

+350
-164
lines changed

tests/v1/e2e/test_correctness_sliding_window.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@ class TestConfig:
1717

1818
model_config = {
1919
"bigcode/starcoder2-3b": TestConfig(4096, (800, 1100)),
20-
"google/gemma-2-2b-it": TestConfig(4096, (400, 800)),
20+
"google/gemma-3-1b-it": TestConfig(4096, (400, 800)),
2121
}
2222

2323

2424
@pytest.mark.parametrize(
2525
"model",
2626
[
2727
"bigcode/starcoder2-3b", # sliding window only
28-
"google/gemma-2-2b-it", # sliding window + full attention
28+
"google/gemma-3-1b-it", # sliding window + full attention
2929
])
3030
@pytest.mark.parametrize("batch_size", [5])
3131
@pytest.mark.parametrize("seed", [1])

vllm/v1/core/block_pool.py

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
BlockStored, KVCacheEvent)
88
from vllm.logger import init_logger
99
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
10-
GroupedKVCacheBlock, KVCacheBlock,
10+
KVCacheBlock, KVCacheBlockBundle,
1111
generate_block_hash_extra_keys,
1212
hash_block_tokens)
1313
from vllm.v1.request import Request
@@ -49,19 +49,20 @@ def __init__(
4949
# enabled).
5050
self.free_block_queue = FreeKVCacheBlockQueue(self.blocks)
5151

52-
# TODO: update comment
53-
# {manager_id: {block_hash: {block ID: GroupedKVCacheBlock}}}. A cached
54-
# block is a full block with a block hash that can be used for prefix
55-
# caching.
52+
# {manager_id: {block_hash: {block ID: KVCacheBlockBundle}}}.
53+
# A cached block is a full block with a block hash that can be used for
54+
# prefix caching.
5655
# The cached block may be used by running requests or in the
5756
# free_block_queue that could potentially be evicted.
57+
# Use KVCacheBlockBundle to make sure different kv cache groups managed
58+
# by the same single_type_manager are cached & evicted together.
5859
# NOTE: We currently don't de-duplicate the blocks in the cache,
5960
# meaning that if a block becomes full and is cached, we don't check
6061
# if there is already an identical block in the cache. This is because
6162
# we want to make sure the allocated block IDs won't change so that
6263
# block tables are append-only.
6364
self.cached_block_hash_to_block: list[dict[BlockHashType, dict[
64-
int, GroupedKVCacheBlock]]] = [
65+
int, KVCacheBlockBundle]]] = [
6566
defaultdict(dict) for _ in range(num_single_type_managers)
6667
]
6768
# To represent a placeholder block with block_id=0.
@@ -74,7 +75,7 @@ def __init__(
7475
self.kv_event_queue: list[KVCacheEvent] = []
7576

7677
def get_cached_block(self, block_hash: BlockHashType,
77-
manager_id: int) -> Optional[GroupedKVCacheBlock]:
78+
manager_id: int) -> Optional[KVCacheBlockBundle]:
7879
"""Get a cached block by the block hash, or None if cache miss.
7980
If there are duplicated blocks, we return the first block in the cache.
8081
@@ -95,7 +96,7 @@ def get_cached_block(self, block_hash: BlockHashType,
9596
def cache_full_blocks(
9697
self,
9798
request: Request,
98-
blocks: list[GroupedKVCacheBlock],
99+
blocks: list[KVCacheBlockBundle],
99100
block_hashes: list[BlockHashType],
100101
num_cached_blocks: int,
101102
num_full_blocks: int,
@@ -141,15 +142,14 @@ def cache_full_blocks(
141142
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events
142143
else None)
143144
for i, blk in enumerate(new_full_blocks):
144-
assert all(b.block_hash is None for b in blk.blocks)
145-
assert blk.block_hash is None
145+
assert blk.block_hash_is_none()
146146

147147
if i < len(new_block_hashes):
148148
# The block hash may already be computed in
149149
# "get_computed_blocks" if the tokens are not generated by
150150
# this request (either the prompt tokens or the previously
151-
# generated tokens with preemption).
152-
# TODO: or other groups with the same block_size
151+
# generated tokens with preemption), or by other
152+
# single_type_managers with the same block_size.
153153
# In this case we simply reuse the block hash.
154154
block_hash = new_block_hashes[i]
155155
else:
@@ -177,10 +177,7 @@ def cache_full_blocks(
177177
block_hashes.append(block_hash)
178178

179179
# Update and added the full block to the cache.
180-
for b in blk.blocks:
181-
b.block_hash = block_hash
182-
b.manager_id = manager_id
183-
blk.block_hash = block_hash
180+
blk.init_block_hash(block_hash, manager_id)
184181
self.cached_block_hash_to_block[manager_id][block_hash][
185182
blk.master_block_id] = blk
186183
if new_hashes is not None:
@@ -200,37 +197,46 @@ def cache_full_blocks(
200197
if request.lora_request else None,
201198
))
202199

203-
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
200+
def get_new_blocks(self, num_block_bundle: int,
201+
bundle_size: int) -> list[KVCacheBlockBundle]:
204202
"""Get new blocks from the free block pool.
205203
206204
Note that we do not check block cache in this function.
207205
208206
Args:
209-
num_blocks: The number of blocks to allocate.
207+
num_block_bundle: The number of KVCacheBlockBundle to allocate.
208+
bundle_size: The number of blocks in each KVCacheBlockBundle.
210209
211210
Returns:
212211
A list of new block.
213212
"""
214-
if num_blocks > self.get_num_free_blocks():
213+
num_total_blocks = num_block_bundle * bundle_size
214+
if num_total_blocks > self.get_num_free_blocks():
215215
raise ValueError(
216-
f"Cannot get {num_blocks} free blocks from the pool")
216+
f"Cannot get {num_total_blocks} free blocks from the pool")
217217

218-
ret: list[KVCacheBlock] = []
218+
flat_new_blocks: list[KVCacheBlock] = []
219219
idx = 0
220-
while idx < num_blocks:
220+
while idx < num_total_blocks:
221221
# First allocate blocks.
222222
curr_block = self.free_block_queue.popleft()
223-
assert curr_block.ref_cnt == 0
224223

225224
# If the block is cached, evict it.
226225
if self.enable_caching:
227226
self._maybe_evict_cached_block(curr_block)
228227

229-
curr_block.incr_ref()
230-
ret.append(curr_block)
228+
assert curr_block.block_hash is None
229+
flat_new_blocks.append(curr_block)
231230
idx += 1
232231

233-
return ret
232+
new_blocks = []
233+
for i in range(num_block_bundle):
234+
blocks = flat_new_blocks[i * bundle_size:(i + 1) * bundle_size]
235+
block_bundle = KVCacheBlockBundle.from_kv_cache_blocks(
236+
tuple(blocks))
237+
block_bundle.incr_ref()
238+
new_blocks.append(block_bundle)
239+
return new_blocks
234240

235241
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
236242
"""
@@ -249,8 +255,11 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
249255
manager_id]:
250256
cached_blocks = (
251257
self.cached_block_hash_to_block[manager_id][block_hash])
252-
assert block.block_id in cached_blocks
253-
cached_blocks[block.block_id].reset_hash()
258+
cached_block = cached_blocks[block.block_id]
259+
# TODO: add notes
260+
assert cached_block.master_block_id == block.block_id
261+
assert cached_block.ref_cnt == 0
262+
cached_block.reset_hash()
254263
del cached_blocks[block.block_id]
255264
if len(cached_blocks) == 0:
256265
del self.cached_block_hash_to_block[manager_id][block_hash]
@@ -260,26 +269,26 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
260269
return True
261270
return False
262271

263-
def touch(self, blocks: list[list[GroupedKVCacheBlock]]) -> None:
272+
def touch(self, blocks: list[list[KVCacheBlockBundle]]) -> None:
264273
"""Touch a block increases its reference count by 1, and may remove
265274
the block from the free queue. This is used when a block is hit by
266275
another request with the same prefix.
267276
268277
Args:
269278
blocks: A list of blocks to touch.
270279
"""
271-
# TODO: check whether we should manage ref_cnt at grouped_block level
272280
for blocks_one_manager in blocks:
273-
for grouped_block in blocks_one_manager:
274-
for block in grouped_block.blocks:
275-
# ref_cnt=0 means this block is in the free list (i.e.
276-
# eviction candidate), so remove it.
277-
if block.ref_cnt == 0 and block != self.null_block:
278-
self.free_block_queue.remove(block)
279-
block.incr_ref()
281+
for block_bundle in blocks_one_manager:
282+
if block_bundle.ref_cnt == 0:
283+
# ref_cnt=0 means the blocks are in the free list (i.e.
284+
# eviction candidate), so remove them.
285+
for block in block_bundle.blocks:
286+
if block != self.null_block:
287+
self.free_block_queue.remove(block)
288+
block_bundle.incr_ref()
280289

281290
def free_blocks(self,
282-
ordered_blocks: Iterable[GroupedKVCacheBlock]) -> None:
291+
ordered_blocks: Iterable[KVCacheBlockBundle]) -> None:
283292
"""Free a list of blocks. The blocks should be ordered by their
284293
eviction priority, where the first block will be evicted first.
285294
@@ -288,11 +297,13 @@ def free_blocks(self,
288297
priority.
289298
"""
290299
# TODO: make sure blocks in the first group are evicted first
291-
for blk in ordered_blocks:
292-
for block in blk.blocks:
293-
block.decr_ref()
300+
for block_bundle in ordered_blocks:
301+
block_bundle.decr_ref()
302+
if block_bundle.ref_cnt > 0:
303+
continue
304+
for block in block_bundle.blocks:
294305
# null_block should not be added to the free list.
295-
if block.ref_cnt == 0 and block != self.null_block:
306+
if block != self.null_block:
296307
self.free_block_queue.append(block)
297308

298309
def reset_prefix_cache(self) -> bool:

0 commit comments

Comments
 (0)