Skip to content
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ vLLM is flexible and easy to use with:

vLLM seamlessly supports many Huggingface models, including the following architectures:

- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
Expand Down
4 changes: 3 additions & 1 deletion csrc/attention.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <torch/extension.h>
#include <c10/util/Optional.h>

void single_query_cached_kv_attention(
torch::Tensor& out,
Expand All @@ -9,7 +10,8 @@ void single_query_cached_kv_attention(
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len);
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
Expand Down
25 changes: 19 additions & 6 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ __global__ void single_query_cached_kv_attention_kernel(
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride) {
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
Expand All @@ -91,6 +92,7 @@ __global__ void single_query_cached_kv_attention_kernel(
const int head_idx = blockIdx.x;
const int num_heads = gridDim.x;
const int seq_idx = blockIdx.y;
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];

// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread group
Expand Down Expand Up @@ -167,12 +169,14 @@ __global__ void single_query_cached_kv_attention_kernel(

// Compute dot product.
// This includes a reduction across the threads in the same thread group.
const float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
const bool mask = token_idx >= context_len;

float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition seems unnecessary. If alibi_slope == 0, then alibi_slope * (token_idx - context_len) will be 0 as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It's to avoid the redundant computation of 0 * (token_idx - context_len).


if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= context_len;
logits[token_idx] = mask ? 0.f : qk;
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
Expand Down Expand Up @@ -328,6 +332,7 @@ __global__ void single_query_cached_kv_attention_kernel(
block_tables_ptr, \
context_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
query_stride);

// TODO(woosuk): Tune NUM_THREADS.
Expand All @@ -343,7 +348,8 @@ void single_query_cached_kv_attention_launcher(
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int max_context_len) {
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
Expand All @@ -353,6 +359,11 @@ void single_query_cached_kv_attention_launcher(
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);

// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = alibi_slopes ?
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;

T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
Expand Down Expand Up @@ -411,7 +422,8 @@ void single_query_cached_kv_attention_launcher(
scale, \
block_tables, \
context_lens, \
max_context_len);
max_context_len, \
alibi_slopes);

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
Expand Down Expand Up @@ -458,7 +470,8 @@ void single_query_cached_kv_attention(
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]
int block_size,
int max_context_len) {
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
if (query.dtype() == at::ScalarType::Float) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float);
} else if (query.dtype() == at::ScalarType::Half) {
Expand Down
3 changes: 3 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ Alongside each architecture, we include some popular models that use it.
* - Architecture
- Models
- Example HuggingFace Models
* - :code:`BloomForCausalLM`
- BLOOM, BLOOMZ, BLOOMChat
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
* - :code:`GPT2LMHeadModel`
- GPT-2
- :code:`gpt2`, :code:`gpt2-xl`, etc.
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/input_metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, List, Tuple

import torch
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from xformers.ops import AttentionBias

from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceData
Expand Down Expand Up @@ -38,7 +38,6 @@ def __init__(
self.max_context_len = max_context_len
self.block_tables = block_tables

self.attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens)
self.num_generation_tokens = context_lens.shape[0]
Expand All @@ -50,6 +49,9 @@ def __init__(
assert block_tables.shape[0] == self.num_generation_tokens
assert context_lens.shape[0] == self.num_generation_tokens

# Set during the execution of the first attention op.
self.attn_bias: List[AttentionBias] = []

def __repr__(self) -> str:
# Print only useful metadata.
return (f'InputMetadata('
Expand Down
133 changes: 128 additions & 5 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Multi-head attention."""
from typing import Optional
from typing import List, Optional

import torch
import torch.nn as nn
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
LowerTriangularMaskWithTensorBias)

from vllm import attention_ops
from vllm import cache_ops
Expand Down Expand Up @@ -53,13 +55,21 @@ def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")

def set_attn_bias(self, input_metadata: InputMetadata) -> None:
if input_metadata.attn_bias:
# Already set by a previous layer.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you choose this design, instead of explicitly initializing attn_bias in advance, say at the beginning of the forward function of BLOOM?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. It's because alibi_slopes is stored in the attention layer. If we want to create the attention bias in BloomForCausalLM, then we have to store alibi_slopes in both places (because alibi_slopes is also used for the decoding attention).

I kinda agree that this design is not ideal. But couldn't find a better way to do so.

return
prompt_lens = input_metadata.prompt_lens
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
input_metadata.attn_bias.append(attn_bias)

def multi_query_kv_attention(
self,
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: xops.AttentionBias,
input_metadata: InputMetadata,
) -> torch.Tensor:
"""Normal attention for the prompt tokens.

Expand All @@ -68,13 +78,14 @@ def multi_query_kv_attention(
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_heads, head_size]
value: shape = [num_prompt_tokens, num_heads, head_size]
input_metadata: metadata for paged attention.
"""
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
attn_bias=input_metadata.attn_bias[0],
p=0.0,
scale=self.scale,
op=self.attn_op,
Expand Down Expand Up @@ -112,6 +123,7 @@ def single_query_cached_kv_attention(
input_metadata.context_lens,
block_size,
input_metadata.max_context_len,
None, # alibi_slopes
)

def forward(
Expand Down Expand Up @@ -154,12 +166,13 @@ def forward(
# Compute the attention op for prompts.
num_prompt_tokens = input_metadata.num_prompt_tokens
if num_prompt_tokens > 0:
self.set_attn_bias(input_metadata)
self.multi_query_kv_attention(
output[:num_prompt_tokens],
query[:num_prompt_tokens],
key[:num_prompt_tokens],
value[:num_prompt_tokens],
input_metadata.attn_bias,
input_metadata,
)

# Wait until the cache op is done.
Expand Down Expand Up @@ -219,7 +232,8 @@ def __init__(
cache = torch.cat((cos, sin), dim=-1)

# FIXME(woosuk): This assumes that we configure the default dtype when
# initializing the model. Make it more robust.
# initializing the model.
# TODO(woosuk): Make it more robust.
torch_dtype = torch.get_default_dtype()
cache = cache.to(torch_dtype)
# Embedding size: [max_position, rotary_dim]
Expand Down Expand Up @@ -271,3 +285,112 @@ def forward(
input_metadata,
cache_event,
)


class PagedAttentionWithALiBi(PagedAttention):
"""PagedAttention with ALiBi attention bias."""

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
slopes: List[float],
) -> None:
super().__init__(num_heads, head_size, scale)
assert len(slopes) == num_heads

slopes = torch.tensor(slopes, dtype=torch.float32)
self.register_buffer("alibi_slopes", slopes, persistent=False)

def set_attn_bias(self, input_metadata: InputMetadata) -> None:
if input_metadata.attn_bias:
# Already set by a previous layer.
return
# Generates ALiBi mask for each prompt.
for prompt_len in input_metadata.prompt_lens:
bias = torch.arange(prompt_len)
bias = bias[None, :] - bias[:, None]
bias = bias.to(self.alibi_slopes.device)

# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
padded_len = (prompt_len + 7) // 8 * 8
bias = torch.empty(
self.num_heads,
padded_len,
padded_len,
device=self.alibi_slopes.device,
)[:, :prompt_len, :prompt_len].copy_(bias)
bias.mul_(self.alibi_slopes[:, None, None])
attn_bias = LowerTriangularMaskWithTensorBias(bias)
input_metadata.attn_bias.append(attn_bias)

def multi_query_kv_attention(
self,
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
"""Attention with ALiBi bias for the prompt tokens.

Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_heads, head_size]
value: shape = [num_prompt_tokens, num_heads, head_size]
input_metadata: metadata for paged attention.
"""
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
start = 0
for i, prompt_len in enumerate(input_metadata.prompt_lens):
end = start + prompt_len
out = xops.memory_efficient_attention_forward(
query[None, start:end],
key[None, start:end],
value[None, start:end],
attn_bias=input_metadata.attn_bias[i],
p=0.0,
scale=self.scale,
op=self.attn_op,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out.squeeze(0))
start += prompt_len
return output

def single_query_cached_kv_attention(
self,
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
) -> None:
"""PagedAttention with ALiBi bias for the generation tokens.

Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
input_metadata: metadata for paged attention.
"""
block_size = value_cache.shape[3]
attention_ops.single_query_cached_kv_attention(
output,
query,
key_cache,
value_cache,
self.scale,
input_metadata.block_tables,
input_metadata.context_lens,
block_size,
input_metadata.max_context_len,
self.alibi_slopes,
)
5 changes: 2 additions & 3 deletions vllm/model_executor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
from transformers import PretrainedConfig

from vllm.config import ModelConfig
from vllm.model_executor.models import (GPT2LMHeadModel, GPTBigCodeForCausalLM,
GPTNeoXForCausalLM, LlamaForCausalLM,
OPTForCausalLM)
from vllm.model_executor.models import * # pylint: disable=wildcard-import
from vllm.model_executor.weight_utils import initialize_dummy_weights

# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = {
"BloomForCausalLM": BloomForCausalLM,
"GPT2LMHeadModel": GPT2LMHeadModel,
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from vllm.model_executor.models.bloom import BloomForCausalLM
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM

__all__ = [
"BloomForCausalLM",
"GPT2LMHeadModel",
"GPTBigCodeForCausalLM",
"GPTNeoXForCausalLM",
Expand Down
Loading