Skip to content

Commit 593b48c

Browse files
orozerychoprahetarth
authored andcommitted
[KV offload][1/N] Introduce an offloading component (vllm-project#19848)
Signed-off-by: Or Ozeri <[email protected]>
1 parent c3998d4 commit 593b48c

File tree

5 files changed

+499
-0
lines changed

5 files changed

+499
-0
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ steps:
280280
# split the test to avoid interference
281281
- pytest -v -s v1/core
282282
- pytest -v -s v1/executor
283+
- pytest -v -s v1/offloading
283284
- pytest -v -s v1/sample
284285
- pytest -v -s v1/logits_processors
285286
- pytest -v -s v1/worker

tests/v1/offloading/test_worker.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from vllm.v1.offloading.abstract import LoadStoreSpec
4+
from vllm.v1.offloading.worker.worker import (OffloadingHandler,
5+
OffloadingWorker, TransferResult,
6+
TransferSpec)
7+
8+
9+
class LoadStoreSpec1(LoadStoreSpec):
10+
11+
def __init__(self,
12+
submit_success: bool = True,
13+
async_success: bool = True,
14+
exception: bool = False):
15+
self.finished = False
16+
self.submit_success = submit_success
17+
self.async_success = async_success
18+
self.exception = exception
19+
20+
@staticmethod
21+
def medium() -> str:
22+
return "1"
23+
24+
def __repr__(self):
25+
return f"{self.medium()}: {id(self)}"
26+
27+
28+
class LoadStoreSpec2(LoadStoreSpec):
29+
30+
@staticmethod
31+
def medium() -> str:
32+
return "2"
33+
34+
def __repr__(self):
35+
return f"{self.medium()}: {id(self)}"
36+
37+
38+
class OffloadingHandler1To2(OffloadingHandler):
39+
40+
def __init__(self):
41+
self.transfers: dict[int, LoadStoreSpec1] = {}
42+
43+
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
44+
src, dst = spec
45+
assert isinstance(src, LoadStoreSpec1)
46+
assert isinstance(dst, LoadStoreSpec2)
47+
48+
if src.exception:
49+
raise Exception("An expected exception. Don't worry!")
50+
if not src.submit_success:
51+
return False
52+
53+
self.transfers[job_id] = src
54+
return True
55+
56+
def get_finished(self) -> list[TransferResult]:
57+
finished = []
58+
for job_id, spec in list(self.transfers.items()):
59+
if spec.finished:
60+
finished.append((job_id, spec.async_success))
61+
del self.transfers[job_id]
62+
return finished
63+
64+
65+
class OffloadingHandler2To1(OffloadingHandler):
66+
67+
def __init__(self):
68+
self.transfers: dict[int, LoadStoreSpec1] = {}
69+
70+
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
71+
src, dst = spec
72+
assert isinstance(src, LoadStoreSpec2)
73+
assert isinstance(dst, LoadStoreSpec1)
74+
75+
self.transfers[job_id] = dst
76+
return True
77+
78+
def get_finished(self) -> list[TransferResult]:
79+
finished = []
80+
for job_id, spec in list(self.transfers.items()):
81+
if spec.finished:
82+
finished.append((job_id, spec.async_success))
83+
del self.transfers[job_id]
84+
return finished
85+
86+
87+
def test_offloading_worker():
88+
"""
89+
Tests OffloadingWorker with 2 handlers.
90+
One handler performs 1->2 transfers, and the other handles 2->1.
91+
"""
92+
worker = OffloadingWorker()
93+
handler1to2 = OffloadingHandler1To2()
94+
handler2to1 = OffloadingHandler2To1()
95+
worker.register_handler(LoadStoreSpec1, LoadStoreSpec2, handler1to2)
96+
worker.register_handler(LoadStoreSpec2, LoadStoreSpec1, handler2to1)
97+
98+
# 1st transfer 1->2 (exception)
99+
src1 = LoadStoreSpec1(exception=True)
100+
dst1 = LoadStoreSpec2()
101+
assert not worker.transfer_async(1, (src1, dst1))
102+
103+
# 2ed transfer 1->2 (failure to submit)
104+
src2 = LoadStoreSpec1(submit_success=False)
105+
dst2 = LoadStoreSpec2()
106+
assert not worker.transfer_async(2, (src2, dst2))
107+
108+
# 3rd transfer 1->2 (failure)
109+
src3 = LoadStoreSpec1(async_success=False)
110+
dst3 = LoadStoreSpec2()
111+
assert worker.transfer_async(3, (src3, dst3))
112+
113+
# 4th transfer 1->2 (success)
114+
src4 = LoadStoreSpec1()
115+
dst4 = LoadStoreSpec2()
116+
worker.transfer_async(4, (src4, dst4))
117+
assert set(handler1to2.transfers.keys()) == {3, 4}
118+
119+
# 5th transfer 2->1
120+
src5 = LoadStoreSpec2()
121+
dst5 = LoadStoreSpec1()
122+
worker.transfer_async(5, (src5, dst5))
123+
assert set(handler2to1.transfers.keys()) == {5}
124+
125+
# no transfer completed yet
126+
assert worker.get_finished() == []
127+
128+
# complete 3rd, 4th
129+
src3.finished = True
130+
src4.finished = True
131+
132+
# 6th transfer 1->2
133+
src6 = LoadStoreSpec1()
134+
dst6 = LoadStoreSpec2()
135+
worker.transfer_async(6, (src6, dst6))
136+
137+
# 7th transfer 2->1
138+
src7 = LoadStoreSpec2()
139+
dst7 = LoadStoreSpec1()
140+
worker.transfer_async(7, (src7, dst7))
141+
142+
# 6th and 7th transfers started
143+
assert 6 in handler1to2.transfers
144+
assert 7 in handler2to1.transfers
145+
146+
# verify result of 3rd and 4th transfers
147+
assert (sorted(worker.get_finished()) == [(3, False), (4, True)])
148+
149+
# complete 6th and 7th transfers
150+
src6.finished = True
151+
dst7.finished = True
152+
assert (sorted(worker.get_finished()) == [(6, True), (7, True)])

vllm/v1/offloading/abstract.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
OffloadingManager class for managing KV data offloading in vLLM v1
5+
6+
This class runs in the scheduler, tracks which blocks are offloaded
7+
and their address.
8+
9+
The class provides the following primitives:
10+
lookup() - find the length of the maximal series of blocks,
11+
starting from the first one, that are all offloaded.
12+
prepare_load() - prepare given blocks to be read.
13+
The given blocks will be protected from eviction.
14+
This function returns a LoadSpec which encapsulates
15+
information required for performing the load.
16+
touch() - marks the give blocks as recently used. Can be used
17+
to track block's LRU. This function is separated from the
18+
prepare_load function to allow setting block recency even
19+
for blocks which do not need reading from the cache, such as
20+
blocks that are cached by the GPU prefix cache.
21+
complete_load() - mark blocks which were previously prepared to be
22+
loaded as done loading. This is to re-allow their eviction.
23+
prepare_store() - prepare the given blocks to be written.
24+
Returns a StoreSpec encapsulating offloading information,
25+
as well as a list of blocks that were evicted as a result.
26+
complete_store() - marks a previous store as completed.
27+
Following this call, the given blocks will become loadable.
28+
"""
29+
30+
from abc import ABC, abstractmethod
31+
from collections.abc import Iterable
32+
from dataclasses import dataclass
33+
from typing import Optional
34+
35+
from vllm.v1.core.kv_cache_utils import BlockHash
36+
37+
38+
class LoadStoreSpec(ABC):
39+
"""
40+
Abstract metadata that encapsulates information allowing a worker
41+
to load, and optionally also to store, blocks of KV data.
42+
"""
43+
44+
@staticmethod
45+
@abstractmethod
46+
def medium() -> str:
47+
"""
48+
Returns a string representation of the medium type
49+
this store/load targets.
50+
"""
51+
pass
52+
53+
54+
@dataclass
55+
class PrepareStoreOutput:
56+
block_hashes_to_store: list[BlockHash]
57+
store_spec: LoadStoreSpec
58+
block_hashes_evicted: list[BlockHash]
59+
60+
61+
@dataclass
62+
class OffloadingEvent:
63+
block_hashes: list[BlockHash]
64+
block_size: int
65+
medium: str
66+
# True if blocks are removed, False if stored
67+
removed: bool
68+
69+
70+
class OffloadingManager(ABC):
71+
72+
@abstractmethod
73+
def lookup(self, block_hashes: Iterable[BlockHash]) -> int:
74+
"""
75+
Finds the length of the maximal series of blocks, starting from the
76+
first one, that are all offloaded.
77+
78+
Args:
79+
block_hashes: the hashes identifying the blocks to lookup.
80+
81+
Returns:
82+
An integer representing the maximal number of blocks that
83+
are currently offloaded.
84+
"""
85+
pass
86+
87+
@abstractmethod
88+
def prepare_load(self, block_hashes: Iterable[BlockHash]) -> LoadStoreSpec:
89+
"""
90+
Prepare the given blocks to be read.
91+
The given blocks will be protected from eviction until
92+
complete_load is called.
93+
It assumes all given blocks are offloaded.
94+
95+
Args:
96+
block_hashes: the hashes identifying the blocks.
97+
98+
Returns:
99+
A LoadStoreSpec that can be used by a worker to locate and load
100+
the actual offloaded KV data.
101+
"""
102+
pass
103+
104+
def touch(self, block_hashes: Iterable[BlockHash]):
105+
"""
106+
Mark the given blocks as recently used.
107+
This could in practice mean moving them to the end of an LRU list.
108+
109+
Args:
110+
block_hashes: the hashes identifying the blocks.
111+
"""
112+
return
113+
114+
def complete_load(self, block_hashes: Iterable[BlockHash]):
115+
"""
116+
Marks previous blocks that were prepared to load as done loading.
117+
118+
Args:
119+
block_hashes: the hashes identifying the blocks.
120+
"""
121+
return
122+
123+
@abstractmethod
124+
def prepare_store(
125+
self,
126+
block_hashes: Iterable[BlockHash]) -> Optional[PrepareStoreOutput]:
127+
"""
128+
Prepare the given blocks to be offloaded.
129+
The given blocks will be protected from eviction until
130+
complete_store is called.
131+
132+
Args:
133+
block_hashes: the hashes identifying the blocks.
134+
135+
Returns:
136+
A PrepareStoreOutput indicating which blocks need storing,
137+
where to store them (LoadStoreSpec), and list of blocks that
138+
were evicted as a result.
139+
None is returned if the blocks cannot be stored.
140+
"""
141+
pass
142+
143+
def complete_store(self,
144+
block_hashes: Iterable[BlockHash],
145+
success: bool = True):
146+
"""
147+
Marks blocks which were previously prepared to be stored, as stored.
148+
Following this call, the blocks become loadable.
149+
If if_success is False, blocks that were not marked as stored will be
150+
removed.
151+
152+
Args:
153+
block_hashes: the hashes identifying the blocks.
154+
success: whether the blocks were stored successfully.
155+
"""
156+
return
157+
158+
def take_events(self) -> Iterable[OffloadingEvent]:
159+
"""
160+
Take the offloading events from the manager.
161+
162+
Yields:
163+
New OffloadingEvents collected since the last call.
164+
"""
165+
return ()

vllm/v1/offloading/mediums.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from abc import ABC
4+
5+
import numpy as np
6+
7+
from vllm.v1.offloading.abstract import LoadStoreSpec
8+
9+
10+
class BlockIDsLoadStoreSpec(LoadStoreSpec, ABC):
11+
"""
12+
Spec for loading/storing KV blocks from given block numbers.
13+
"""
14+
15+
def __init__(self, block_ids: list[int]):
16+
self.block_ids = np.array(block_ids, dtype=np.int64)
17+
18+
def __repr__(self) -> str:
19+
return repr(self.block_ids)
20+
21+
22+
class GPULoadStoreSpec(BlockIDsLoadStoreSpec):
23+
"""
24+
Spec for loading/storing a KV block to GPU memory.
25+
"""
26+
27+
@staticmethod
28+
def medium() -> str:
29+
return "GPU"
30+
31+
32+
class CPULoadStoreSpec(BlockIDsLoadStoreSpec):
33+
"""
34+
Spec for loading/storing a KV block to CPU memory.
35+
"""
36+
37+
@staticmethod
38+
def medium() -> str:
39+
return "CPU"

0 commit comments

Comments
 (0)