@@ -197,14 +197,14 @@ class SDPA(nn.Module):
197197 def __init__ (
198198 self ,
199199 kv_cache : KVCache ,
200- mask ,
201200 dim : int ,
201+ head_dim : int ,
202202 n_rep : int ,
203203 ):
204204 super ().__init__ ()
205205 self .kv_cache = kv_cache
206- self .mask = mask
207206 self .dim = dim
207+ self .head_dim = head_dim
208208 self .n_rep = n_rep
209209
210210 def forward (
@@ -215,17 +215,18 @@ def forward(
215215 v : torch .Tensor ,
216216 bsz ,
217217 seqlen ,
218+ mask : torch .Tensor ,
218219 ) -> torch .Tensor :
219220 q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
220221 k = k .transpose (1 , 2 )
221222 v = v .transpose (1 , 2 )
222223
223224 k , v = self .kv_cache .update (input_pos , k , v )
224- mask = self . mask [None , None , input_pos ]
225+ attn_mask = mask [None , None , input_pos ]
225226
226227 k = k .repeat_interleave (self .n_rep , dim = 1 )
227228 v = v .repeat_interleave (self .n_rep , dim = 1 )
228- y = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
229+ y = F .scaled_dot_product_attention (q , k , v , attn_mask = attn_mask , dropout_p = 0.0 )
229230
230231 return y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
231232
@@ -271,10 +272,10 @@ def __init__(self, args: ModelArgs, layer_id: int):
271272 not args .use_sdpa_with_kv_cache_op , # if we are using the custom op dont transpose the cache. Expect untransposed q k v
272273 )
273274 self .SDPA = SDPA (
274- self .kv_cache ,
275- self .mask ,
276- self .dim ,
277- self .n_rep ,
275+ kv_cache = self .kv_cache ,
276+ dim = self .dim ,
277+ head_dim = self .head_dim ,
278+ n_rep = self .n_rep ,
278279 )
279280
280281 def forward (
@@ -298,7 +299,7 @@ def forward(
298299
299300 if self .use_kv_cache :
300301 assert input_pos is not None
301- output = self .SDPA (input_pos , q , k , v , bsz , seqlen )
302+ output = self .SDPA (input_pos , q , k , v , bsz , seqlen , self . mask )
302303 return self .wo (output )
303304
304305 q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
0 commit comments