@@ -121,6 +121,13 @@ def __init__(
121121 compilation_config .static_forward_context [prefix ] = self
122122 self .layer_name = prefix
123123 self .attn_type = attn_type
124+ # use a placeholder kv cache tensor during init, which will be replaced
125+ # by bind_kv_cache
126+ # this variable will not be accessed if use_direct_call is True
127+ self .kv_cache = [
128+ torch .tensor ([]) for _ in range (get_current_vllm_config (
129+ ).parallel_config .pipeline_parallel_size )
130+ ]
124131
125132 def forward (
126133 self ,
@@ -148,11 +155,11 @@ def forward(
148155 if value is not None :
149156 value = value .view (- 1 , self .num_kv_heads , self .head_size )
150157 torch .ops .vllm .unified_attention_with_output (
151- query , key , value , output , kv_cache , self .layer_name )
158+ query , key , value , output , self .layer_name )
152159 return output .view (- 1 , hidden_size )
153160 else :
154161 return torch .ops .vllm .unified_attention (query , key , value ,
155- kv_cache , self .layer_name )
162+ self .layer_name )
156163
157164 def extra_repr (self ) -> str :
158165 s = f"head_size={ self .impl .head_size } " # type: ignore
@@ -230,12 +237,12 @@ def unified_attention(
230237 query : torch .Tensor ,
231238 key : torch .Tensor ,
232239 value : torch .Tensor ,
233- kv_cache : torch .Tensor ,
234240 layer_name : str ,
235241) -> torch .Tensor :
236242 forward_context : ForwardContext = get_forward_context ()
237- attn_metadata = forward_context .dynamic_forward_context
238- self = forward_context .static_forward_context [layer_name ]
243+ attn_metadata = forward_context .attn_metadata
244+ self = forward_context .attn_layers [layer_name ]
245+ kv_cache = self .kv_cache [forward_context .virtual_engine ]
239246 return self .impl .forward (query , key , value , kv_cache , attn_metadata ,
240247 self ._k_scale , self ._v_scale )
241248
@@ -244,7 +251,6 @@ def unified_attention_fake(
244251 query : torch .Tensor ,
245252 key : torch .Tensor ,
246253 value : torch .Tensor ,
247- kv_cache : torch .Tensor ,
248254 layer_name : str ,
249255) -> torch .Tensor :
250256 return torch .empty_like (query ).contiguous ()
@@ -253,7 +259,7 @@ def unified_attention_fake(
253259direct_register_custom_op (
254260 op_name = "unified_attention" ,
255261 op_func = unified_attention ,
256- mutates_args = ["kv_cache" ],
262+ mutates_args = [],
257263 fake_impl = unified_attention_fake ,
258264 dispatch_key = current_platform .dispatch_key ,
259265)
@@ -264,12 +270,12 @@ def unified_attention_with_output(
264270 key : torch .Tensor ,
265271 value : torch .Tensor ,
266272 output : torch .Tensor ,
267- kv_cache : torch .Tensor ,
268273 layer_name : str ,
269274) -> None :
270275 forward_context : ForwardContext = get_forward_context ()
271- attn_metadata = forward_context .dynamic_forward_context
272- self = forward_context .static_forward_context [layer_name ]
276+ attn_metadata = forward_context .attn_metadata
277+ self = forward_context .attn_layers [layer_name ]
278+ kv_cache = self .kv_cache [forward_context .virtual_engine ]
273279 self .impl .forward (query ,
274280 key ,
275281 value ,
@@ -285,7 +291,6 @@ def unified_attention_with_output_fake(
285291 key : torch .Tensor ,
286292 value : torch .Tensor ,
287293 output : torch .Tensor ,
288- kv_cache : torch .Tensor ,
289294 layer_name : str ,
290295) -> None :
291296 return
@@ -294,7 +299,7 @@ def unified_attention_with_output_fake(
294299direct_register_custom_op (
295300 op_name = "unified_attention_with_output" ,
296301 op_func = unified_attention_with_output ,
297- mutates_args = ["kv_cache" , " output" ],
302+ mutates_args = ["output" ],
298303 fake_impl = unified_attention_with_output_fake ,
299304 dispatch_key = current_platform .dispatch_key ,
300305)
0 commit comments