Skip to content

Commit 8903a96

Browse files
integrated flash attention 2 (#1035)
1 parent 70af6e8 commit 8903a96

File tree

3 files changed

+332
-7
lines changed

3 files changed

+332
-7
lines changed

megatron/model/flash_attention.py

Lines changed: 323 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.nn.functional as F
66

77
from flash_attn import flash_attn_triton
8-
import flash_attn_cuda
8+
import flash_attn_2_cuda as flash_attn_cuda # For flash_attn version 2.1.1
99

1010

1111
def flash_attn_unpadded_unpacked_func_triton(
@@ -462,3 +462,325 @@ def flash_attn_unpadded_func_cuda(
462462
causal,
463463
return_attn_probs,
464464
)
465+
466+
467+
# For flash-attention 2 integration
468+
def _flash_attn_varlen_forward(
469+
q,
470+
k,
471+
v,
472+
cu_seqlens_q,
473+
cu_seqlens_k,
474+
max_seqlen_q,
475+
max_seqlen_k,
476+
dropout_p,
477+
softmax_scale,
478+
causal,
479+
return_softmax,
480+
):
481+
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
482+
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
483+
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
484+
q,
485+
k,
486+
v,
487+
None,
488+
cu_seqlens_q,
489+
cu_seqlens_k,
490+
max_seqlen_q,
491+
max_seqlen_k,
492+
dropout_p,
493+
softmax_scale,
494+
False,
495+
causal,
496+
return_softmax,
497+
None,
498+
)
499+
# if out.isnan().any() or softmax_lse.isnan().any():
500+
# breakpoint()
501+
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
502+
503+
504+
def _flash_attn_varlen_backward(
505+
dout,
506+
q,
507+
k,
508+
v,
509+
out,
510+
softmax_lse,
511+
dq,
512+
dk,
513+
dv,
514+
cu_seqlens_q,
515+
cu_seqlens_k,
516+
max_seqlen_q,
517+
max_seqlen_k,
518+
dropout_p,
519+
softmax_scale,
520+
causal,
521+
rng_state=None,
522+
):
523+
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
524+
# dq, dk, dv are allocated by us so they should already be contiguous
525+
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
526+
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
527+
dout,
528+
q,
529+
k,
530+
v,
531+
out,
532+
softmax_lse,
533+
dq,
534+
dk,
535+
dv,
536+
cu_seqlens_q,
537+
cu_seqlens_k,
538+
max_seqlen_q,
539+
max_seqlen_k,
540+
dropout_p,
541+
softmax_scale,
542+
False,
543+
causal,
544+
None,
545+
rng_state,
546+
)
547+
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
548+
# breakpoint()
549+
return dq, dk, dv, softmax_d
550+
551+
552+
class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
553+
@staticmethod
554+
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax):
555+
if softmax_scale is None:
556+
softmax_scale = qkv.shape[-1] ** (-0.5)
557+
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
558+
qkv[:, 0],
559+
qkv[:, 1],
560+
qkv[:, 2],
561+
cu_seqlens,
562+
cu_seqlens,
563+
max_seqlen,
564+
max_seqlen,
565+
dropout_p,
566+
softmax_scale,
567+
causal=causal,
568+
return_softmax=return_softmax and dropout_p > 0,
569+
)
570+
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
571+
ctx.dropout_p = dropout_p
572+
ctx.max_seqlen = max_seqlen
573+
ctx.softmax_scale = softmax_scale
574+
ctx.causal = causal
575+
return out if not return_softmax else (out, softmax_lse, S_dmask)
576+
577+
@staticmethod
578+
def backward(ctx, dout, *args):
579+
q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
580+
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
581+
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
582+
_flash_attn_varlen_backward(
583+
dout,
584+
q,
585+
k,
586+
v,
587+
out,
588+
softmax_lse,
589+
dqkv[:, 0],
590+
dqkv[:, 1],
591+
dqkv[:, 2],
592+
cu_seqlens,
593+
cu_seqlens,
594+
ctx.max_seqlen,
595+
ctx.max_seqlen,
596+
ctx.dropout_p,
597+
ctx.softmax_scale,
598+
ctx.causal,
599+
rng_state=rng_state,
600+
)
601+
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
602+
return dqkv, None, None, None, None, None, None
603+
604+
605+
def flash_attn_varlen_qkvpacked_func(
606+
qkv,
607+
cu_seqlens,
608+
max_seqlen,
609+
dropout_p=0.0,
610+
softmax_scale=None,
611+
causal=False,
612+
return_attn_probs=False,
613+
):
614+
"""dropout_p should be set to 0.0 during evaluation
615+
If Q, K, V are already stacked into 1 tensor, this function will be faster than
616+
calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
617+
of the gradients of Q, K, V.
618+
For multi-query and grouped-query attention (MQA/GQA), please see
619+
flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
620+
621+
Arguments:
622+
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
623+
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
624+
of the sequences in the batch, used to index into qkv.
625+
max_seqlen: int. Maximum sequence length in the batch.
626+
dropout_p: float. Dropout probability.
627+
softmax_scale: float. The scaling of QK^T before applying softmax.
628+
Default to 1 / sqrt(headdim).
629+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
630+
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
631+
testing only. The returned probabilities are not guaranteed to be correct
632+
(they might not have the right scaling).
633+
Return:
634+
out: (total, nheads, headdim).
635+
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
636+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
637+
normalization factor).
638+
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
639+
The output of softmax (possibly with different scaling). It also encodes the dropout
640+
pattern (negative means that location was dropped, nonnegative means it was kept).
641+
"""
642+
return FlashAttnVarlenQKVPackedFunc.apply(
643+
qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs
644+
)
645+
646+
647+
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
648+
@staticmethod
649+
def forward(
650+
ctx,
651+
q,
652+
kv,
653+
cu_seqlens_q,
654+
cu_seqlens_k,
655+
max_seqlen_q,
656+
max_seqlen_k,
657+
dropout_p,
658+
softmax_scale,
659+
causal,
660+
return_softmax,
661+
):
662+
if softmax_scale is None:
663+
softmax_scale = q.shape[-1] ** (-0.5)
664+
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
665+
q,
666+
kv[:, 0],
667+
kv[:, 1],
668+
cu_seqlens_q,
669+
cu_seqlens_k,
670+
max_seqlen_q,
671+
max_seqlen_k,
672+
dropout_p,
673+
softmax_scale,
674+
causal=causal,
675+
return_softmax=return_softmax and dropout_p > 0,
676+
)
677+
ctx.save_for_backward(
678+
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
679+
)
680+
ctx.dropout_p = dropout_p
681+
ctx.max_seqlen_q = max_seqlen_q
682+
ctx.max_seqlen_k = max_seqlen_k
683+
ctx.softmax_scale = softmax_scale
684+
ctx.causal = causal
685+
return out if not return_softmax else (out, softmax_lse, S_dmask)
686+
687+
@staticmethod
688+
def backward(ctx, dout, *args):
689+
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
690+
dq = torch.empty_like(q)
691+
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
692+
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
693+
_flash_attn_varlen_backward(
694+
dout,
695+
q,
696+
k,
697+
v,
698+
out,
699+
softmax_lse,
700+
dq,
701+
dkv[:, 0],
702+
dkv[:, 1],
703+
cu_seqlens_q,
704+
cu_seqlens_k,
705+
ctx.max_seqlen_q,
706+
ctx.max_seqlen_k,
707+
ctx.dropout_p,
708+
ctx.softmax_scale,
709+
ctx.causal,
710+
rng_state=rng_state,
711+
)
712+
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
713+
dkv = dkv[..., : dout.shape[-1]]
714+
return dq, dkv, None, None, None, None, None, None, None, None
715+
716+
717+
def flash_attn_varlen_kvpacked_func(
718+
q,
719+
kv,
720+
cu_seqlens_q,
721+
cu_seqlens_k,
722+
max_seqlen_q,
723+
max_seqlen_k,
724+
dropout_p=0.0,
725+
softmax_scale=None,
726+
causal=False,
727+
return_attn_probs=False,
728+
):
729+
"""dropout_p should be set to 0.0 during evaluation
730+
If K, V are already stacked into 1 tensor, this function will be faster than
731+
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
732+
of the gradients of K, V.
733+
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
734+
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
735+
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
736+
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
737+
738+
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
739+
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
740+
1 1 1 1 0
741+
1 1 1 1 1
742+
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
743+
0 0
744+
0 0
745+
0 0
746+
1 0
747+
1 1
748+
If the row of the mask is all zero, the output will be zero.
749+
750+
Arguments:
751+
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
752+
kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
753+
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
754+
of the sequences in the batch, used to index into q.
755+
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
756+
of the sequences in the batch, used to index into kv.
757+
max_seqlen_q: int. Maximum query sequence length in the batch.
758+
max_seqlen_k: int. Maximum key sequence length in the batch.
759+
dropout_p: float. Dropout probability.
760+
softmax_scale: float. The scaling of QK^T before applying softmax.
761+
Default to 1 / sqrt(headdim).
762+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
763+
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
764+
testing only. The returned probabilities are not guaranteed to be correct
765+
(they might not have the right scaling).
766+
Return:
767+
out: (total, nheads, headdim).
768+
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
769+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
770+
normalization factor).
771+
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
772+
The output of softmax (possibly with different scaling). It also encodes the dropout
773+
pattern (negative means that location was dropped, nonnegative means it was kept).
774+
"""
775+
return FlashAttnVarlenKVPackedFunc.apply(
776+
q,
777+
kv,
778+
cu_seqlens_q,
779+
cu_seqlens_k,
780+
max_seqlen_q,
781+
max_seqlen_k,
782+
dropout_p,
783+
softmax_scale,
784+
causal,
785+
return_attn_probs,
786+
)

megatron/model/transformer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -345,14 +345,17 @@ def __init__(
345345
else:
346346
if self.use_flash_attention:
347347
from megatron.model.flash_attention import (
348-
flash_attn_unpadded_qkvpacked_func_cuda,
349-
flash_attn_unpadded_kvpacked_func_cuda,
350-
flash_attn_unpadded_unpacked_func_triton,
348+
# flash_attn_unpadded_qkvpacked_func_cuda,
349+
# flash_attn_unpadded_kvpacked_func_cuda,
350+
# Change of function names going from flash attention 1 -> flash attention 2
351+
flash_attn_varlen_qkvpacked_func,
352+
flash_attn_varlen_kvpacked_func,
353+
flash_attn_unpadded_unpacked_func_triton
351354
)
352355

353356
self.flash_triton_fn = flash_attn_unpadded_unpacked_func_triton
354-
self.flash_qkv_fn = flash_attn_unpadded_qkvpacked_func_cuda
355-
self.flash_kv_fn = flash_attn_unpadded_kvpacked_func_cuda
357+
self.flash_qkv_fn = flash_attn_varlen_qkvpacked_func
358+
self.flash_kv_fn = flash_attn_varlen_kvpacked_func
356359
else:
357360
self.scale_mask_softmax = FusedScaleMaskSoftmax(
358361
input_in_fp16=self.fp16,
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
flash-attn==0.2.2
1+
flash-attn==2.2.1

0 commit comments

Comments
 (0)