2121import torch
2222import torch .nn .functional as F
2323import torch .nn as nn
24+ from pkg_resources import packaging
25+ from importlib .metadata import version
2426
2527from .norms import get_norm
2628from megatron import mpu
@@ -412,6 +414,14 @@ def __init__(
412414 self .rope_fusion = neox_args .rope_fusion
413415 self .attention_type = neox_args .attention_config [layer_number ]
414416 self .use_flash_attention = self .attention_type == "flash"
417+ self .use_triton = (
418+ self .use_flash_attention
419+ and self .pos_emb == "alibi"
420+ and (
421+ not packaging .version .Version (version ("flash-attn" ))
422+ >= packaging .version .Version ("2.4.0.post1" )
423+ )
424+ )
415425 self .sparse = self .attention_type not in ("global" , "flash" )
416426
417427 if self .gqa :
@@ -578,7 +588,7 @@ def flash_attention(self, query_layer, key_layer, value_layer):
578588 key_layer .size (0 ),
579589 )
580590
581- if self .pos_emb != "alibi" :
591+ if self .use_flash_attention and not self . use_triton :
582592
583593 # [sk, b, np, hn] -> [b, sk, np, hn] -> [b * sk, 1, np, hn]
584594 key_layer = key_layer .transpose (0 , 1 ).reshape (
@@ -588,41 +598,46 @@ def flash_attention(self, query_layer, key_layer, value_layer):
588598 output_size [0 ], output_size [3 ], self .num_kv_heads_per_partition , - 1
589599 )
590600
591- batch_size = output_size [0 ]
592- max_seqlen_q = output_size [2 ]
593- max_seqlen_k = output_size [3 ]
594-
595- cu_seqlens_q = torch .arange (
596- 0 ,
597- (batch_size + 1 ) * max_seqlen_q ,
598- step = max_seqlen_q ,
599- dtype = torch .int32 ,
600- device = query_layer .device ,
601- )
602-
603- cu_seqlens_k = torch .arange (
604- 0 ,
605- (batch_size + 1 ) * max_seqlen_k ,
606- step = max_seqlen_k ,
607- dtype = torch .int32 ,
608- device = key_layer .device ,
609- )
610-
611601 # [sq, b, np, hn] -> [b, sq, np, hn]
612602 query_layer = query_layer .transpose (0 , 1 ).reshape (
613603 output_size [0 ], output_size [2 ], output_size [1 ], - 1
614604 )
615605
616- # only pass in window_size kwarg to flash-attn
617- # if we use Sliding Window Attention.
606+ # only pass in window_size or alibi_slopes kwarg
607+ # if we use Sliding Window Attention / AliBi .
618608 # Flash attn defaults to (-1,-1), or
619609 # does not have this kwarg prior to v2.3.0
620610 extra_kwargs = (
621611 {"window_size" : (self .sliding_window_width , - 1 )}
622612 if self .sliding_window_width is not None
623613 else {}
624614 )
615+ if self .pos_emb == "alibi" :
616+ extra_kwargs ["alibi_slopes" ] = self .alibi_embed .slopes .to (
617+ query_layer .device
618+ ).to (torch .float32 )
619+
625620 if not self .training :
621+ batch_size = output_size [0 ]
622+ max_seqlen_q = output_size [2 ]
623+ max_seqlen_k = output_size [3 ]
624+
625+ cu_seqlens_q = torch .arange (
626+ 0 ,
627+ (batch_size + 1 ) * max_seqlen_q ,
628+ step = max_seqlen_q ,
629+ dtype = torch .int32 ,
630+ device = query_layer .device ,
631+ )
632+
633+ cu_seqlens_k = torch .arange (
634+ 0 ,
635+ (batch_size + 1 ) * max_seqlen_k ,
636+ step = max_seqlen_k ,
637+ dtype = torch .int32 ,
638+ device = key_layer .device ,
639+ )
640+
626641 q_shape = query_layer .shape
627642 k_shape = key_layer .shape
628643 v_shape = value_layer .shape
@@ -662,6 +677,8 @@ def flash_attention(self, query_layer, key_layer, value_layer):
662677 matmul_result = matmul_result .transpose (1 , 2 )
663678
664679 else :
680+ # we still use Triton if using AliBi with flash-attn<2.4.0.post1.
681+
665682 # [sq, b, np, hn] -> [b, sq, np, hn]
666683 sq = query_layer .size (0 )
667684 b = query_layer .size (1 )
0 commit comments