@@ -193,6 +193,95 @@ def update(
193193 return k_out , v_out
194194
195195
196+ class SDPA (nn .Module ):
197+ def __init__ (
198+ self ,
199+ kv_cache : KVCache ,
200+ mask ,
201+ use_sdpa_with_kv_cache_op : bool ,
202+ dim : int ,
203+ n_rep : int ,
204+ ):
205+ super ().__init__ ()
206+ self .kv_cache = kv_cache
207+ self .mask = mask
208+ self .use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
209+ self .dim = dim
210+ self .n_rep = n_rep
211+
212+ def forward (
213+ self ,
214+ input_pos : torch .Tensor ,
215+ q : torch .Tensor ,
216+ k : torch .Tensor ,
217+ v : torch .Tensor ,
218+ bsz ,
219+ seqlen ,
220+ ) -> torch .Tensor :
221+ if not self .use_sdpa_with_kv_cache_op :
222+ return self ._forward_default (
223+ input_pos ,
224+ q ,
225+ k ,
226+ v ,
227+ bsz ,
228+ seqlen ,
229+ )
230+ else :
231+ return self ._forward_custom (
232+ input_pos ,
233+ q ,
234+ k ,
235+ v ,
236+ bsz ,
237+ seqlen ,
238+ )
239+
240+ def _forward_custom (
241+ self ,
242+ input_pos : torch .Tensor ,
243+ q : torch .Tensor ,
244+ k : torch .Tensor ,
245+ v : torch .Tensor ,
246+ bsz ,
247+ seqlen ,
248+ ):
249+ from .custom_ops import sdpa_with_kv_cache # noqa
250+
251+ output = torch .ops .llama .sdpa_with_kv_cache (
252+ q ,
253+ k ,
254+ v ,
255+ self .kv_cache .k_cache ,
256+ self .kv_cache .v_cache ,
257+ input_pos [- 1 ].item (),
258+ seqlen ,
259+ )
260+ return output .view (bsz , seqlen , self .dim )
261+
262+ def _forward_default (
263+ self ,
264+ input_pos : torch .Tensor ,
265+ q : torch .Tensor ,
266+ k : torch .Tensor ,
267+ v : torch .Tensor ,
268+ bsz ,
269+ seqlen ,
270+ ) -> torch .Tensor :
271+ q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
272+ k = k .transpose (1 , 2 )
273+ v = v .transpose (1 , 2 )
274+
275+ k , v = self .kv_cache .update (input_pos , k , v )
276+ mask = self .mask [None , None , input_pos ]
277+
278+ k = k .repeat_interleave (self .n_rep , dim = 1 )
279+ v = v .repeat_interleave (self .n_rep , dim = 1 )
280+ y = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
281+
282+ return y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
283+
284+
196285class Attention (nn .Module ):
197286 def __init__ (self , args : ModelArgs , layer_id : int ):
198287 super ().__init__ ()
@@ -213,7 +302,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
213302 self .wv = nn .Linear (args .dim , self .n_kv_heads * self .head_dim , bias = False )
214303 self .wo = nn .Linear (args .n_heads * self .head_dim , args .dim , bias = False )
215304
216- self .use_sdpa_with_kv_cache_op = args .use_sdpa_with_kv_cache_op
217305 self .layer_id = layer_id
218306
219307 causal_mask = torch .tril (
@@ -234,6 +322,13 @@ def __init__(self, args: ModelArgs, layer_id: int):
234322 self .head_dim ,
235323 not args .use_sdpa_with_kv_cache_op , # if we are using the custom op dont transpose the cache. Expect untransposed q k v
236324 )
325+ self .SDPA = SDPA (
326+ self .kv_cache ,
327+ self .mask ,
328+ args .use_sdpa_with_kv_cache_op ,
329+ self .dim ,
330+ self .n_rep ,
331+ )
237332
238333 def forward (
239334 self ,
@@ -256,41 +351,8 @@ def forward(
256351
257352 if self .use_kv_cache :
258353 assert input_pos is not None
259-
260- if not self .use_sdpa_with_kv_cache_op :
261-
262- q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
263- k = k .transpose (1 , 2 )
264- v = v .transpose (1 , 2 )
265-
266- k , v = self .kv_cache .update (input_pos , k , v )
267- mask = self .mask [None , None , input_pos ]
268-
269- k = k .repeat_interleave (self .n_rep , dim = 1 )
270- v = v .repeat_interleave (self .n_rep , dim = 1 )
271- y = F .scaled_dot_product_attention (
272- q , k , v , attn_mask = mask , dropout_p = 0.0
273- )
274-
275- y = y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
276-
277- y = self .wo (y )
278- return y
279- else :
280- from .custom_ops .sdpa_with_kv_cache import sdpa_with_kv_cache # noqa
281-
282- output = torch .ops .llama .sdpa_with_kv_cache (
283- q ,
284- k ,
285- v ,
286- self .kv_cache .k_cache ,
287- self .kv_cache .v_cache ,
288- input_pos [- 1 ].item (),
289- seqlen ,
290- )
291- output = output .view (bsz , seqlen , - 1 )
292- output = self .wo (output )
293- return output
354+ output = self .SDPA (input_pos , q , k , v , bsz , seqlen )
355+ return self .wo (output )
294356
295357 q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
296358 k = k .transpose (1 , 2 )
0 commit comments