@@ -220,9 +220,19 @@ def forward(
220220
221221 if past_key_value is not None :
222222 # sin and cos are specific to RoPE models; cache_position needed for the static cache
223- cache_kwargs = {"sin" : sin , "cos" : cos , "cache_position" : cache_position }
223+ cache_kwargs = {
224+ "sin" : sin ,
225+ "cos" : cos ,
226+ "cache_position" : cache_position ,
227+ "sliding_window" : self .sliding_window ,
228+ }
224229 key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
225230
231+ # Here we need to slice as we use a static cache by default, but FA2 does not support it
232+ if attention_mask is not None and self .config ._attn_implementation == "flash_attention_2" :
233+ seq_len = attention_mask .shape [- 1 ]
234+ key_states , value_states = key_states [:, :, :seq_len , :], value_states [:, :, :seq_len , :]
235+
226236 attention_interface : Callable = eager_attention_forward
227237 if self .config ._attn_implementation != "eager" :
228238 if self .config ._attn_implementation == "sdpa" and kwargs .get ("output_attentions" , False ):
@@ -276,20 +286,30 @@ def forward(
276286 output_attentions : Optional [bool ] = False ,
277287 use_cache : Optional [bool ] = False ,
278288 cache_position : Optional [torch .LongTensor ] = None ,
289+ last_cache_position : int = 0 ,
290+ ** kwargs ,
279291 ) -> Tuple [torch .FloatTensor , Optional [Tuple [torch .FloatTensor , torch .FloatTensor ]]]:
280292 if self .is_sliding and attention_mask is not None : # efficient SDPA and no padding
281- # Flash-attn is a 2D tensor
293+ # In prefill, we may be larger than sliding window
294+ effective_seq_len = max (cache_position .shape [0 ], self .sliding_window )
295+ # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
296+ # thus we must slice from the right (at most `effective_seq_len` elements)
282297 if self .config ._attn_implementation == "flash_attention_2" :
283- if past_key_value is not None : # when decoding
284- attention_mask = attention_mask [:, - self .sliding_window :]
298+ attention_mask = attention_mask [:, - effective_seq_len :]
299+ # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
300+ # from the left, with an offset if we are beyond the sliding window
285301 else :
286302 min_dtype = torch .finfo (hidden_states .dtype ).min
287303 sliding_window_mask = torch .tril (
288304 torch .ones_like (attention_mask , dtype = torch .bool ), diagonal = - self .sliding_window
289305 )
290306 attention_mask = torch .where (sliding_window_mask , min_dtype , attention_mask )
291- if attention_mask .shape [- 1 ] <= 1 : # when decoding
292- attention_mask = attention_mask [:, :, :, - self .sliding_window :]
307+ # In case we are beyond the sliding window, we need to correctly offset the mask slicing
308+ # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
309+ offset = last_cache_position - effective_seq_len
310+ # Should only be used when beyond the sliding window (i.e. offset > 0)
311+ offset = max (0 , offset )
312+ attention_mask = attention_mask [:, :, :, offset : offset + effective_seq_len ]
293313
294314 residual = hidden_states
295315
@@ -305,6 +325,7 @@ def forward(
305325 output_attentions = output_attentions ,
306326 use_cache = use_cache ,
307327 cache_position = cache_position ,
328+ ** kwargs ,
308329 )
309330 hidden_states = self .post_attention_layernorm (hidden_states )
310331 hidden_states = residual + hidden_states
@@ -549,6 +570,7 @@ def forward(
549570 output_hidden_states : Optional [bool ] = None ,
550571 return_dict : Optional [bool ] = None ,
551572 cache_position : Optional [torch .LongTensor ] = None ,
573+ last_cache_position : Optional [int ] = None ,
552574 ** flash_attn_kwargs : Unpack [FlashAttentionKwargs ],
553575 ) -> Union [Tuple , BaseModelOutputWithPast ]:
554576 output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
@@ -589,6 +611,16 @@ def forward(
589611 if position_ids is None :
590612 position_ids = cache_position .unsqueeze (0 )
591613
614+ # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
615+ # (retrieving the same value from `cache_position` later on would crash dynamo)
616+ if last_cache_position is None :
617+ last_cache_position = 0
618+ if attention_mask is not None :
619+ # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
620+ # It will break dynamo tracing but there are no way around it (and it should never happen in practice)
621+ last_cache_position = (
622+ attention_mask .shape [- 1 ] if attention_mask .dim () == 2 else cache_position [- 1 ].item ()
623+ )
592624 causal_mask = self ._update_causal_mask (
593625 attention_mask , inputs_embeds , cache_position , past_key_values , output_attentions
594626 )
@@ -624,6 +656,7 @@ def forward(
624656 output_attentions ,
625657 use_cache ,
626658 cache_position ,
659+ last_cache_position ,
627660 )
628661 else :
629662 layer_outputs = decoder_layer (
@@ -635,6 +668,7 @@ def forward(
635668 output_attentions = output_attentions ,
636669 use_cache = use_cache ,
637670 cache_position = cache_position ,
671+ last_cache_position = last_cache_position ,
638672 ** flash_attn_kwargs ,
639673 )
640674
@@ -850,6 +884,7 @@ def forward(
850884 output_hidden_states = output_hidden_states ,
851885 return_dict = return_dict ,
852886 cache_position = cache_position ,
887+ ** loss_kwargs ,
853888 )
854889
855890 hidden_states = outputs [0 ]
@@ -918,6 +953,10 @@ def prepare_inputs_for_generation(
918953 # The clone here is for the same reason as for `position_ids`.
919954 model_inputs = {"input_ids" : input_ids .clone (memory_format = torch .contiguous_format ), "inputs_embeds" : None }
920955
956+ # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
957+ # (retrieving the same value from `cache_position` later on would crash dynamo)
958+ model_inputs ["last_cache_position" ] = attention_mask .shape [- 1 ] if attention_mask is not None else 0
959+
921960 if (
922961 isinstance (past_key_values , HybridCache )
923962 and attention_mask .ndim == 2
0 commit comments