@@ -209,6 +209,95 @@ def update(
209209 return k_out , v_out
210210
211211
212+ class SDPA (nn .Module ):
213+ def __init__ (
214+ self ,
215+ kv_cache : KVCache ,
216+ mask ,
217+ use_sdpa_with_kv_cache_op : bool ,
218+ dim : int ,
219+ n_rep : int ,
220+ ):
221+ super ().__init__ ()
222+ self .kv_cache = kv_cache
223+ self .mask = mask
224+ self .use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
225+ self .dim = dim
226+ self .n_rep = n_rep
227+
228+ def forward (
229+ self ,
230+ input_pos : torch .Tensor ,
231+ q : torch .Tensor ,
232+ k : torch .Tensor ,
233+ v : torch .Tensor ,
234+ bsz ,
235+ seqlen ,
236+ ) -> torch .Tensor :
237+ if not self .use_sdpa_with_kv_cache_op :
238+ return self ._forward_default (
239+ input_pos ,
240+ q ,
241+ k ,
242+ v ,
243+ bsz ,
244+ seqlen ,
245+ )
246+ else :
247+ return self ._forward_custom (
248+ input_pos ,
249+ q ,
250+ k ,
251+ v ,
252+ bsz ,
253+ seqlen ,
254+ )
255+
256+ def _forward_custom (
257+ self ,
258+ input_pos : torch .Tensor ,
259+ q : torch .Tensor ,
260+ k : torch .Tensor ,
261+ v : torch .Tensor ,
262+ bsz ,
263+ seqlen ,
264+ ):
265+ from .custom_ops import sdpa_with_kv_cache # noqa
266+
267+ output = torch .ops .llama .sdpa_with_kv_cache (
268+ q ,
269+ k ,
270+ v ,
271+ self .kv_cache .k_cache ,
272+ self .kv_cache .v_cache ,
273+ input_pos [- 1 ].item (),
274+ seqlen ,
275+ )
276+ return output .view (bsz , seqlen , self .dim )
277+
278+ def _forward_default (
279+ self ,
280+ input_pos : torch .Tensor ,
281+ q : torch .Tensor ,
282+ k : torch .Tensor ,
283+ v : torch .Tensor ,
284+ bsz ,
285+ seqlen ,
286+ ) -> torch .Tensor :
287+ q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
288+ k = k .transpose (1 , 2 )
289+ v = v .transpose (1 , 2 )
290+
291+ k , v = self .kv_cache .update (input_pos , k , v )
292+ mask = self .mask [None , None , input_pos ]
293+
294+ k = k .repeat_interleave (self .n_rep , dim = 1 )
295+ v = v .repeat_interleave (self .n_rep , dim = 1 )
296+ y = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
297+
298+ return y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
299+
300+
212301class Attention (nn .Module ):
213302 def __init__ (self , args : ModelArgs , layer_id : int ):
214303 super ().__init__ ()
@@ -229,7 +318,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
229318 self .wv = nn .Linear (args .dim , self .n_kv_heads * self .head_dim , bias = False )
230319 self .wo = nn .Linear (args .n_heads * self .head_dim , args .dim , bias = False )
231320
232- self .use_sdpa_with_kv_cache_op = args .use_sdpa_with_kv_cache_op
233321 self .layer_id = layer_id
234322
235323 causal_mask = torch .tril (
@@ -250,6 +338,13 @@ def __init__(self, args: ModelArgs, layer_id: int):
250338 self .head_dim ,
251339 not args .use_sdpa_with_kv_cache_op , # if we are using the custom op dont transpose the cache. Expect untransposed q k v
252340 )
341+ self .SDPA = SDPA (
342+ self .kv_cache ,
343+ self .mask ,
344+ args .use_sdpa_with_kv_cache_op ,
345+ self .dim ,
346+ self .n_rep ,
347+ )
253348
254349 def forward (
255350 self ,
@@ -272,41 +367,8 @@ def forward(
272367
273368 if self .use_kv_cache :
274369 assert input_pos is not None
275-
276- if not self .use_sdpa_with_kv_cache_op :
277-
278- q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
279- k = k .transpose (1 , 2 )
280- v = v .transpose (1 , 2 )
281-
282- k , v = self .kv_cache .update (input_pos , k , v )
283- mask = self .mask [None , None , input_pos ]
284-
285- k = k .repeat_interleave (self .n_rep , dim = 1 )
286- v = v .repeat_interleave (self .n_rep , dim = 1 )
287- y = F .scaled_dot_product_attention (
288- q , k , v , attn_mask = mask , dropout_p = 0.0
289- )
290-
291- y = y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
292-
293- y = self .wo (y )
294- return y
295- else :
296- from .custom_ops import sdpa_with_kv_cache # noqa
297-
298- output = torch .ops .llama .sdpa_with_kv_cache (
299- q ,
300- k ,
301- v ,
302- self .kv_cache .k_cache ,
303- self .kv_cache .v_cache ,
304- input_pos [- 1 ].item (),
305- seqlen ,
306- )
307- output = output .view (bsz , seqlen , - 1 )
308- output = self .wo (output )
309- return output
370+ output = self .SDPA (input_pos , q , k , v , bsz , seqlen )
371+ return self .wo (output )
310372
311373 q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
312374 k = k .transpose (1 , 2 )
0 commit comments