Skip to content

Commit f70c54d

Browse files
Merge branch 'main' into tp-mamba-neox
2 parents 4bc39c3 + 03186de commit f70c54d

File tree

3 files changed

+51
-28
lines changed

3 files changed

+51
-28
lines changed

configs/neox_arguments.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ Logging Arguments
111111

112112
- **git_hash**: str
113113

114-
Default = 696454f
114+
Default = fdac107
115115

116116
current git hash of repository
117117

megatron/model/transformer.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import torch
2222
import torch.nn.functional as F
2323
import torch.nn as nn
24+
from pkg_resources import packaging
25+
from importlib.metadata import version
2426

2527
from .norms import get_norm
2628
from megatron import mpu
@@ -412,6 +414,14 @@ def __init__(
412414
self.rope_fusion = neox_args.rope_fusion
413415
self.attention_type = neox_args.attention_config[layer_number]
414416
self.use_flash_attention = self.attention_type == "flash"
417+
self.use_triton = (
418+
self.use_flash_attention
419+
and self.pos_emb == "alibi"
420+
and (
421+
not packaging.version.Version(version("flash-attn"))
422+
>= packaging.version.Version("2.4.0.post1")
423+
)
424+
)
415425
self.sparse = self.attention_type not in ("global", "flash")
416426

417427
if self.gqa:
@@ -578,7 +588,7 @@ def flash_attention(self, query_layer, key_layer, value_layer):
578588
key_layer.size(0),
579589
)
580590

581-
if self.pos_emb != "alibi":
591+
if self.use_flash_attention and not self.use_triton:
582592

583593
# [sk, b, np, hn] -> [b, sk, np, hn] -> [b * sk, 1, np, hn]
584594
key_layer = key_layer.transpose(0, 1).reshape(
@@ -588,41 +598,46 @@ def flash_attention(self, query_layer, key_layer, value_layer):
588598
output_size[0], output_size[3], self.num_kv_heads_per_partition, -1
589599
)
590600

591-
batch_size = output_size[0]
592-
max_seqlen_q = output_size[2]
593-
max_seqlen_k = output_size[3]
594-
595-
cu_seqlens_q = torch.arange(
596-
0,
597-
(batch_size + 1) * max_seqlen_q,
598-
step=max_seqlen_q,
599-
dtype=torch.int32,
600-
device=query_layer.device,
601-
)
602-
603-
cu_seqlens_k = torch.arange(
604-
0,
605-
(batch_size + 1) * max_seqlen_k,
606-
step=max_seqlen_k,
607-
dtype=torch.int32,
608-
device=key_layer.device,
609-
)
610-
611601
# [sq, b, np, hn] -> [b, sq, np, hn]
612602
query_layer = query_layer.transpose(0, 1).reshape(
613603
output_size[0], output_size[2], output_size[1], -1
614604
)
615605

616-
# only pass in window_size kwarg to flash-attn
617-
# if we use Sliding Window Attention.
606+
# only pass in window_size or alibi_slopes kwarg
607+
# if we use Sliding Window Attention / AliBi.
618608
# Flash attn defaults to (-1,-1), or
619609
# does not have this kwarg prior to v2.3.0
620610
extra_kwargs = (
621611
{"window_size": (self.sliding_window_width, -1)}
622612
if self.sliding_window_width is not None
623613
else {}
624614
)
615+
if self.pos_emb == "alibi":
616+
extra_kwargs["alibi_slopes"] = self.alibi_embed.slopes.to(
617+
query_layer.device
618+
).to(torch.float32)
619+
625620
if not self.training:
621+
batch_size = output_size[0]
622+
max_seqlen_q = output_size[2]
623+
max_seqlen_k = output_size[3]
624+
625+
cu_seqlens_q = torch.arange(
626+
0,
627+
(batch_size + 1) * max_seqlen_q,
628+
step=max_seqlen_q,
629+
dtype=torch.int32,
630+
device=query_layer.device,
631+
)
632+
633+
cu_seqlens_k = torch.arange(
634+
0,
635+
(batch_size + 1) * max_seqlen_k,
636+
step=max_seqlen_k,
637+
dtype=torch.int32,
638+
device=key_layer.device,
639+
)
640+
626641
q_shape = query_layer.shape
627642
k_shape = key_layer.shape
628643
v_shape = value_layer.shape
@@ -662,6 +677,8 @@ def flash_attention(self, query_layer, key_layer, value_layer):
662677
matmul_result = matmul_result.transpose(1, 2)
663678

664679
else:
680+
# we still use Triton if using AliBi with flash-attn<2.4.0.post1.
681+
665682
# [sq, b, np, hn] -> [b, sq, np, hn]
666683
sq = query_layer.size(0)
667684
b = query_layer.size(1)

megatron/neox_arguments/arguments.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,11 +1092,17 @@ def calculate_derived(self):
10921092
self.num_kv_heads % self.model_parallel_size == 0
10931093
), "Number of KV heads must be at least model_parallel_size for now!"
10941094
# Flash attention version >=2.3.0 required to combine Flash + Sliding Window Attention
1095-
if self.sliding_window_width is not None and "flash" in self.attention_config:
1095+
if "flash" in self.attention_config:
10961096
_flash_version = packaging.version.Version(version("flash-attn"))
1097-
assert _flash_version >= packaging.version.Version(
1098-
"2.3.0"
1099-
), f"Flash-Attention version ({str(_flash_version)}) must be >= 2.3.0 to support sliding window attention."
1097+
if self.sliding_window_width is not None:
1098+
assert _flash_version >= packaging.version.Version(
1099+
"2.3.0"
1100+
), f"Flash-Attention version ({str(_flash_version)}) must be >= 2.3.0 to support sliding window attention."
1101+
if self.pos_emb == "alibi":
1102+
if not _flash_version >= packaging.version.Version("2.4.0.post1"):
1103+
print(
1104+
f"Warning: Flash-Attention version ({str(_flash_version)}) must be >= 2.4.0.post1 to support AliBi. Falling back to flash-attn triton backend, but version 2.4.0.post1 or later will be required in future."
1105+
)
11001106

11011107
# Adding equal dataset weights if none are provided
11021108
if self.train_data_paths and (self.train_data_weights is None):

0 commit comments

Comments
 (0)