-
-
Notifications
You must be signed in to change notification settings - Fork 11k
Add support for BLOOM #331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
218bbb5
a65960b
dc45776
b2f1167
cca8695
b691559
8aac4ed
18d54e6
093bfcd
047a0be
5d189ca
0bdc743
e11a911
2617adf
4529e7d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,11 @@ | ||
| """Multi-head attention.""" | ||
| from typing import Optional | ||
| from typing import List, Optional | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from xformers import ops as xops | ||
| from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, | ||
| LowerTriangularMaskWithTensorBias) | ||
|
|
||
| from vllm import attention_ops | ||
| from vllm import cache_ops | ||
|
|
@@ -53,13 +55,21 @@ def __init__(self, num_heads: int, head_size: int, scale: float) -> None: | |
| raise ValueError(f"head_size ({self.head_size}) is not supported. " | ||
| f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") | ||
|
|
||
| def set_attn_bias(self, input_metadata: InputMetadata) -> None: | ||
| if input_metadata.attn_bias: | ||
| # Already set by a previous layer. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you choose this design, instead of explicitly initializing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question. It's because I kinda agree that this design is not ideal. But couldn't find a better way to do so. |
||
| return | ||
| prompt_lens = input_metadata.prompt_lens | ||
| attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens) | ||
| input_metadata.attn_bias.append(attn_bias) | ||
|
|
||
| def multi_query_kv_attention( | ||
| self, | ||
| output: torch.Tensor, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| attn_bias: xops.AttentionBias, | ||
| input_metadata: InputMetadata, | ||
| ) -> torch.Tensor: | ||
| """Normal attention for the prompt tokens. | ||
|
|
||
|
|
@@ -68,13 +78,14 @@ def multi_query_kv_attention( | |
| query: shape = [num_prompt_tokens, num_heads, head_size] | ||
| key: shape = [num_prompt_tokens, num_heads, head_size] | ||
| value: shape = [num_prompt_tokens, num_heads, head_size] | ||
| input_metadata: metadata for paged attention. | ||
| """ | ||
| # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize. | ||
| out = xops.memory_efficient_attention_forward( | ||
| query.unsqueeze(0), | ||
| key.unsqueeze(0), | ||
| value.unsqueeze(0), | ||
| attn_bias=attn_bias, | ||
| attn_bias=input_metadata.attn_bias[0], | ||
| p=0.0, | ||
| scale=self.scale, | ||
| op=self.attn_op, | ||
|
|
@@ -112,6 +123,7 @@ def single_query_cached_kv_attention( | |
| input_metadata.context_lens, | ||
| block_size, | ||
| input_metadata.max_context_len, | ||
| None, # alibi_slopes | ||
| ) | ||
|
|
||
| def forward( | ||
|
|
@@ -154,12 +166,13 @@ def forward( | |
| # Compute the attention op for prompts. | ||
| num_prompt_tokens = input_metadata.num_prompt_tokens | ||
| if num_prompt_tokens > 0: | ||
| self.set_attn_bias(input_metadata) | ||
| self.multi_query_kv_attention( | ||
| output[:num_prompt_tokens], | ||
| query[:num_prompt_tokens], | ||
| key[:num_prompt_tokens], | ||
| value[:num_prompt_tokens], | ||
| input_metadata.attn_bias, | ||
| input_metadata, | ||
| ) | ||
|
|
||
| # Wait until the cache op is done. | ||
|
|
@@ -219,7 +232,8 @@ def __init__( | |
| cache = torch.cat((cos, sin), dim=-1) | ||
|
|
||
| # FIXME(woosuk): This assumes that we configure the default dtype when | ||
| # initializing the model. Make it more robust. | ||
| # initializing the model. | ||
| # TODO(woosuk): Make it more robust. | ||
| torch_dtype = torch.get_default_dtype() | ||
| cache = cache.to(torch_dtype) | ||
| # Embedding size: [max_position, rotary_dim] | ||
|
|
@@ -271,3 +285,112 @@ def forward( | |
| input_metadata, | ||
| cache_event, | ||
| ) | ||
|
|
||
|
|
||
| class PagedAttentionWithALiBi(PagedAttention): | ||
| """PagedAttention with ALiBi attention bias.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| num_heads: int, | ||
| head_size: int, | ||
| scale: float, | ||
| slopes: List[float], | ||
| ) -> None: | ||
| super().__init__(num_heads, head_size, scale) | ||
| assert len(slopes) == num_heads | ||
|
|
||
| slopes = torch.tensor(slopes, dtype=torch.float32) | ||
| self.register_buffer("alibi_slopes", slopes, persistent=False) | ||
|
|
||
| def set_attn_bias(self, input_metadata: InputMetadata) -> None: | ||
| if input_metadata.attn_bias: | ||
| # Already set by a previous layer. | ||
| return | ||
| # Generates ALiBi mask for each prompt. | ||
| for prompt_len in input_metadata.prompt_lens: | ||
| bias = torch.arange(prompt_len) | ||
| bias = bias[None, :] - bias[:, None] | ||
| bias = bias.to(self.alibi_slopes.device) | ||
|
|
||
| # When using custom attention bias, xformers requires the bias to | ||
| # be sliced from a tensor whose length is a multiple of 8. | ||
| padded_len = (prompt_len + 7) // 8 * 8 | ||
| bias = torch.empty( | ||
| self.num_heads, | ||
| padded_len, | ||
| padded_len, | ||
| device=self.alibi_slopes.device, | ||
| )[:, :prompt_len, :prompt_len].copy_(bias) | ||
| bias.mul_(self.alibi_slopes[:, None, None]) | ||
| attn_bias = LowerTriangularMaskWithTensorBias(bias) | ||
| input_metadata.attn_bias.append(attn_bias) | ||
|
|
||
| def multi_query_kv_attention( | ||
| self, | ||
| output: torch.Tensor, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| input_metadata: InputMetadata, | ||
| ) -> torch.Tensor: | ||
| """Attention with ALiBi bias for the prompt tokens. | ||
|
|
||
| Args: | ||
| output: shape = [num_prompt_tokens, num_heads, head_size] | ||
| query: shape = [num_prompt_tokens, num_heads, head_size] | ||
| key: shape = [num_prompt_tokens, num_heads, head_size] | ||
| value: shape = [num_prompt_tokens, num_heads, head_size] | ||
| input_metadata: metadata for paged attention. | ||
| """ | ||
| # FIXME(woosuk): Because xformers does not support dynamic sequence | ||
| # lengths with custom attention bias, we process each prompt one by | ||
| # one. This is inefficient, especially when we have many short prompts. | ||
| start = 0 | ||
| for i, prompt_len in enumerate(input_metadata.prompt_lens): | ||
| end = start + prompt_len | ||
| out = xops.memory_efficient_attention_forward( | ||
| query[None, start:end], | ||
| key[None, start:end], | ||
| value[None, start:end], | ||
| attn_bias=input_metadata.attn_bias[i], | ||
| p=0.0, | ||
| scale=self.scale, | ||
| op=self.attn_op, | ||
| ) | ||
| # TODO(woosuk): Unnecessary copy. Optimize. | ||
| output[start:end].copy_(out.squeeze(0)) | ||
| start += prompt_len | ||
| return output | ||
|
|
||
| def single_query_cached_kv_attention( | ||
| self, | ||
| output: torch.Tensor, | ||
| query: torch.Tensor, | ||
| key_cache: torch.Tensor, | ||
| value_cache: torch.Tensor, | ||
| input_metadata: InputMetadata, | ||
| ) -> None: | ||
| """PagedAttention with ALiBi bias for the generation tokens. | ||
|
|
||
| Args: | ||
| output: shape = [num_generation_tokens, num_heads, head_size] | ||
| query: shape = [num_generation_tokens, num_heads, head_size] | ||
| key_cache: shape = [num_blocks, num_heads, head_size/x, | ||
| block_size, x] | ||
| value_cache: shape = [num_blocks, num_heads, head_size, block_size] | ||
| input_metadata: metadata for paged attention. | ||
| """ | ||
| block_size = value_cache.shape[3] | ||
| attention_ops.single_query_cached_kv_attention( | ||
| output, | ||
| query, | ||
| key_cache, | ||
| value_cache, | ||
| self.scale, | ||
| input_metadata.block_tables, | ||
| input_metadata.context_lens, | ||
| block_size, | ||
| input_metadata.max_context_len, | ||
| self.alibi_slopes, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This condition seems unnecessary. If
alibi_slope == 0, thenalibi_slope * (token_idx - context_len)will be 0 as well.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. It's to avoid the redundant computation of
0 * (token_idx - context_len).