|
5 | 5 | import torch.nn.functional as F |
6 | 6 |
|
7 | 7 | from flash_attn import flash_attn_triton |
8 | | -import flash_attn_cuda |
| 8 | +import flash_attn_2_cuda as flash_attn_cuda # For flash_attn version 2.1.1 |
9 | 9 |
|
10 | 10 |
|
11 | 11 | def flash_attn_unpadded_unpacked_func_triton( |
@@ -462,3 +462,325 @@ def flash_attn_unpadded_func_cuda( |
462 | 462 | causal, |
463 | 463 | return_attn_probs, |
464 | 464 | ) |
| 465 | + |
| 466 | + |
| 467 | +# For flash-attention 2 integration |
| 468 | +def _flash_attn_varlen_forward( |
| 469 | + q, |
| 470 | + k, |
| 471 | + v, |
| 472 | + cu_seqlens_q, |
| 473 | + cu_seqlens_k, |
| 474 | + max_seqlen_q, |
| 475 | + max_seqlen_k, |
| 476 | + dropout_p, |
| 477 | + softmax_scale, |
| 478 | + causal, |
| 479 | + return_softmax, |
| 480 | +): |
| 481 | + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x |
| 482 | + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] |
| 483 | + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd( |
| 484 | + q, |
| 485 | + k, |
| 486 | + v, |
| 487 | + None, |
| 488 | + cu_seqlens_q, |
| 489 | + cu_seqlens_k, |
| 490 | + max_seqlen_q, |
| 491 | + max_seqlen_k, |
| 492 | + dropout_p, |
| 493 | + softmax_scale, |
| 494 | + False, |
| 495 | + causal, |
| 496 | + return_softmax, |
| 497 | + None, |
| 498 | + ) |
| 499 | + # if out.isnan().any() or softmax_lse.isnan().any(): |
| 500 | + # breakpoint() |
| 501 | + return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state |
| 502 | + |
| 503 | + |
| 504 | +def _flash_attn_varlen_backward( |
| 505 | + dout, |
| 506 | + q, |
| 507 | + k, |
| 508 | + v, |
| 509 | + out, |
| 510 | + softmax_lse, |
| 511 | + dq, |
| 512 | + dk, |
| 513 | + dv, |
| 514 | + cu_seqlens_q, |
| 515 | + cu_seqlens_k, |
| 516 | + max_seqlen_q, |
| 517 | + max_seqlen_k, |
| 518 | + dropout_p, |
| 519 | + softmax_scale, |
| 520 | + causal, |
| 521 | + rng_state=None, |
| 522 | +): |
| 523 | + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x |
| 524 | + # dq, dk, dv are allocated by us so they should already be contiguous |
| 525 | + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] |
| 526 | + dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd( |
| 527 | + dout, |
| 528 | + q, |
| 529 | + k, |
| 530 | + v, |
| 531 | + out, |
| 532 | + softmax_lse, |
| 533 | + dq, |
| 534 | + dk, |
| 535 | + dv, |
| 536 | + cu_seqlens_q, |
| 537 | + cu_seqlens_k, |
| 538 | + max_seqlen_q, |
| 539 | + max_seqlen_k, |
| 540 | + dropout_p, |
| 541 | + softmax_scale, |
| 542 | + False, |
| 543 | + causal, |
| 544 | + None, |
| 545 | + rng_state, |
| 546 | + ) |
| 547 | + # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): |
| 548 | + # breakpoint() |
| 549 | + return dq, dk, dv, softmax_d |
| 550 | + |
| 551 | + |
| 552 | +class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): |
| 553 | + @staticmethod |
| 554 | + def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax): |
| 555 | + if softmax_scale is None: |
| 556 | + softmax_scale = qkv.shape[-1] ** (-0.5) |
| 557 | + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( |
| 558 | + qkv[:, 0], |
| 559 | + qkv[:, 1], |
| 560 | + qkv[:, 2], |
| 561 | + cu_seqlens, |
| 562 | + cu_seqlens, |
| 563 | + max_seqlen, |
| 564 | + max_seqlen, |
| 565 | + dropout_p, |
| 566 | + softmax_scale, |
| 567 | + causal=causal, |
| 568 | + return_softmax=return_softmax and dropout_p > 0, |
| 569 | + ) |
| 570 | + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) |
| 571 | + ctx.dropout_p = dropout_p |
| 572 | + ctx.max_seqlen = max_seqlen |
| 573 | + ctx.softmax_scale = softmax_scale |
| 574 | + ctx.causal = causal |
| 575 | + return out if not return_softmax else (out, softmax_lse, S_dmask) |
| 576 | + |
| 577 | + @staticmethod |
| 578 | + def backward(ctx, dout, *args): |
| 579 | + q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors |
| 580 | + qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) |
| 581 | + dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) |
| 582 | + _flash_attn_varlen_backward( |
| 583 | + dout, |
| 584 | + q, |
| 585 | + k, |
| 586 | + v, |
| 587 | + out, |
| 588 | + softmax_lse, |
| 589 | + dqkv[:, 0], |
| 590 | + dqkv[:, 1], |
| 591 | + dqkv[:, 2], |
| 592 | + cu_seqlens, |
| 593 | + cu_seqlens, |
| 594 | + ctx.max_seqlen, |
| 595 | + ctx.max_seqlen, |
| 596 | + ctx.dropout_p, |
| 597 | + ctx.softmax_scale, |
| 598 | + ctx.causal, |
| 599 | + rng_state=rng_state, |
| 600 | + ) |
| 601 | + dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension |
| 602 | + return dqkv, None, None, None, None, None, None |
| 603 | + |
| 604 | + |
| 605 | +def flash_attn_varlen_qkvpacked_func( |
| 606 | + qkv, |
| 607 | + cu_seqlens, |
| 608 | + max_seqlen, |
| 609 | + dropout_p=0.0, |
| 610 | + softmax_scale=None, |
| 611 | + causal=False, |
| 612 | + return_attn_probs=False, |
| 613 | +): |
| 614 | + """dropout_p should be set to 0.0 during evaluation |
| 615 | + If Q, K, V are already stacked into 1 tensor, this function will be faster than |
| 616 | + calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation |
| 617 | + of the gradients of Q, K, V. |
| 618 | + For multi-query and grouped-query attention (MQA/GQA), please see |
| 619 | + flash_attn_varlen_kvpacked_func and flash_attn_varlen_func. |
| 620 | +
|
| 621 | + Arguments: |
| 622 | + qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. |
| 623 | + cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths |
| 624 | + of the sequences in the batch, used to index into qkv. |
| 625 | + max_seqlen: int. Maximum sequence length in the batch. |
| 626 | + dropout_p: float. Dropout probability. |
| 627 | + softmax_scale: float. The scaling of QK^T before applying softmax. |
| 628 | + Default to 1 / sqrt(headdim). |
| 629 | + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). |
| 630 | + return_attn_probs: bool. Whether to return the attention probabilities. This option is for |
| 631 | + testing only. The returned probabilities are not guaranteed to be correct |
| 632 | + (they might not have the right scaling). |
| 633 | + Return: |
| 634 | + out: (total, nheads, headdim). |
| 635 | + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The |
| 636 | + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax |
| 637 | + normalization factor). |
| 638 | + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). |
| 639 | + The output of softmax (possibly with different scaling). It also encodes the dropout |
| 640 | + pattern (negative means that location was dropped, nonnegative means it was kept). |
| 641 | + """ |
| 642 | + return FlashAttnVarlenQKVPackedFunc.apply( |
| 643 | + qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs |
| 644 | + ) |
| 645 | + |
| 646 | + |
| 647 | +class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): |
| 648 | + @staticmethod |
| 649 | + def forward( |
| 650 | + ctx, |
| 651 | + q, |
| 652 | + kv, |
| 653 | + cu_seqlens_q, |
| 654 | + cu_seqlens_k, |
| 655 | + max_seqlen_q, |
| 656 | + max_seqlen_k, |
| 657 | + dropout_p, |
| 658 | + softmax_scale, |
| 659 | + causal, |
| 660 | + return_softmax, |
| 661 | + ): |
| 662 | + if softmax_scale is None: |
| 663 | + softmax_scale = q.shape[-1] ** (-0.5) |
| 664 | + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( |
| 665 | + q, |
| 666 | + kv[:, 0], |
| 667 | + kv[:, 1], |
| 668 | + cu_seqlens_q, |
| 669 | + cu_seqlens_k, |
| 670 | + max_seqlen_q, |
| 671 | + max_seqlen_k, |
| 672 | + dropout_p, |
| 673 | + softmax_scale, |
| 674 | + causal=causal, |
| 675 | + return_softmax=return_softmax and dropout_p > 0, |
| 676 | + ) |
| 677 | + ctx.save_for_backward( |
| 678 | + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state |
| 679 | + ) |
| 680 | + ctx.dropout_p = dropout_p |
| 681 | + ctx.max_seqlen_q = max_seqlen_q |
| 682 | + ctx.max_seqlen_k = max_seqlen_k |
| 683 | + ctx.softmax_scale = softmax_scale |
| 684 | + ctx.causal = causal |
| 685 | + return out if not return_softmax else (out, softmax_lse, S_dmask) |
| 686 | + |
| 687 | + @staticmethod |
| 688 | + def backward(ctx, dout, *args): |
| 689 | + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors |
| 690 | + dq = torch.empty_like(q) |
| 691 | + kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) |
| 692 | + dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) |
| 693 | + _flash_attn_varlen_backward( |
| 694 | + dout, |
| 695 | + q, |
| 696 | + k, |
| 697 | + v, |
| 698 | + out, |
| 699 | + softmax_lse, |
| 700 | + dq, |
| 701 | + dkv[:, 0], |
| 702 | + dkv[:, 1], |
| 703 | + cu_seqlens_q, |
| 704 | + cu_seqlens_k, |
| 705 | + ctx.max_seqlen_q, |
| 706 | + ctx.max_seqlen_k, |
| 707 | + ctx.dropout_p, |
| 708 | + ctx.softmax_scale, |
| 709 | + ctx.causal, |
| 710 | + rng_state=rng_state, |
| 711 | + ) |
| 712 | + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension |
| 713 | + dkv = dkv[..., : dout.shape[-1]] |
| 714 | + return dq, dkv, None, None, None, None, None, None, None, None |
| 715 | + |
| 716 | + |
| 717 | +def flash_attn_varlen_kvpacked_func( |
| 718 | + q, |
| 719 | + kv, |
| 720 | + cu_seqlens_q, |
| 721 | + cu_seqlens_k, |
| 722 | + max_seqlen_q, |
| 723 | + max_seqlen_k, |
| 724 | + dropout_p=0.0, |
| 725 | + softmax_scale=None, |
| 726 | + causal=False, |
| 727 | + return_attn_probs=False, |
| 728 | +): |
| 729 | + """dropout_p should be set to 0.0 during evaluation |
| 730 | + If K, V are already stacked into 1 tensor, this function will be faster than |
| 731 | + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation |
| 732 | + of the gradients of K, V. |
| 733 | + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads |
| 734 | + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. |
| 735 | + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head |
| 736 | + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. |
| 737 | +
|
| 738 | + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. |
| 739 | + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: |
| 740 | + 1 1 1 1 0 |
| 741 | + 1 1 1 1 1 |
| 742 | + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: |
| 743 | + 0 0 |
| 744 | + 0 0 |
| 745 | + 0 0 |
| 746 | + 1 0 |
| 747 | + 1 1 |
| 748 | + If the row of the mask is all zero, the output will be zero. |
| 749 | +
|
| 750 | + Arguments: |
| 751 | + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. |
| 752 | + kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. |
| 753 | + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths |
| 754 | + of the sequences in the batch, used to index into q. |
| 755 | + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths |
| 756 | + of the sequences in the batch, used to index into kv. |
| 757 | + max_seqlen_q: int. Maximum query sequence length in the batch. |
| 758 | + max_seqlen_k: int. Maximum key sequence length in the batch. |
| 759 | + dropout_p: float. Dropout probability. |
| 760 | + softmax_scale: float. The scaling of QK^T before applying softmax. |
| 761 | + Default to 1 / sqrt(headdim). |
| 762 | + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). |
| 763 | + return_attn_probs: bool. Whether to return the attention probabilities. This option is for |
| 764 | + testing only. The returned probabilities are not guaranteed to be correct |
| 765 | + (they might not have the right scaling). |
| 766 | + Return: |
| 767 | + out: (total, nheads, headdim). |
| 768 | + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The |
| 769 | + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax |
| 770 | + normalization factor). |
| 771 | + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). |
| 772 | + The output of softmax (possibly with different scaling). It also encodes the dropout |
| 773 | + pattern (negative means that location was dropped, nonnegative means it was kept). |
| 774 | + """ |
| 775 | + return FlashAttnVarlenKVPackedFunc.apply( |
| 776 | + q, |
| 777 | + kv, |
| 778 | + cu_seqlens_q, |
| 779 | + cu_seqlens_k, |
| 780 | + max_seqlen_q, |
| 781 | + max_seqlen_k, |
| 782 | + dropout_p, |
| 783 | + softmax_scale, |
| 784 | + causal, |
| 785 | + return_attn_probs, |
| 786 | + ) |
0 commit comments