11"""Multi-head attention."""
2- from typing import Optional
2+ from typing import List , Optional
33
44import torch
55import torch .nn as nn
66from xformers import ops as xops
7+ from xformers .ops .fmha .attn_bias import (BlockDiagonalCausalMask ,
8+ LowerTriangularMaskWithTensorBias )
79
810from vllm import attention_ops
911from vllm import cache_ops
@@ -53,13 +55,21 @@ def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
5355 raise ValueError (f"head_size ({ self .head_size } ) is not supported. "
5456 f"Supported head sizes: { _SUPPORTED_HEAD_SIZES } ." )
5557
58+ def set_attn_bias (self , input_metadata : InputMetadata ) -> None :
59+ if input_metadata .attn_bias :
60+ # Already set by a previous layer.
61+ return
62+ prompt_lens = input_metadata .prompt_lens
63+ attn_bias = BlockDiagonalCausalMask .from_seqlens (prompt_lens )
64+ input_metadata .attn_bias .append (attn_bias )
65+
5666 def multi_query_kv_attention (
5767 self ,
5868 output : torch .Tensor ,
5969 query : torch .Tensor ,
6070 key : torch .Tensor ,
6171 value : torch .Tensor ,
62- attn_bias : xops . AttentionBias ,
72+ input_metadata : InputMetadata ,
6373 ) -> torch .Tensor :
6474 """Normal attention for the prompt tokens.
6575
@@ -68,13 +78,14 @@ def multi_query_kv_attention(
6878 query: shape = [num_prompt_tokens, num_heads, head_size]
6979 key: shape = [num_prompt_tokens, num_heads, head_size]
7080 value: shape = [num_prompt_tokens, num_heads, head_size]
81+ input_metadata: metadata for paged attention.
7182 """
7283 # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
7384 out = xops .memory_efficient_attention_forward (
7485 query .unsqueeze (0 ),
7586 key .unsqueeze (0 ),
7687 value .unsqueeze (0 ),
77- attn_bias = attn_bias ,
88+ attn_bias = input_metadata . attn_bias [ 0 ] ,
7889 p = 0.0 ,
7990 scale = self .scale ,
8091 op = self .attn_op ,
@@ -112,6 +123,7 @@ def single_query_cached_kv_attention(
112123 input_metadata .context_lens ,
113124 block_size ,
114125 input_metadata .max_context_len ,
126+ None , # alibi_slopes
115127 )
116128
117129 def forward (
@@ -154,12 +166,13 @@ def forward(
154166 # Compute the attention op for prompts.
155167 num_prompt_tokens = input_metadata .num_prompt_tokens
156168 if num_prompt_tokens > 0 :
169+ self .set_attn_bias (input_metadata )
157170 self .multi_query_kv_attention (
158171 output [:num_prompt_tokens ],
159172 query [:num_prompt_tokens ],
160173 key [:num_prompt_tokens ],
161174 value [:num_prompt_tokens ],
162- input_metadata . attn_bias ,
175+ input_metadata ,
163176 )
164177
165178 # Wait until the cache op is done.
@@ -219,7 +232,8 @@ def __init__(
219232 cache = torch .cat ((cos , sin ), dim = - 1 )
220233
221234 # FIXME(woosuk): This assumes that we configure the default dtype when
222- # initializing the model. Make it more robust.
235+ # initializing the model.
236+ # TODO(woosuk): Make it more robust.
223237 torch_dtype = torch .get_default_dtype ()
224238 cache = cache .to (torch_dtype )
225239 # Embedding size: [max_position, rotary_dim]
@@ -271,3 +285,112 @@ def forward(
271285 input_metadata ,
272286 cache_event ,
273287 )
288+
289+
290+ class PagedAttentionWithALiBi (PagedAttention ):
291+ """PagedAttention with ALiBi attention bias."""
292+
293+ def __init__ (
294+ self ,
295+ num_heads : int ,
296+ head_size : int ,
297+ scale : float ,
298+ slopes : List [float ],
299+ ) -> None :
300+ super ().__init__ (num_heads , head_size , scale )
301+ assert len (slopes ) == num_heads
302+
303+ slopes = torch .tensor (slopes , dtype = torch .float32 )
304+ self .register_buffer ("alibi_slopes" , slopes , persistent = False )
305+
306+ def set_attn_bias (self , input_metadata : InputMetadata ) -> None :
307+ if input_metadata .attn_bias :
308+ # Already set by a previous layer.
309+ return
310+ # Generates ALiBi mask for each prompt.
311+ for prompt_len in input_metadata .prompt_lens :
312+ bias = torch .arange (prompt_len )
313+ bias = bias [None , :] - bias [:, None ]
314+ bias = bias .to (self .alibi_slopes .device )
315+
316+ # When using custom attention bias, xformers requires the bias to
317+ # be sliced from a tensor whose length is a multiple of 8.
318+ padded_len = (prompt_len + 7 ) // 8 * 8
319+ bias = torch .empty (
320+ self .num_heads ,
321+ padded_len ,
322+ padded_len ,
323+ device = self .alibi_slopes .device ,
324+ )[:, :prompt_len , :prompt_len ].copy_ (bias )
325+ bias .mul_ (self .alibi_slopes [:, None , None ])
326+ attn_bias = LowerTriangularMaskWithTensorBias (bias )
327+ input_metadata .attn_bias .append (attn_bias )
328+
329+ def multi_query_kv_attention (
330+ self ,
331+ output : torch .Tensor ,
332+ query : torch .Tensor ,
333+ key : torch .Tensor ,
334+ value : torch .Tensor ,
335+ input_metadata : InputMetadata ,
336+ ) -> torch .Tensor :
337+ """Attention with ALiBi bias for the prompt tokens.
338+
339+ Args:
340+ output: shape = [num_prompt_tokens, num_heads, head_size]
341+ query: shape = [num_prompt_tokens, num_heads, head_size]
342+ key: shape = [num_prompt_tokens, num_heads, head_size]
343+ value: shape = [num_prompt_tokens, num_heads, head_size]
344+ input_metadata: metadata for paged attention.
345+ """
346+ # FIXME(woosuk): Because xformers does not support dynamic sequence
347+ # lengths with custom attention bias, we process each prompt one by
348+ # one. This is inefficient, especially when we have many short prompts.
349+ start = 0
350+ for i , prompt_len in enumerate (input_metadata .prompt_lens ):
351+ end = start + prompt_len
352+ out = xops .memory_efficient_attention_forward (
353+ query [None , start :end ],
354+ key [None , start :end ],
355+ value [None , start :end ],
356+ attn_bias = input_metadata .attn_bias [i ],
357+ p = 0.0 ,
358+ scale = self .scale ,
359+ op = self .attn_op ,
360+ )
361+ # TODO(woosuk): Unnecessary copy. Optimize.
362+ output [start :end ].copy_ (out .squeeze (0 ))
363+ start += prompt_len
364+ return output
365+
366+ def single_query_cached_kv_attention (
367+ self ,
368+ output : torch .Tensor ,
369+ query : torch .Tensor ,
370+ key_cache : torch .Tensor ,
371+ value_cache : torch .Tensor ,
372+ input_metadata : InputMetadata ,
373+ ) -> None :
374+ """PagedAttention with ALiBi bias for the generation tokens.
375+
376+ Args:
377+ output: shape = [num_generation_tokens, num_heads, head_size]
378+ query: shape = [num_generation_tokens, num_heads, head_size]
379+ key_cache: shape = [num_blocks, num_heads, head_size/x,
380+ block_size, x]
381+ value_cache: shape = [num_blocks, num_heads, head_size, block_size]
382+ input_metadata: metadata for paged attention.
383+ """
384+ block_size = value_cache .shape [3 ]
385+ attention_ops .single_query_cached_kv_attention (
386+ output ,
387+ query ,
388+ key_cache ,
389+ value_cache ,
390+ self .scale ,
391+ input_metadata .block_tables ,
392+ input_metadata .context_lens ,
393+ block_size ,
394+ input_metadata .max_context_len ,
395+ self .alibi_slopes ,
396+ )
0 commit comments