@@ -317,7 +317,7 @@ def forward(
317317        # for the decoder 
318318        is_cross_attention  =  key_value_states  is  not   None 
319319
320-         bsz , tgt_len , _  =  hidden_states .size ()
320+         bsz , _ , _  =  hidden_states .size ()
321321
322322        # get query proj 
323323        query_states  =  self .q_proj (hidden_states )
@@ -351,13 +351,15 @@ def forward(
351351            # if encoder bi-directional self-attention `past_key_value` is always `None` 
352352            past_key_value  =  (key_states , value_states )
353353
354+         query_length  =  query_states .shape [1 ]
355+         tgt_len  =  key_states .shape [- 2 ]
356+ 
354357        # Flash attention requires the input to have the shape 
355358        # batch_size x seq_length x head_dim x hidden_dim 
356-         query_states  =  query_states .view (bsz , tgt_len , self .num_heads , self .head_dim )
359+         query_states  =  query_states .view (bsz , query_length , self .num_heads , self .head_dim )
357360        key_states  =  key_states .transpose (1 , 2 ).view (bsz , tgt_len , self .num_heads , self .head_dim )
358361        value_states  =  value_states .transpose (1 , 2 ).view (bsz , tgt_len , self .num_heads , self .head_dim )
359362
360-         _ , query_length , _ , _  =  query_states .shape 
361363
362364        attn_dropout  =  self .dropout  if  self .training  else  0.0 
363365
0 commit comments