Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions tests/lora/test_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def test_baichuan_lora(baichuan_lora_files):


@pytest.mark.skip("Requires multiple GPUs")
def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_baichuan_tensor_parallel_equality(baichuan_lora_files, fully_sharded):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 4:
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
Expand All @@ -75,7 +76,8 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=1,
trust_remote_code=True)
trust_remote_code=True,
fully_sharded_loras=fully_sharded)
output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1)

del llm_tp1
Expand All @@ -87,7 +89,8 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=2,
trust_remote_code=True)
trust_remote_code=True,
fully_sharded_loras=fully_sharded)
output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2)

del llm_tp2
Expand All @@ -101,10 +104,11 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True)
trust_remote_code=True,
fully_sharded_loras=fully_sharded)
output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2)

del llm_tp4
cleanup()

assert output_tp1 == output_tp4
assert output_tp1 == output_tp4
7 changes: 5 additions & 2 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
RowParallelLinearWithShardedLoRA)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
Expand Down Expand Up @@ -684,7 +685,9 @@ def create_column_parallel_packed_layer():
bias=False,
params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = QKVParallelLinearWithLora(linear)
lora_linear = QKVParallelLinearWithLora(
linear
) if not fully_shard else QKVParallelLinearWithShardedLora(linear)

@dataclass
class FakeConfig:
Expand Down
58 changes: 55 additions & 3 deletions vllm/lora/fully_sharded_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
RowParallelLinearWithLoRA)
from vllm.lora.punica import bgmv, dispatch_bgmv_low_level

Expand Down Expand Up @@ -90,11 +91,11 @@ def can_replace_layer(cls, source_layer: nn.Module,
def _mcp_apply(x, bias, layer):
"""
MergedColumnParallelLinearWithShardedLoRA and
QKVParallelLinearWithShardedLora share the same
MergedQKVParallelLinearWithShardedLora share the same
LoRa weight application method.

The main difference is the step by shard_size for lora_b which can
vary for QKVParallelLinearWithShardedLora but is constant for
vary for MergedQKVParallelLinearWithShardedLora but is constant for
MergedColumnParallelLinearWithShardedLoRA.
"""
# expecting 2 for column parallel and 3 for qkv
Expand Down Expand Up @@ -167,14 +168,65 @@ def can_replace_layer(cls, source_layer: nn.Module,
)


class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
"""
Differs from QKVParallelLinearWithLora by slicing the
LoRA A's also.

Based on S-LoRA, slicing happens along the rank dim.
"""

def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked.shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a

def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)

x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device)

bgmv(buffer, x, self.lora_a_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
buffer = tensor_model_parallel_all_gather(buffer)
bgmv(output, buffer, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
# now have column partitioned output

output = output.view(*out_orig_shape)
return output

@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)


class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
"""
Differs from MergedQKVParallelLinearWithLora by slicing the
LoRA A's also.

Based on S-LoRA, slicing happens along the rank dim.
"""

def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
Expand Down
36 changes: 21 additions & 15 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,24 @@ def __init__(self, base_layer: QKVParallelLinear) -> None:
self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
self.base_layer.head_size)

def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
k_offset = self.q_proj_total_size
lora_b_k = lora_b[:, k_offset +
self.kv_proj_shard_size * self.kv_shard_id:k_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
v_offset = k_offset + self.kv_proj_total_size
lora_b_v = lora_b[:, v_offset +
self.kv_proj_shard_size * self.kv_shard_id:v_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
return lora_b

def set_lora(
self,
index: int,
Expand All @@ -650,21 +668,8 @@ def set_lora(
):
self.reset_lora(index)
if self.tp_size > 1:
tp_rank = get_tensor_model_parallel_rank()
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
k_offset = self.q_proj_total_size
lora_b_k = lora_b[:, k_offset + self.kv_proj_shard_size *
self.kv_shard_id:k_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
v_offset = k_offset + self.kv_proj_total_size
lora_b_v = lora_b[:, v_offset + self.kv_proj_shard_size *
self.kv_shard_id:v_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)

self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
Expand All @@ -674,6 +679,7 @@ def set_lora(
lora_b.T, non_blocking=True)

@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
Expand Down
4 changes: 3 additions & 1 deletion vllm/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
RowParallelLinearWithShardedLoRA)
# being imported for _all_lora_classes below
# yapf conflicts with isort for this block
# yapf: disable
Expand All @@ -35,6 +36,7 @@
RowParallelLinearWithLoRA,
LogitsProcessorWithLoRA,
ColumnParallelLinearWithShardedLoRA,
QKVParallelLinearWithShardedLora,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora,
RowParallelLinearWithShardedLoRA,
Expand Down