66import pytest
77import torch
88
9+ from vllm .config import ParallelConfig , VllmConfig , set_current_vllm_config
910from vllm .utils import (FlexibleArgumentParser , StoreBoolean , bind_kv_cache ,
1011 deprecate_kwargs , get_open_port , memory_profiling ,
1112 merge_async_iterators , supports_kw )
@@ -323,11 +324,11 @@ def test_bind_kv_cache():
323324 torch .zeros ((1 , )),
324325 torch .zeros ((1 , )),
325326 ]
326- bind_kv_cache (ctx , kv_cache )
327- assert ctx ['layers.0.self_attn' ].kv_cache is kv_cache [0 ]
328- assert ctx ['layers.1.self_attn' ].kv_cache is kv_cache [1 ]
329- assert ctx ['layers.2.self_attn' ].kv_cache is kv_cache [2 ]
330- assert ctx ['layers.3.self_attn' ].kv_cache is kv_cache [3 ]
327+ bind_kv_cache (ctx , [ kv_cache ] )
328+ assert ctx ['layers.0.self_attn' ].kv_cache [ 0 ] is kv_cache [0 ]
329+ assert ctx ['layers.1.self_attn' ].kv_cache [ 0 ] is kv_cache [1 ]
330+ assert ctx ['layers.2.self_attn' ].kv_cache [ 0 ] is kv_cache [2 ]
331+ assert ctx ['layers.3.self_attn' ].kv_cache [ 0 ] is kv_cache [3 ]
331332
332333def test_bind_kv_cache_non_attention ():
333334 from vllm .attention import Attention
@@ -341,9 +342,9 @@ def test_bind_kv_cache_non_attention():
341342 torch .zeros ((1 , )),
342343 torch .zeros ((1 , )),
343344 ]
344- bind_kv_cache (ctx , kv_cache )
345- assert ctx ['model.layers.20.attn' ].kv_cache is kv_cache [0 ]
346- assert ctx ['model.layers.28.attn' ].kv_cache is kv_cache [1 ]
345+ bind_kv_cache (ctx , [ kv_cache ] )
346+ assert ctx ['model.layers.20.attn' ].kv_cache [ 0 ] is kv_cache [0 ]
347+ assert ctx ['model.layers.28.attn' ].kv_cache [ 0 ] is kv_cache [1 ]
347348
348349
349350def test_bind_kv_cache_encoder_decoder ():
@@ -364,7 +365,24 @@ def test_bind_kv_cache_encoder_decoder():
364365 ]
365366 encoder_kv_cache = ctx ['encoder.layers.0.self_attn.attn' ].kv_cache
366367
367- bind_kv_cache (ctx , kv_cache )
368+ bind_kv_cache (ctx , [ kv_cache ] )
368369 assert ctx ['encoder.layers.0.self_attn.attn' ].kv_cache is encoder_kv_cache
369- assert ctx ['decoder.layers.0.encoder_attn.attn' ].kv_cache is kv_cache [0 ]
370- assert ctx ['decoder.layers.0.self_attn.attn' ].kv_cache is kv_cache [0 ]
370+ assert ctx ['decoder.layers.0.encoder_attn.attn' ].kv_cache [0 ] is kv_cache [0 ]
371+ assert ctx ['decoder.layers.0.self_attn.attn' ].kv_cache [0 ] is kv_cache [0 ]
372+
373+
374+ def test_bind_kv_cache_pp ():
375+ cfg = VllmConfig (parallel_config = ParallelConfig (pipeline_parallel_size = 2 ))
376+ with set_current_vllm_config (cfg ):
377+ from vllm .attention import Attention
378+
379+ ctx = {
380+ 'layers.0.self_attn' : Attention (32 , 128 , 0.1 ),
381+ }
382+ kv_cache = [
383+ [torch .zeros ((1 , ))],
384+ [torch .zeros ((1 , ))]
385+ ]
386+ bind_kv_cache (ctx , kv_cache )
387+ assert ctx ['layers.0.self_attn' ].kv_cache [0 ] is kv_cache [0 ][0 ]
388+ assert ctx ['layers.0.self_attn' ].kv_cache [1 ] is kv_cache [1 ][0 ]
0 commit comments