diff --git a/vllm/config.py b/vllm/config.py index f163665e2c06..f647318e9d6a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1678,14 +1678,10 @@ class LoRAConfig: bias_enabled: bool = False def __post_init__(self): - # Setting the maximum rank to 256 should be able to satisfy the vast - # majority of applications. - possible_max_ranks = (8, 16, 32, 64, 128, 256) possible_lora_extra_vocab_size = (0, 256, 512) - if self.max_lora_rank not in possible_max_ranks: + if self.max_lora_rank < 1: raise ValueError( - f"max_lora_rank ({self.max_lora_rank}) must be one of " - f"{possible_max_ranks}.") + f"max_lora_rank ({self.max_lora_rank}) must be >= 1.") if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size: raise ValueError( f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) " diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 3701988ff692..66827902995f 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1,4 +1,5 @@ # pylint: disable=unused-argument +import copy import math from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union @@ -130,6 +131,13 @@ class LoRAMapping(AdapterMapping): class BaseLayerWithLoRA(nn.Module): + # Initialized following static typing. + _create_lora_weights_args: Tuple[int, LoRAConfig, + Optional[PretrainedConfig]] = ( + 0, + LoRAConfig(1, 1), + None, + ) def slice_lora_a( self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]] @@ -156,11 +164,18 @@ def reset_lora(self, index: int): """Resets the lora weights at index back to 0.""" ... + def update_max_lora_rank( + self, + lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]], + ): + """Updates max lora rank if larger lora matrices are given.""" + ... + def set_lora( self, index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, + lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]], + lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]], embeddings_tensor: Optional[torch.Tensor], bias: Optional[torch.Tensor] = None, ): @@ -194,11 +209,14 @@ def __init__(self, base_layer: VocabParallelEmbedding) -> None: self.embeddings_weights: Optional[torch.Tensor] def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None) -> None: - + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + self._create_lora_weights_args = (max_loras, + copy.deepcopy(lora_config), + copy.deepcopy(model_config)) if self.base_layer.num_added_embeddings_per_partition > 0: # We can start adding lora weights self.embeddings_weights = self.base_layer.weight.data[ @@ -255,6 +273,14 @@ def reset_lora(self, index: int): self.lora_b_stacked[index] = 0 self.embeddings_tensors[index] = 0 + def update_max_lora_rank( + self, + lora_a: torch.Tensor, + ): + if lora_a.shape[1] > self._create_lora_weights_args[1].max_lora_rank: + self._create_lora_weights_args[1].max_lora_rank = lora_a.shape[1] + self.create_lora_weights(*self._create_lora_weights_args) + def set_lora( self, index: int, @@ -264,6 +290,8 @@ def set_lora( bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) + self.update_max_lora_rank(lora_a) + self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( lora_a, non_blocking=True) self.lora_b_stacked[index, @@ -340,6 +368,9 @@ def create_lora_weights( lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, ) -> None: + self._create_lora_weights_args = (max_loras, + copy.deepcopy(lora_config), + copy.deepcopy(model_config)) self.lora_config = lora_config lora_a_output_size = lora_config.max_lora_rank self.lora_a_stacked = torch.zeros( @@ -375,6 +406,14 @@ def reset_lora(self, index: int): if self.lora_config.bias_enabled: self.bias_stacked[index] = 0 + def update_max_lora_rank( + self, + lora_a: torch.Tensor, + ): + if lora_a.shape[1] > self._create_lora_weights_args[1].max_lora_rank: + self._create_lora_weights_args[1].max_lora_rank = lora_a.shape[1] + self.create_lora_weights(*self._create_lora_weights_args) + def set_lora( self, index: int, @@ -384,6 +423,7 @@ def set_lora( bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) + self.update_max_lora_rank(lora_a) self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( @@ -469,6 +509,9 @@ def create_lora_weights( lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, ) -> None: + self._create_lora_weights_args = (max_loras, + copy.deepcopy(lora_config), + copy.deepcopy(model_config)) self.lora_config = lora_config self.tp_size = get_tensor_model_parallel_world_size() lora_a_output_size_per_partition = ( @@ -547,6 +590,21 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: bias = bias[start_idx:end_idx] return bias + def update_max_lora_rank( + self, + lora_a: torch.Tensor, + ) -> None: + if (self.lora_config.fully_sharded_loras + and lora_a.shape[1] * self.tp_size > + self.lora_config.max_lora_rank): + self._create_lora_weights_args[1].max_lora_rank = ( + lora_a.shape[1] * self.tp_size) + self.create_lora_weights(*self._create_lora_weights_args) + elif (not self.lora_config.fully_sharded_loras + and lora_a.shape[1] > self.lora_config.max_lora_rank): + self._create_lora_weights_args[1].max_lora_rank = lora_a.shape[1] + self.create_lora_weights(*self._create_lora_weights_args) + def set_lora( self, index: int, @@ -556,6 +614,7 @@ def set_lora( bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) + self.update_max_lora_rank(lora_a) if self.tp_size > 1: lora_a = self.slice_lora_a(lora_a) @@ -643,6 +702,9 @@ def create_lora_weights( lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, ) -> None: + self._create_lora_weights_args = (max_loras, + copy.deepcopy(lora_config), + copy.deepcopy(model_config)) self.lora_config = lora_config n_slices = 2 if not (len(self.base_layer.output_sizes) == n_slices @@ -730,6 +792,23 @@ def slice_bias( ] return bias + def update_max_lora_rank(self, lora_a: List[Union[torch.Tensor, + None]]) -> None: + for tensor in lora_a: + if tensor is None: + continue + if (self.lora_config.fully_sharded_loras + and tensor.shape[1] * self.tp_size > + self.lora_config.max_lora_rank): + self._create_lora_weights_args[1].max_lora_rank = ( + tensor.shape[1] * self.tp_size) + self.create_lora_weights(*self._create_lora_weights_args) + elif (not self.lora_config.fully_sharded_loras + and tensor.shape[1] > self.lora_config.max_lora_rank): + self._create_lora_weights_args[1].max_lora_rank = ( + tensor.shape[1]) + self.create_lora_weights(*self._create_lora_weights_args) + def set_lora( self, index: int, @@ -739,6 +818,7 @@ def set_lora( bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) + self.update_max_lora_rank(lora_a) if self.tp_size > 1: lora_a = self.slice_lora_a(lora_a) @@ -865,6 +945,8 @@ def set_lora( bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) + self.update_max_lora_rank(lora_a) + if self.tp_size > 1: lora_a = self.slice_lora_a(lora_a) lora_b = self.slice_lora_b(lora_b) @@ -911,6 +993,9 @@ def create_lora_weights( lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, ) -> None: + self._create_lora_weights_args = (max_loras, + copy.deepcopy(lora_config), + copy.deepcopy(model_config)) self.lora_config = lora_config self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() @@ -1070,15 +1155,33 @@ def slice_bias( bias = [bias_q, bias_k, bias_v] return bias + def update_max_lora_rank(self, lora_a: List[Union[torch.Tensor, + None]]) -> None: + for tensor in lora_a: + if tensor is None: + continue + if (self.lora_config.fully_sharded_loras + and tensor.shape[1] * self.tp_size > + self.lora_config.max_lora_rank): + self._create_lora_weights_args[1].max_lora_rank = ( + tensor.shape[1] * self.tp_size) + self.create_lora_weights(*self._create_lora_weights_args) + elif (not self.lora_config.fully_sharded_loras + and tensor.shape[1] > self.lora_config.max_lora_rank): + self._create_lora_weights_args[1].max_lora_rank = ( + tensor.shape[1]) + self.create_lora_weights(*self._create_lora_weights_args) + def set_lora( self, index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, + lora_a: List[Union[torch.Tensor, None]], + lora_b: List[Union[torch.Tensor, None]], embeddings_tensor: Optional[torch.Tensor], bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) + self.update_max_lora_rank(lora_a) if self.tp_size > 1: lora_a = self.slice_lora_a(lora_a) @@ -1171,6 +1274,9 @@ def create_lora_weights( lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, ) -> None: + self._create_lora_weights_args = (max_loras, + copy.deepcopy(lora_config), + copy.deepcopy(model_config)) self.lora_config = lora_config self.tp_rank = get_tensor_model_parallel_rank() self.lora_a_stacked = torch.zeros( @@ -1235,6 +1341,14 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: return bias + def update_max_lora_rank( + self, + lora_a: torch.Tensor, + ): + if lora_a.shape[1] > self._create_lora_weights_args[1].max_lora_rank: + self._create_lora_weights_args[1].max_lora_rank = lora_a.shape[1] + self.create_lora_weights(*self._create_lora_weights_args) + def set_lora( self, index: int, @@ -1244,6 +1358,7 @@ def set_lora( bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) + self.update_max_lora_rank(lora_a) if self.base_layer.tp_size > 1: lora_a = self.slice_lora_a(lora_a) @@ -1399,6 +1514,9 @@ def create_lora_weights( if 32000 < self.base_layer.vocab_size > 257024: raise ValueError("When using LoRA, vocab size must be " "32000 >= vocab_size <= 257024") + self._create_lora_weights_args = (max_loras, + copy.deepcopy(lora_config), + copy.deepcopy(model_config)) self.lora_a_stacked = torch.zeros( ( max_loras, @@ -1441,6 +1559,14 @@ def reset_lora(self, index: int): self.lora_b_stacked[index] = 0 self.embeddings_tensors[index] = float("-inf") + def update_max_lora_rank( + self, + lora_a: torch.Tensor, + ): + if lora_a.shape[1] > self._create_lora_weights_args[1].max_lora_rank: + self._create_lora_weights_args[1].max_lora_rank = lora_a.shape[1] + self.create_lora_weights(*self._create_lora_weights_args) + def set_lora( self, index: int, @@ -1450,6 +1576,8 @@ def set_lora( bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) + self.update_max_lora_rank(lora_a) + self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( lora_a.T, non_blocking=True) diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 93a5e2762191..463cce5fd393 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -106,10 +106,6 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: ) except Exception as e: raise RuntimeError(f"Loading lora {lora_path} failed") from e - if lora.rank > self.lora_config.max_lora_rank: - raise ValueError( - f"LoRA rank {lora.rank} is greater than max_lora_rank " - f"{self.lora_config.max_lora_rank}.") if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} " f"is greater than lora_extra_vocab_size "