Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions tests/models/encoder_decoder/vision_language/test_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
from ....quantization.utils import is_quant_method_supported
from ....utils import large_gpu_test
from ...utils import check_logprobs_close

Expand Down Expand Up @@ -397,6 +398,50 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
)


@large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["float16"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
def test_bnb_regression(
image_assets: _ImageAssets,
model: str,
dtype: str,
max_tokens: int,
):
stop_sign = image_assets[0].pil_image
prompts = [
{
"prompt": "<|begin_of_text|>The content of the image <|image|> is",
"multi_modal_data": {
"image": stop_sign
},
},
{
"prompt":
"The color of the sky is blue but sometimes it can also be",
},
]
# Test regression about QKVCrossParallelLinear
llm = LLM(
model=model,
dtype=dtype,
max_model_len=4096,
max_num_seqs=2,
enforce_eager=True,
quantization="bitsandbytes",
load_format="bitsandbytes",
)
sampling_params = SamplingParams(
temperature=0,
max_tokens=max_tokens,
)
outputs = llm.generate(prompts, sampling_params)
assert outputs


@large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
Expand Down
213 changes: 178 additions & 35 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import itertools
from abc import abstractmethod
from typing import Optional, Union
from typing import Any, Literal, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter

Expand Down Expand Up @@ -84,6 +85,43 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
return param[shard_id], loaded_weight


# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
"""
Separate the BitsAndBytes 4-bit shard.

For example, given bnb weight attributes as below:
{
'bnb_shard_offsets': array([0, 4, 8, 16]),
'bnb_quant_state': {0: ..., 1: ..., 2: ...},
}

The function will return:
{
'bnb_shard_offsets': array([0, 4]),
'bnb_quant_state': {0: ...},
}
and
{
'bnb_shard_offsets': array([0, 4, 12]),
'bnb_quant_state': {0: ..., 1: ...},
}
"""
shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
offset_l = shard_offsets[:2]
offset_r = shard_offsets[1:] - shard_offsets[1]
quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
quant_state_r = {
i - 1: bnb_weight_attrs["bnb_quant_state"][i]
for i in range(1,
len(shard_offsets) - 1)
}
left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
return left, right


class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""

Expand Down Expand Up @@ -1229,7 +1267,24 @@ def extra_repr(self) -> str:
return s


class QKVCrossParallelLinear(torch.nn.Module):
class QKVCrossParallelLinear(LinearBase):
"""Linear layers for efficient cross-attention's QKV transformation.

Args:
hidden_size: input hidden state size of the transformer.
head_size: size of each attention head.
total_num_heads: total number of attention query heads.
total_num_kv_heads: total number of attention key/value heads. If
None, assume total_num_kv_heads = total_num_heads.
bias: If true, add bias.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""

def __init__(self,
hidden_size: int,
Expand All @@ -1241,12 +1296,28 @@ def __init__(self,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
# input_size and output_size are not used, just for alignment
input_size = hidden_size
output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
super().__init__(input_size=input_size,
output_size=output_size,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix)

self.quant_config = quant_config

# Empty placeholders for loading as a single module.
self.weight = torch.nn.Parameter()
set_weight_attrs(self.weight, {
"weight_loader": self.weight_loader_weight,
})
placeholder_size = 0
assert self.quant_method is not None
self.quant_method.create_weights(self,
placeholder_size, [placeholder_size],
placeholder_size,
placeholder_size,
self.params_dtype,
weight_loader=self.weight_loader)

# Use a dictionary to avoid submodules parameters auto-registration:
# drop-in replacement for a `QKVParallelLinear` module.
self.proj = dict()
Expand Down Expand Up @@ -1276,18 +1347,94 @@ def __init__(self,
if bias:
self.bias = torch.nn.Parameter()
set_weight_attrs(self.bias, {
"weight_loader": self.weight_loader_bias,
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.bias = None

@property
def q_proj_decoder(self):
return self.proj["q_proj_decoder"]
def q_proj_decoder(self) -> ColumnParallelLinear:
layer = self.proj["q_proj_decoder"]
for name, param in self.named_parameters():
target_param = getattr(layer, name)
self.sync_weight_attrs(param, target_param, mode="q_proj_decoder")
return layer

@property
def kv_proj_encoder(self):
return self.proj["kv_proj_encoder"]
def kv_proj_encoder(self) -> QKVParallelLinear:
layer = self.proj["kv_proj_encoder"]
for name, param in self.named_parameters():
target_param = getattr(layer, name)
self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder")
return layer

def sync_weight_attrs(
self,
src_param: nn.Parameter,
tgt_param: nn.Parameter,
mode: Literal["q_proj_decoder", "kv_proj_encoder"],
):
missing_attrs_dict = {
k: getattr(src_param, k)
for k in (set(src_param.__dict__.keys()) -
set(tgt_param.__dict__.keys()))
}
# TODO(Isotr0py): handle bitsandbytes 8bit
use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit",
False)
if (missing_attrs_dict and use_bitsandbytes_4bit):
q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
missing_attrs_dict)
if mode == "q_proj_decoder":
set_weight_attrs(tgt_param, q_proj_attrs)
elif mode == "kv_proj_encoder":
set_weight_attrs(tgt_param, kv_proj_attrs)
else:
set_weight_attrs(tgt_param, missing_attrs_dict)

def forward(self, decoder_hidden_states, encoder_hidden_states):
def _is_same_param(
self,
src_param: torch.nn.Parameter,
map_param: torch.nn.Parameter,
) -> bool:
"""Check if two parameters are exactly pointing to same things."""
# ignore weight_loader because it's always different
key_to_ignore = ["weight_loader", "_weight_loader"]
has_same_type_name = type(src_param) is type(map_param)
src_param_attrs = {
k: v
for k, v in src_param.__dict__.items() if k not in key_to_ignore
}
map_param_attrs = {
k: v
for k, v in map_param.__dict__.items() if k not in key_to_ignore
}
has_same_attrs = src_param_attrs == map_param_attrs
return has_same_type_name and has_same_attrs

def select_proj_params(
self,
layer: nn.Module,
param: nn.Parameter,
) -> nn.Parameter:
"""
Given the placeholder param,
return the corresponding param in the proj layers.
"""
target_param_list = [
v for _, v in layer.named_parameters()
if self._is_same_param(param, v)
]
assert len(target_param_list) == 1
target_param = target_param_list[0]
return target_param

def forward( # type: ignore[override]
self,
decoder_hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
q, _ = self.q_proj_decoder(decoder_hidden_states)
if encoder_hidden_states is None:
# Encoder KV already cached.
Expand All @@ -1300,25 +1447,21 @@ def forward(self, decoder_hidden_states, encoder_hidden_states):
k, v = kv_enc.split(self.kv_size, dim=-1)
return q, k, v

def weight_loader_weight(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
# NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
param = self.q_proj_decoder.weight if loaded_shard_id == "q" \
else self.kv_proj_encoder.weight
param.weight_loader(
param,
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
param, loaded_weight, loaded_shard_id)

def weight_loader_bias(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
param = self.q_proj_decoder.bias if loaded_shard_id == "q" \
else self.kv_proj_encoder.bias
param.weight_loader(
param,
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
param, loaded_weight, loaded_shard_id)
def weight_loader(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
layer = (self.q_proj_decoder
if loaded_shard_id == "q" else self.kv_proj_encoder)
target_param = self.select_proj_params(layer, param)
shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
layer.weight_loader(target_param, loaded_weight, *shard_id_args)

def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", q_size={self.q_proj_decoder.output_size_per_partition}"
s += f", kv_size={self.kv_size}"
s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}"
s += ", gather_output=False"
return s
Loading