@@ -213,14 +213,14 @@ class SDPA(nn.Module):
213213 def __init__ (
214214 self ,
215215 kv_cache : KVCache ,
216- mask ,
217216 dim : int ,
217+ head_dim : int ,
218218 n_rep : int ,
219219 ):
220220 super ().__init__ ()
221221 self .kv_cache = kv_cache
222- self .mask = mask
223222 self .dim = dim
223+ self .head_dim = head_dim
224224 self .n_rep = n_rep
225225
226226 def forward (
@@ -231,17 +231,18 @@ def forward(
231231 v : torch .Tensor ,
232232 bsz ,
233233 seqlen ,
234+ mask : torch .Tensor ,
234235 ) -> torch .Tensor :
235236 q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
236237 k = k .transpose (1 , 2 )
237238 v = v .transpose (1 , 2 )
238239
239240 k , v = self .kv_cache .update (input_pos , k , v )
240- mask = self . mask [None , None , input_pos ]
241+ attn_mask = mask [None , None , input_pos ]
241242
242243 k = k .repeat_interleave (self .n_rep , dim = 1 )
243244 v = v .repeat_interleave (self .n_rep , dim = 1 )
244- y = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
245+ y = F .scaled_dot_product_attention (q , k , v , attn_mask = attn_mask , dropout_p = 0.0 )
245246
246247 return y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
247248
@@ -287,10 +288,10 @@ def __init__(self, args: ModelArgs, layer_id: int):
287288 not args .use_sdpa_with_kv_cache_op , # if we are using the custom op dont transpose the cache. Expect untransposed q k v
288289 )
289290 self .SDPA = SDPA (
290- self .kv_cache ,
291- self .mask ,
292- self .dim ,
293- self .n_rep ,
291+ kv_cache = self .kv_cache ,
292+ dim = self .dim ,
293+ head_dim = self .head_dim ,
294+ n_rep = self .n_rep ,
294295 )
295296
296297 def forward (
@@ -314,7 +315,7 @@ def forward(
314315
315316 if self .use_kv_cache :
316317 assert input_pos is not None
317- output = self .SDPA (input_pos , q , k , v , bsz , seqlen )
318+ output = self .SDPA (input_pos , q , k , v , bsz , seqlen , self . mask )
318319 return self .wo (output )
319320
320321 q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
0 commit comments