diff --git a/colossalai/kernel/kernel_loader.py b/colossalai/kernel/kernel_loader.py index 148c3e3fc08a..88044ade69cf 100644 --- a/colossalai/kernel/kernel_loader.py +++ b/colossalai/kernel/kernel_loader.py @@ -6,7 +6,7 @@ CpuAdamX86Extension, FlashAttentionDaoCudaExtension, FlashAttentionNpuExtension, - FlashAttentionXformersCudaExtension, + FlashAttentionSdpaCudaExtension, FusedOptimizerCudaExtension, LayerNormCudaExtension, MoeCudaExtension, @@ -65,9 +65,9 @@ def load(self, ext_name: str = None): else: usable_exts = [] for ext in exts: - if ext.is_hardware_available(): + if ext.is_available(): # make sure the machine is compatible during kernel loading - ext.assert_hardware_compatible() + ext.assert_compatible() usable_exts.append(ext) assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine." @@ -106,4 +106,12 @@ class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader): class FlashAttentionLoader(KernelLoader): - REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension] + REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionSdpaCudaExtension] + + +class FlashAttentionWithPaddingMaskLoader(KernelLoader): + REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension] + + +class FlashAttentionWithCustomMaskLoader(KernelLoader): + REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension] diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 56e8b08c4e4a..c9b4317a6f17 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,3 +1,4 @@ +from .attn import AttnMaskType, ColoAttention from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row @@ -23,4 +24,6 @@ "FusedRMSNorm", "FusedLinear1D_Col", "ParallelModule", + "AttnMaskType", + "ColoAttention", ] diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py new file mode 100644 index 000000000000..ab2fb8bed8e8 --- /dev/null +++ b/colossalai/shardformer/layer/attn.py @@ -0,0 +1,234 @@ +from enum import Enum +from typing import Callable, Dict, Optional, Tuple + +import torch +import torch.nn.functional as F + +from colossalai.kernel.kernel_loader import ( + FlashAttentionLoader, + FlashAttentionWithCustomMaskLoader, + FlashAttentionWithPaddingMaskLoader, +) + +__all__ = [ + "AttnMaskType", + "ColoAttention", +] + + +class AttnMaskType(Enum): + CUSTOM = 0 + PADDED = 1 + CAUSAL = 2 + PADDED_CAUSAL = 3 + + +def invert_mask(mask: torch.Tensor) -> torch.Tensor: + """Invert the mask tensor. + + Args: + mask (torch.Tensor): Mask tensor. Shape should be [B, 1, Sq, Skv] + + Returns: + torch.Tensor: Inverted mask tensor. + """ + inverted_mask = 1.0 - mask + return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(mask.dtype).min) + + +# adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py +def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.Tensor]: + """Get padding information from padding mask. + + Args: + padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, S] + + Returns: + Tuple[int, torch.Tensor, torch.Tensor]: Tuple of (max_seq_len, cu_seqlens, indices) + """ + seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return max_seqlen_in_batch, cu_seqlens, indices + + +class ColoAttention: + # these two attrs are initialized in the first call of attention() method + _flash_attn_func: Optional[Callable] = None + _flash_attn_with_custom_mask_func: Optional[Callable] = None + _flash_attn_with_padding_mask_func: Optional[Callable] = None + + @staticmethod + def _init_flash_attn_func(): + if ColoAttention._flash_attn_func is None: + ColoAttention._flash_attn_func = FlashAttentionLoader().load() + if ColoAttention._flash_attn_with_custom_mask_func is None: + ColoAttention._flash_attn_with_custom_mask_func = FlashAttentionWithCustomMaskLoader().load() + if ColoAttention._flash_attn_with_padding_mask_func is None: + ColoAttention._flash_attn_with_padding_mask_func = FlashAttentionWithPaddingMaskLoader().load() + + @staticmethod + def prepare_attn_kwargs( + shape_4d: Tuple[int], + dtype: torch.dtype, + device: torch.device, + q_padding_mask: Optional[torch.Tensor] = None, + kv_padding_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + ) -> Dict[str, torch.Tensor]: + """Return a dictionary of keyword arguments for attention function. It supports 4 mask type. + 1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves. + 2. padded mask: recv padding mask and is_causal=False, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}. + 3. causal mask: no padding mask and is_causal=True, return {attention_mask, attention_mask_type}. + 4. padded causal mask: recv padding mask and is_causal=True, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}. + + Args: + shape_4d (Tuple[int]): Should be (B, 1, Sq, Skv) + dtype (torch.dtype): Dtype of attention mask, generally should be ``hidden_states.dtype`` + device (torch.device): Device of attention mask, generally should be ``hidden_states.device`` + q_padding_mask (Optional[torch.Tensor], optional): Padding mask of query. It should be a long tensor or int tensor. + The shape should be [B, Sq]. ``1`` means valid token, and ``0`` means padding token. Defaults to None. + kv_padding_mask (Optional[torch.Tensor], optional): Padding mask of key and value. It should be a long tensor or int tensor. + The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token. + If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None. + is_causal (bool, optional): Whether to use causal attention mask. Defaults to False. + + Returns: + Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function. + """ + if q_padding_mask is None and not is_causal: + return {} + assert len(shape_4d) == 4 and shape_4d[1] == 1 + b, _, s_q, s_kv = shape_4d + outputs = {} + if q_padding_mask is not None: + if kv_padding_mask is None: + kv_padding_mask = q_padding_mask + assert q_padding_mask.shape == (b, s_q) and kv_padding_mask.shape == (b, s_kv) + attention_mask = torch.einsum("bi,bj->bij", q_padding_mask, kv_padding_mask).to(dtype=dtype, device=device) + max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) + max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask) + outputs.update( + { + "cu_seqlens_q": cu_seqlens_q, + "cu_seqlens_kv": cu_seqlens_kv, + "max_seqlen_q": max_seqlen_q, + "max_seqlen_kv": max_seqlen_kv, + "q_indices": q_indices, + "kv_indices": kv_indices, + } + ) + if is_causal: + outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL + attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) + else: + outputs["attention_mask_type"] = AttnMaskType.PADDED + else: + assert is_causal + outputs["attention_mask_type"] = AttnMaskType.CAUSAL + attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv) + attention_mask = invert_mask(attention_mask).unsqueeze(1) + outputs["attention_mask"] = attention_mask + return outputs + + @staticmethod + def attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_type: AttnMaskType = AttnMaskType.CUSTOM, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + q_indices: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: Optional[float] = None, + ) -> torch.Tensor: + """Flash Attention function. It supports 4 mask type. + 1. custom mask: recv attention_mask + 2. padded mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices + 3. causal mask: recv attention_mask, attention_mask_type + 4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices + + Args: + q (torch.Tensor): Query tensor. Shape should be [B, N, Sq, D] + k (torch.Tensor): Key tensor. Shape should be [B, N, Skv, D] + v (torch.Tensor): Value tensor. Shape should be [B, N, Skv, D] + attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None. + attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM. + cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths + of the sequences in the batch, used to index into q. + Shape should be [B+1]. Defaults to None. + cu_seqlens_kv (Optional[torch.Tensor], optional): The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + Shape should be [B+1]. Defaults to None. + max_seqlen_q (Optional[int], optional): Maximum query sequence length in the batch. Defaults to None. + max_seqlen_kv (Optional[int], optional): Maximum key/value sequence length in the batch. Defaults to None. + indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from the flattened input sequence. + Shape should be [NUM_TOKENS]. Defaults to None. + dropout_p (float, optional): Dropout probability. Defaults to 0.0. + scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None. + + Returns: + torch.Tensor: Output tensor. Shape should be [B, N, Sq, D] + """ + ColoAttention._init_flash_attn_func() + # known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan + # this case is usaul when padding mask is used and self attention is performed + # thus, we don't use sdpa when padding mask is used + # sanity check + if attention_mask is not None: + assert torch.is_floating_point(attention_mask), "attention_mask should be a floating point tensor." + if attention_mask_type in (AttnMaskType.CUSTOM, AttnMaskType.CAUSAL): + assert ( + cu_seqlens_q is None + and cu_seqlens_kv is None + and max_seqlen_q is None + and max_seqlen_kv is None + and q_indices is None + and kv_indices is None + ) + if attention_mask_type == AttnMaskType.CUSTOM: + assert not torch.all(attention_mask != 0, dim=-1).any() + elif attention_mask_type in (AttnMaskType.PADDED, AttnMaskType.PADDED_CAUSAL): + assert ( + cu_seqlens_q is not None + and cu_seqlens_kv is not None + and max_seqlen_q is not None + and max_seqlen_kv is not None + and q_indices is not None + and kv_indices is not None + ) + else: + # if attention_mask is None, attention_mask_type should be the default value + assert attention_mask_type == AttnMaskType.CUSTOM + # kernel dispatch + if attention_mask is not None and attention_mask_type == AttnMaskType.CUSTOM: + attn_func = ColoAttention._flash_attn_with_custom_mask_func + elif attention_mask_type in (AttnMaskType.PADDED, AttnMaskType.PADDED_CAUSAL): + attn_func = ColoAttention._flash_attn_with_padding_mask_func + else: + attn_func = ColoAttention._flash_attn_func + is_causal = attention_mask is not None and attention_mask_type in ( + AttnMaskType.CAUSAL, + AttnMaskType.PADDED_CAUSAL, + ) + return attn_func( + q, + k, + v, + dropout_p=dropout_p, + scale=scale, + attention_mask=attention_mask, + is_causal=is_causal, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + q_indices=q_indices, + kv_indices=kv_indices, + ) diff --git a/extensions/README.md b/extensions/README.md index 6f5feb55c2af..b9bde7742be9 100644 --- a/extensions/README.md +++ b/extensions/README.md @@ -101,13 +101,13 @@ class MyExtension(_Extension): self._support_jit = True self.priority = 10 - def is_hardware_available(self) -> bool: + def is_available(self) -> bool: """ Return if the required hardware can be found. """ ... - def assert_hardware_compatible(self) -> None: + def assert_compatible(self) -> None: """ Check if the hardware required by the kernel is compatible. """ diff --git a/extensions/__init__.py b/extensions/__init__.py index 9343cadda194..0dbadba81905 100644 --- a/extensions/__init__.py +++ b/extensions/__init__.py @@ -1,9 +1,5 @@ from .cpu_adam import CpuAdamArmExtension, CpuAdamX86Extension -from .flash_attention import ( - FlashAttentionDaoCudaExtension, - FlashAttentionNpuExtension, - FlashAttentionXformersCudaExtension, -) +from .flash_attention import FlashAttentionDaoCudaExtension, FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension from .layernorm import LayerNormCudaExtension from .moe import MoeCudaExtension from .optimizer import FusedOptimizerCudaExtension @@ -18,7 +14,7 @@ ScaledMaskedSoftmaxCudaExtension, ScaledUpperTriangleMaskedSoftmaxCudaExtension, FlashAttentionDaoCudaExtension, - FlashAttentionXformersCudaExtension, + FlashAttentionSdpaCudaExtension, FlashAttentionNpuExtension, ] @@ -31,6 +27,6 @@ "ScaledMaskedSoftmaxCudaExtension", "ScaledUpperTriangleMaskedSoftmaxCudaExtension", "FlashAttentionDaoCudaExtension", - "FlashAttentionXformersCudaExtension", + "FlashAttentionSdpaCudaExtension", "FlashAttentionNpuExtension", ] diff --git a/extensions/base_extension.py b/extensions/base_extension.py index c815a7f2ac4a..0c79c0a9e9f5 100644 --- a/extensions/base_extension.py +++ b/extensions/base_extension.py @@ -58,13 +58,13 @@ def get_jit_extension_folder_path(): return cache_directory @abstractmethod - def is_hardware_available(self) -> bool: + def is_available(self) -> bool: """ Check if the hardware required by the kernel is available. """ @abstractmethod - def assert_hardware_compatible(self) -> None: + def assert_compatible(self) -> None: """ Check if the hardware required by the kernel is compatible. """ diff --git a/extensions/cpu_adam/cpu_adam_arm.py b/extensions/cpu_adam/cpu_adam_arm.py index 35bff3b55928..61c4f3ed0697 100644 --- a/extensions/cpu_adam/cpu_adam_arm.py +++ b/extensions/cpu_adam/cpu_adam_arm.py @@ -7,11 +7,11 @@ class CpuAdamArmExtension(_CppExtension): def __init__(self): super().__init__(name="cpu_adam_arm") - def is_hardware_available(self) -> bool: + def is_available(self) -> bool: # only arm allowed return platform.machine() == "aarch64" - def assert_hardware_compatible(self) -> None: + def assert_compatible(self) -> None: arch = platform.machine() assert ( arch == "aarch64" diff --git a/extensions/cpu_adam/cpu_adam_x86.py b/extensions/cpu_adam/cpu_adam_x86.py index a38194167b01..9bbc8d85126d 100644 --- a/extensions/cpu_adam/cpu_adam_x86.py +++ b/extensions/cpu_adam/cpu_adam_x86.py @@ -8,15 +8,15 @@ class CpuAdamX86Extension(_CudaExtension): def __init__(self): super().__init__(name="cpu_adam_x86") - def is_hardware_available(self) -> bool: - return platform.machine() == "x86_64" and super().is_hardware_available() + def is_available(self) -> bool: + return platform.machine() == "x86_64" and super().is_available() - def assert_hardware_compatible(self) -> None: + def assert_compatible(self) -> None: arch = platform.machine() assert ( arch == "x86_64" ), f"[extension] The {self.name} kernel requires the CPU architecture to be x86_64 but got {arch}" - super().assert_hardware_compatible() + super().assert_compatible() # necessary 4 functions def sources_files(self): diff --git a/extensions/cuda_extension.py b/extensions/cuda_extension.py index 842cd9713a99..f1e0095b29b6 100644 --- a/extensions/cuda_extension.py +++ b/extensions/cuda_extension.py @@ -22,7 +22,7 @@ def nvcc_flags(self) -> List[str]: This function should return a list of nvcc compilation flags for extensions. """ - def is_hardware_available(self) -> bool: + def is_available(self) -> bool: # cuda extension can only be built if cuda is available try: import torch @@ -32,7 +32,7 @@ def is_hardware_available(self) -> bool: cuda_available = False return cuda_available - def assert_hardware_compatible(self) -> None: + def assert_compatible(self) -> None: from torch.utils.cpp_extension import CUDA_HOME if not CUDA_HOME: diff --git a/extensions/flash_attention/__init__.py b/extensions/flash_attention/__init__.py index 18abb6191035..34c5bbfa6317 100644 --- a/extensions/flash_attention/__init__.py +++ b/extensions/flash_attention/__init__.py @@ -1,6 +1,6 @@ from .flash_attention_dao_cuda import FlashAttentionDaoCudaExtension from .flash_attention_npu import FlashAttentionNpuExtension -from .flash_attention_xformers_cuda import FlashAttentionXformersCudaExtension +from .flash_attention_sdpa_cuda import FlashAttentionSdpaCudaExtension try: import flash_attention # noqa @@ -17,4 +17,4 @@ HAS_MEM_EFF_ATTN = False -__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionXformersCudaExtension", "FlashAttentionNpuExtension"] +__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionSdpaCudaExtension", "FlashAttentionNpuExtension"] diff --git a/extensions/flash_attention/flash_attention_dao_cuda.py b/extensions/flash_attention/flash_attention_dao_cuda.py index 1b7f8ac4736a..a2f2a52f1af4 100644 --- a/extensions/flash_attention/flash_attention_dao_cuda.py +++ b/extensions/flash_attention/flash_attention_dao_cuda.py @@ -5,17 +5,20 @@ class FlashAttentionDaoCudaExtension(_Extension): def __init__(self): super().__init__(name="flash_attention_dao_cuda", support_aot=False, support_jit=False, priority=10) - def is_hardware_available(self) -> bool: + def is_available(self) -> bool: # cuda extension can only be built if cuda is available try: import torch + from flash_attn import flash_attn_func, flash_attn_varlen_kvpacked_func # noqa + from flash_attn.bert_padding import index_first_axis, pad_input # noqa + cuda_available = torch.cuda.is_available() except: cuda_available = False return cuda_available - def assert_hardware_compatible(self) -> bool: + def assert_compatible(self) -> bool: pass def build_aot(self) -> None: @@ -29,65 +32,65 @@ def build_jit(self) -> None: ) def load(self): - try: - from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func - except ImportError: - raise ModuleNotFoundError( - ( - "We rely on the third-party flash-attn library for flash attention. Please install flash-attn via 'pip install flash-attn --no-build-isolation'" - ) - ) - from typing import Optional import torch + from einops import rearrange + from flash_attn import flash_attn_func, flash_attn_varlen_kvpacked_func + from flash_attn.bert_padding import index_first_axis, pad_input + + def _unpad_input(hidden_states: torch.Tensor, indices: torch.Tensor): + return index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices) def flash_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - seq_len_info_q: "SeqLenInfo", - seq_len_info_kv: "SeqLenInfo", - origin_attn_mask: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, dropout_p: float = 0.0, - scale: float = None, - causal: bool = False, - padded: bool = False, + scale: Optional[float] = None, + attention_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + q_indices: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, ): - """ - Arguments: - q: (batch, q_seqlen, nheads, headdim) - k: (batch, kv_seqlen, nheads, headdim) - v: (batch, kv_seqlen, nheads, headdim) - batch_size: int. - seq_len: int. - dropout_p: float. Dropout probability. - sm_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - Return: - attn_out: (batch, q_seqlen, nheads, headdim). - """ - # check if the input is in allowed dtypes - if padded: - if seq_len_info_kv == None: - seq_len_info_kv = seq_len_info_q - - attn_out = flash_attn_varlen_func( + # [B, N, S, D] -> [B, S, N, D] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + b, s_q = q.shape[:2] + if cu_seqlens_q is not None: + # padded / padded causal + # unpad input: [B, S, N, D] -> [T, N, D] + q = _unpad_input(q, q_indices) + kv = _unpad_input(torch.stack(tensors=(k, v), dim=2), kv_indices) + attn_output = flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + ) + # pad output: [T, N, D] -> [B, S, N, D] + attn_output = pad_input(attn_output, q_indices, b, s_q) + else: + # causal / no attn mask + attn_output = flash_attn_func( q, k, v, - seq_len_info_q.cu_seqlens, - seq_len_info_kv.cu_seqlens, - seq_len_info_q.max_seqlen, - seq_len_info_kv.max_seqlen, - dropout_p, - scale, - causal, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, ) - else: - attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal) - return attn_out + # [B, S, N, D] -> [B, N, S, D] + return attn_output.transpose(1, 2) return flash_attention diff --git a/extensions/flash_attention/flash_attention_npu.py b/extensions/flash_attention/flash_attention_npu.py index 58d0f9306e3d..0e01cefa1112 100644 --- a/extensions/flash_attention/flash_attention_npu.py +++ b/extensions/flash_attention/flash_attention_npu.py @@ -5,15 +5,15 @@ class FlashAttentionNpuExtension(_Extension): def __init__(self): super().__init__(name="flash_attention_npu", support_aot=False, support_jit=False) - def is_hardware_available(self) -> bool: + def is_available(self) -> bool: try: - import torch_npu # noqa + import torch_npu - return True + return hasattr(torch_npu, "npu_fusion_attention") except: return False - def assert_hardware_compatible(self) -> bool: + def assert_compatible(self) -> bool: pass def build_aot(self) -> None: @@ -27,47 +27,36 @@ def build_jit(self) -> None: ) def load(self): + from typing import Optional + import torch - from einops import rearrange + import torch_npu - def npu_sdpa_attention( + def flash_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - seq_len_info_q=None, - seq_len_info_kv=None, - origin_attn_mask: torch.Tensor = None, dropout_p: float = 0.0, - scale: float = 1.0, - causal=None, - padded=None, + scale: Optional[float] = None, + attention_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + q_indices: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, ): - """ - The scaled dot product attention. - - Arguments: - q: (batch, q_seqlen, nheads, headdim) - k: (batch, kv_seqlen, nheads, headdim) - v: (batch, kv_seqlen, nheads, headdim) - batch_size: int. - seq_len: int. - dropout_p: float. Dropout probability. - scale: float. The scaling of QK^T before applying softmax. - Default to 1. - Return: - attn_out: (batch, q_seqlen, nheads, headdim). - """ - q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)] - output = torch.nn.functional.scaled_dot_product_attention( + num_heads = q.size(1) + return torch_npu.npu_fusion_attention( q, k, v, - attn_mask=origin_attn_mask, - dropout_p=dropout_p, - is_causal=origin_attn_mask is None, + num_heads, + "BNSD", + atten_mask=attention_mask.bool(), scale=scale, - ) - output = rearrange(output, "b h s d -> b s (h d)") - return output + keep_prob=1 - dropout_p, + )[0] - return npu_sdpa_attention + return flash_attention diff --git a/extensions/flash_attention/flash_attention_sdpa_cuda.py b/extensions/flash_attention/flash_attention_sdpa_cuda.py new file mode 100644 index 000000000000..d3323a6aae27 --- /dev/null +++ b/extensions/flash_attention/flash_attention_sdpa_cuda.py @@ -0,0 +1,56 @@ +from ..base_extension import _Extension + + +class FlashAttentionSdpaCudaExtension(_Extension): + def __init__(self): + super().__init__(name="flash_attention_sdpa_cuda", support_aot=False, support_jit=False) + + def is_available(self) -> bool: + # cuda extension can only be built if cuda is available + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def assert_compatible(self) -> bool: + pass + + def build_aot(self) -> None: + raise NotImplementedError("Flash attention SDPA does not require ahead-of-time compilation.") + + def build_jit(self) -> None: + raise NotImplementedError("Flash attention SDPA does not require just-in-time compilation.") + + def load(self): + from typing import Optional + + import torch + + def flash_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, + attention_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + q_indices: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + ): + return torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attention_mask, + dropout_p=dropout_p, + scale=scale, + ) + + return flash_attention diff --git a/extensions/flash_attention/flash_attention_xformers_cuda.py b/extensions/flash_attention/flash_attention_xformers_cuda.py deleted file mode 100644 index 27cd823de14b..000000000000 --- a/extensions/flash_attention/flash_attention_xformers_cuda.py +++ /dev/null @@ -1,94 +0,0 @@ -from ..base_extension import _Extension - - -class FlashAttentionXformersCudaExtension(_Extension): - def __init__(self): - super().__init__(name="flash_attention_xformers_cuda", support_aot=False, support_jit=False) - - def is_hardware_available(self) -> bool: - # cuda extension can only be built if cuda is available - try: - import torch - - cuda_available = torch.cuda.is_available() - except: - cuda_available = False - return cuda_available - - def assert_hardware_compatible(self) -> bool: - pass - - def build_aot(self) -> None: - raise NotImplementedError( - "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." - ) - - def build_jit(self) -> None: - raise NotImplementedError( - "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." - ) - - def load(self): - try: - from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention - from xformers.ops.fmha.attn_bias import ( - BlockDiagonalCausalMask, - BlockDiagonalMask, - LowerTriangularMask, - LowerTriangularMaskWithTensorBias, - ) - except ImportError: - raise ModuleNotFoundError( - ( - "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." - ) - ) - from typing import Optional - - import torch - - allow_alibi = True - for op in MemoryEfficientAttentionCutlassOp: - allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) - - def mem_eff_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seq_len_info_q: "SeqLenInfo", - seq_len_info_kv: "SeqLenInfo", - origin_attn_mask: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - scale: float = None, - causal: bool = False, - padded: bool = False, - ): - attn_bias = None - if padded: # bert style - if not causal: - attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) - else: - attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) - elif causal: # gpt style - attn_bias = LowerTriangularMask() - - if bias is not None: # alibi / relative position embedding - assert allow_alibi, "flash attention with bias is not supported in this system." - assert causal, "attention with bias is only supported for causal attention so far." - attn_bias = attn_bias.add_bias(bias) - - if padded: - q = q.unsqueeze(0) - k = k.unsqueeze(0) - v = v.unsqueeze(0) - - out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale) - - # shape: (b*s, n, d) - if padded: - out = out.squeeze(0) - - return out - - return mem_eff_attention diff --git a/setup.py b/setup.py index ef89481e6b1e..c16709ad1c1c 100644 --- a/setup.py +++ b/setup.py @@ -80,8 +80,8 @@ def get_version() -> str: for ext_cls in ALL_EXTENSIONS: ext = ext_cls() - if ext.support_aot and ext.is_hardware_available(): - ext.assert_hardware_compatible() + if ext.support_aot and ext.is_available(): + ext.assert_compatible() op_names.append(ext.name) ext_modules.append(ext.build_aot()) diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index 3ec1700045e3..f9eab132f6f6 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -1,167 +1,147 @@ import math +from copy import copy -import pytest import torch -from einops import rearrange - -from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN +from torch.testing import assert_close + +from colossalai.kernel.kernel_loader import ( + FlashAttentionLoader, + FlashAttentionWithCustomMaskLoader, + FlashAttentionWithPaddingMaskLoader, +) +from colossalai.shardformer.layer import AttnMaskType, ColoAttention +from colossalai.shardformer.layer.attn import invert_mask from colossalai.testing import clear_cache_before_run, parameterize +from colossalai.utils import get_current_device, set_seed -if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN: - from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention - -DTYPE = [torch.float16, torch.bfloat16, torch.float32] - +DTYPE = [torch.float16, torch.bfloat16] +B, N, S, D = 2, 8, 256, 32 -def attention_ref(q, k, v, attn_mask=None, causal=False): - """ - attention output of the control group - """ - dtype_og = q.dtype - seqlen_q, seqlen_k = q.shape[1], k.shape[1] - d = q.shape[-1] - scale = 1.0 / math.sqrt(d) - scores = torch.einsum("bthd,bshd->bhts", q * scale, k) +TOL_MAP = { + torch.float16: {"atol": 5e-4, "rtol": 2e-3}, + torch.bfloat16: {}, +} - if attn_mask is not None: - scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf")) - if causal: - causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1) - scores.masked_fill_(causal_mask, float("-inf")) - attention = torch.softmax(scores, dim=-1) - - output = torch.einsum("bhts,bshd->bthd", attention, v) - output = rearrange(output, "b s h d -> b s (h d)") - # Modify the data at the positions of the mask to 0 +def attention_ref(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask=None, dropout_p=0.0): + head_dim = q.size(-1) + attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim) if attn_mask is not None: - output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1"), 0.0) - - return output.to(dtype=dtype_og) - - -@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") -@clear_cache_before_run() -@parameterize("proj_shape", [(6, 8, 4, 16)]) -@parameterize("dtype", DTYPE) -@parameterize("dropout", [0.0]) -def test_attention_gpt(proj_shape, dtype, dropout): - (B, S, H, D_HEAD) = proj_shape - D = H * D_HEAD - - q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - - mask = [torch.ones(S - i, dtype=torch.bool, device="cuda") for i in range(B)] - mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True) - - attn = ColoAttention(D, H, dropout=dropout) - y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal) - - assert list(y.shape) == [B, S, D] - - out_ref = attention_ref(q, k, v, mask, causal=True) - - # check gradients - dy = torch.rand_like(y) - grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) - grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) - - torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}" - torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" - torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" - torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" - - -@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") -@clear_cache_before_run() -@parameterize("proj_shape", [(6, 8, 4, 16)]) -@parameterize("dtype", DTYPE) -@parameterize("dropout", [0.0]) -def test_attention_bert(proj_shape, dtype, dropout): - (B, S, H, D_HEAD) = proj_shape - D = H * D_HEAD - - q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + attn_weights = attn_weights + attn_mask + attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float).to(q.dtype) + attn_weights = torch.dropout(attn_weights, p=dropout_p, train=True) + attn_output = torch.matmul(attn_weights, v) + return attn_output + + +def gen_padded_kwargs(dtype: torch.dtype): + padding_mask = torch.ones((B, S), dtype=torch.int, device=get_current_device()) + padding_mask[0, : S // 4] = 0 + return ( + ColoAttention.prepare_attn_kwargs((B, 1, S, S), dtype, padding_mask.device, q_padding_mask=padding_mask), + padding_mask, + ) + + +def gen_padded_causal_kwargs(dtype: torch.dtype): + padding_mask = torch.ones((B, S), dtype=torch.int, device=get_current_device()) + padding_mask[0, S // 2 :] = 0 + return ( + ColoAttention.prepare_attn_kwargs( + (B, 1, S, S), dtype, padding_mask.device, q_padding_mask=padding_mask, is_causal=True + ), + padding_mask, + ) + + +def gen_causal_kwargs(dtype: torch.dtype): + return ColoAttention.prepare_attn_kwargs((B, 1, S, S), dtype, get_current_device(), is_causal=True), None + + +def gen_custom_kwargs(dtype: torch.dtype): + attn_mask = torch.ones((B, S, S), dtype=dtype, device=get_current_device()) + attn_mask[0, : S // 2, S // 2 :] = 0 + attn_mask[0, S // 2 :, : S // 2] = 0 + attn_mask[1, :, S // 4 :] = 0 + attn_mask = invert_mask(attn_mask).unsqueeze(1) + assert not torch.all(attn_mask != 0, dim=-1).any() + return {"attention_mask": attn_mask}, None + + +def post_process_kwargs_for_raw_attn(attn_kwargs: dict): + if "attention_mask_type" in attn_kwargs: + attn_kwargs = copy(attn_kwargs) + mask_type = attn_kwargs.pop("attention_mask_type") + attn_kwargs["is_causal"] = mask_type in (AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL) + return attn_kwargs + + +def check_attn_func(dtype: torch.dtype, attn_func, attn_kwargs: dict, padding_mask=None): + tols = TOL_MAP[dtype] + q = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True) + k = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True) + v = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True) + q_flash = q.clone().detach().requires_grad_(True) + k_flash = k.clone().detach().requires_grad_(True) + v_flash = v.clone().detach().requires_grad_(True) + attn_mask = attn_kwargs.get("attention_mask", None) + ref_output = attention_ref(q, k, v, attn_mask) + output = attn_func(q_flash, k_flash, v_flash, **attn_kwargs) + if padding_mask is not None: + # [B, Sq] -> [B, 1, Sq, 1] + padding_mask = padding_mask[:, None, :, None].logical_not() + ref_output = ref_output.masked_fill(padding_mask, 0) + output = output.masked_fill(padding_mask, 0) + assert_close(output, ref_output, **tols) + output.mean().backward() + ref_output.mean().backward() + assert_close(q.grad, q_flash.grad, **tols) + assert_close(k.grad, k_flash.grad, **tols) + assert_close(v.grad, v_flash.grad, **tols) - # attention mask of shape [B, S] with zero padding to max length S - mask = torch.randint(0, 2, (B, S), dtype=torch.bool, device="cuda") - attn = ColoAttention(D, H, dropout=dropout) - y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding) - - assert list(y.shape) == [B, S, D] - - out_ref = attention_ref(q, k, v, mask, causal=False) - - dy = torch.rand_like(y) - grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) - grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) - - torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}" - torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" - torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" - torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" - - -@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") -@clear_cache_before_run() -@parameterize("proj_shape", [(6, 8, 4, 16)]) -@parameterize("dtype", DTYPE) -@parameterize("dropout", [0.0]) -def test_attention_no_mask(proj_shape, dtype, dropout): - (B, S, H, D_HEAD) = proj_shape - D = H * D_HEAD - - q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - - attn = ColoAttention(D, H, dropout=dropout) - y = attn(q, k, v) - - assert list(y.shape) == [B, S, D] - - out_ref = attention_ref(q, k, v, None, causal=False) - - dy = torch.rand_like(y) - grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) - grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) - - torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}" - torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" - torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" - torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" - - -@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") @clear_cache_before_run() -@parameterize("proj_shape", [(6, 24, 8, 4, 16)]) @parameterize("dtype", DTYPE) -@parameterize("dropout", [0.0]) -def test_cross_attention(proj_shape, dtype, dropout): - (B, S, T, H, D_HEAD) = proj_shape - D = H * D_HEAD - - q = torch.randn((B, T, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - - attn = ColoAttention(D, H, dropout=dropout) - y = attn(q, k, v, attn_mask_type=AttnMaskType.causal) - - assert list(y.shape) == [B, T, D] - - out_ref = attention_ref(q, k, v, None, causal=True) - - dy = torch.rand_like(y) - grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) - grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) - - torch.allclose(y, out_ref, atol=1e-18), f"{(y - out_ref).abs().max()}" - torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" - torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" - torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" +def test_flash_attn_func(dtype: torch.dtype): + torch.backends.cudnn.deterministic = True + set_seed(0) + # (func, name, need_postprocess) + avail_attn_funcs = [(ColoAttention.attention, "coloattn", False)] + avail_custom_mask_attn_funcs = [(ColoAttention.attention, "coloattn", False)] + avail_padding_mask_attn_funcs = [(ColoAttention.attention, "coloattn", False)] + for ext_cls in FlashAttentionLoader.REGISTRY: + ext = ext_cls() + if ext.is_available(): + ext.assert_compatible() + avail_attn_funcs.append((ext.load(), ext.name, True)) + for ext_cls in FlashAttentionWithCustomMaskLoader.REGISTRY: + ext = ext_cls() + if ext.is_available(): + ext.assert_compatible() + avail_custom_mask_attn_funcs.append((ext.load(), ext.name, True)) + for ext_cls in FlashAttentionWithPaddingMaskLoader.REGISTRY: + ext = ext_cls() + if ext.is_available(): + ext.assert_compatible() + avail_padding_mask_attn_funcs.append((ext.load(), ext.name, True)) + + test_sets = { + "none": (lambda dtype: ({}, None), avail_attn_funcs), + "padded": (gen_padded_kwargs, avail_padding_mask_attn_funcs), + "padded_causal": (gen_padded_causal_kwargs, avail_padding_mask_attn_funcs), + "causal": (gen_causal_kwargs, avail_attn_funcs), + "custom": (gen_custom_kwargs, avail_custom_mask_attn_funcs), + } + + for mask_type, (gen_kwargs_func, attn_funcs) in test_sets.items(): + attn_kwargs, padding_mask = gen_kwargs_func(dtype) + for attn_func, name, need_postprocess in attn_funcs: + print(f"{dtype}, {name}, {mask_type}") + if need_postprocess: + check_attn_func(dtype, attn_func, post_process_kwargs_for_raw_attn(attn_kwargs), padding_mask) + else: + check_attn_func(dtype, attn_func, attn_kwargs, padding_mask) + + +if __name__ == "__main__": + test_flash_attn_func()