@@ -147,14 +147,11 @@ async def main():
147147 asyncio .run (main ())
148148
149149
150- def test_llm_kv_events_api ():
151- llm = create_llm ()
152- sampling_params = SamplingParams (max_tokens = 6 , temperature = 0.01 )
153-
154- requests = []
155- for i in range (3 ):
156- input_tokens = list (range (127 + i ))[i :]
157- requests .append (input_tokens )
150+ def check_events (llm ,
151+ requests ,
152+ sampling_params ,
153+ scheduling_params = None ,
154+ attention_dp_rank = None ):
158155
159156 _ = llm .generate (requests [0 ], sampling_params = sampling_params )
160157 events1 = llm .get_kv_cache_events (5 )
@@ -163,52 +160,95 @@ def test_llm_kv_events_api():
163160 event = events1 .pop (0 ) # created event
164161 while events1 :
165162 event = events1 .pop (0 )
163+ print ("event1:" , event )
166164 if event :
167165 assert event ["event_id" ] == 1
168166 assert event ["data" ]["type" ] == "stored"
169167 assert len (event ["data" ]["blocks" ]) == 5
168+ if attention_dp_rank :
169+ assert event ["data" ]["attention_dp_rank" ] == attention_dp_rank
170170
171171 _ = llm .generate (requests [1 ], sampling_params = sampling_params )
172172 events2 = llm .get_kv_cache_events (5 )
173173
174174 while events2 :
175175 event = events2 .pop (0 )
176+ print ("event2:" , event )
176177 if event :
177178 if event ["event_id" ] == 2 :
178179 # 2 removed events needed
179180 # should be a removed event to make space for context block
180181 assert event ["data" ]["type" ] == "removed"
181182 assert event ["data" ]["block_hashes" ]
183+ if attention_dp_rank :
184+ assert event ["data" ][
185+ "attention_dp_rank" ] == attention_dp_rank
182186 elif event ["event_id" ] == 3 :
183187 assert event ["data" ]["type" ] == "removed"
184188 assert event ["data" ]["block_hashes" ]
189+ if attention_dp_rank :
190+ assert event ["data" ][
191+ "attention_dp_rank" ] == attention_dp_rank
185192 # stored event for 2nd request
186193 elif event ["event_id" ] == 4 :
187194 assert event ["data" ]["type" ] == "stored"
188195 assert len (event ["data" ]["blocks" ]) == 5
196+ if attention_dp_rank :
197+ assert event ["data" ][
198+ "attention_dp_rank" ] == attention_dp_rank
189199
190200 _ = llm .generate (requests [2 ], sampling_params = sampling_params )
191201 events3 = llm .get_kv_cache_events (5 )
192202
193203 while events3 :
194204 event = events3 .pop (0 )
205+ print ("event3:" , event )
195206 if event :
196207 if event ["event_id" ] == 5 :
197208 assert event ["data" ]["type" ] == "removed"
198209 assert event ["data" ]["block_hashes" ]
210+ if attention_dp_rank :
211+ assert event ["data" ][
212+ "attention_dp_rank" ] == attention_dp_rank
199213 elif event ["event_id" ] == 6 :
200214 assert event ["data" ]["type" ] == "removed"
201215 assert event ["data" ]["block_hashes" ]
216+ if attention_dp_rank :
217+ assert event ["data" ][
218+ "attention_dp_rank" ] == attention_dp_rank
202219 elif event ["event_id" ] == 7 :
203220 assert event ["data" ]["type" ] == "stored"
204221 assert len (event ["data" ]["blocks" ]) == 5
222+ if attention_dp_rank :
223+ assert event ["data" ][
224+ "attention_dp_rank" ] == attention_dp_rank
205225
206226 # no more events after request is finished
207227 assert not llm .get_kv_cache_events (5 )
208228
209229
230+ def test_llm_kv_events_api ():
231+ llm = create_llm ()
232+ sampling_params = SamplingParams (max_tokens = 6 , temperature = 0.01 )
233+
234+ requests = []
235+ for i in range (3 ):
236+ input_tokens = list (range (127 + i ))[i :]
237+ requests .append (input_tokens )
238+
239+ check_events (llm , requests , sampling_params )
240+
241+
210242@skip_single_gpu
211243def test_llm_api_attention_dp_kv_events ():
244+
245+ kvcache_config = KvCacheConfig (free_gpu_memory_fraction = 0.4 ,
246+ event_buffer_max_size = 1024 ,
247+ attention_dp_events_gather_period_ms = 10 ,
248+ enable_block_reuse = True ,
249+ onboard_blocks = True ,
250+ max_tokens = 256 )
251+
212252 llm = LLM (model = llama_model_path ,
213253 tensor_parallel_size = 2 ,
214254 enable_attention_dp = True ,
@@ -217,59 +257,16 @@ def test_llm_api_attention_dp_kv_events():
217257
218258 sampling_params = SamplingParams (max_tokens = 6 , temperature = 0.01 )
219259
220- requests = []
221- for i in range (3 ):
222- input_tokens = list (range (127 + i ))[i :]
223- requests .append (input_tokens )
224-
225- _ = llm .generate (requests [0 ], sampling_params = sampling_params )
226- events1 = llm .get_kv_cache_events (5 )
227-
228- # Should have 1 stored event and 1 created event
229- event = events1 .pop (0 ) # created event
230- while events1 :
231- event = events1 .pop (0 )
232- if event :
233- assert event ["event_id" ] == 1
234- assert event ["data" ]["type" ] == "stored"
235- assert event ["attention_dp_rank" ] == 0
236- assert event ["window_size" ] == 32
237- assert len (event ["data" ]["blocks" ]) == 5
260+ for attention_dp_rank in range (2 ):
261+ requests = []
262+ for i in range (3 ):
263+ input_tokens = list (range (127 + i ))[i :]
264+ requests .append (input_tokens )
238265
239- _ = llm . generate ( requests [ 1 ], sampling_params = sampling_params )
240- events2 = llm . get_kv_cache_events ( 5 )
266+ scheduling_params = SchedulingParams (
267+ attention_dp_rank = attention_dp_rank , attention_dp_relax = False )
241268
242- while events2 :
243- event = events2 .pop (0 )
244- if event :
245- if event ["event_id" ] == 2 :
246- # 2 removed events needed
247- # should be a removed event to make space for context block
248- assert event ["data" ]["type" ] == "removed"
249- assert event ["data" ]["block_hashes" ]
250- elif event ["event_id" ] == 3 :
251- assert event ["data" ]["type" ] == "removed"
252- assert event ["data" ]["block_hashes" ]
253- # stored event for 2nd request
254- elif event ["event_id" ] == 4 :
255- assert event ["data" ]["type" ] == "stored"
256- assert len (event ["data" ]["blocks" ]) == 5
269+ check_events (llm , requests , sampling_params , scheduling_params ,
270+ attention_dp_rank )
257271
258- #_ = llm.generate(requests[2], sampling_params=sampling_params)
259- #events3 = llm.get_kv_cache_events(5)
260-
261- #while events3:
262- # event = events3.pop(0)
263- # if event:
264- # if event["event_id"] == 5:
265- # assert event["data"]["type"] == "removed"
266- # assert event["data"]["block_hashes"]
267- # elif event["event_id"] == 6:
268- # assert event["data"]["type"] == "removed"
269- # assert event["data"]["block_hashes"]
270- # elif event["event_id"] == 7:
271- # assert event["data"]["type"] == "stored"
272- # assert len(event["data"]["blocks"]) == 5
273-
274- ## no more events after request is finished
275- #assert not llm.get_kv_cache_events(5)
272+ time .sleep (5 )
0 commit comments