From 53d83bfaa0a7af743a2d8706860f576e3706da49 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 30 Jun 2025 06:18:43 +0000 Subject: [PATCH] Done Signed-off-by: Jee Jee Li --- .../model_loader/bitsandbytes_loader.py | 123 ++++++++++-------- 1 file changed, 72 insertions(+), 51 deletions(-) diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 09857ef297f0..0c46d170e88d 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -20,8 +20,6 @@ get_tensor_model_parallel_world_size) # yapf: enable from vllm.logger import init_logger -# yapf conflicts with isort for this block -# yapf: disable from vllm.model_executor.layers.linear import (LinearBase, MergedColumnParallelLinear, QKVParallelLinear, @@ -39,6 +37,8 @@ set_weight_attrs) from vllm.platforms import current_platform +# yapf conflicts with isort for this block + logger = init_logger(__name__) @@ -54,11 +54,17 @@ def __init__(self, load_config: LoadConfig): self.unsharded_weights_modules: list[str] = [] # Save the module names that are sharded by column. self.column_sharded_weights_modules: list[str] = [] + # Modules whose weights might have fused on disk + # we need their output_sizes to make shard in flight correctly with TP + self.maybe_fused_weights_modules: dict[str, list[int]] = {} # Store all module names (from transformers) that support # BNB quantization. self.target_modules: list[str] = [] # mapping weight names from transformers to vllm. self.weight_mapper: Callable = lambda name: name + self.pre_quant: bool = False + self.load_8bit: bool = False + self.is_pool_model: bool = False def _get_weight_files( self, @@ -134,13 +140,14 @@ def _prepare_weights(self, model_name_or_path: str, return hf_weights_files, use_safetensors def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): - def _maybe_pool_model(module_name:str): + + def _maybe_pool_model(module_name: str): # For pool model, we need to add the prefix `model.` # for the weight name if possible. if self.is_pool_model and self.target_modules[0]. \ startswith("model.") and not module_name.startswith( "model."): - return "model."+module_name + return "model." + module_name return module_name @@ -159,8 +166,7 @@ def _maybe_pool_model(module_name:str): # mapping weight names from transformers to vllm while preserving # original names. mapped_name = self.weight_mapper(org_name) - mapped_name=_maybe_pool_model(mapped_name) - + mapped_name = _maybe_pool_model(mapped_name) yield org_name, mapped_name, param @@ -168,8 +174,6 @@ def _get_quantized_weights_iterator( self, model_name_or_path: str, revision: Optional[str], - pre_quant: bool, - load_8bit: bool, ) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str, Any]]: """Get an iterator to the model weights with bitsandbytes quantization, @@ -192,8 +196,8 @@ def _get_quantized_weights_iterator( quant_state_dict: dict[str, Any] = {} - if pre_quant: - if load_8bit: + if self.pre_quant: + if self.load_8bit: return self._quantized_8bit_generator( hf_weights_files, use_safetensors, quant_state_dict), quant_state_dict @@ -390,10 +394,13 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors, yield org_weight_name, processed_weight def _get_bnb_target_modules(self, model: nn.Module) -> None: - + """ + Identify and collect all modules that support BitsAndBytes + quantization. + """ for name, module in model.named_modules(): - if (isinstance(module, LinearBase) and - hasattr(module.quant_method, "quant_config")): + if (isinstance(module, LinearBase) + and hasattr(module.quant_method, "quant_config")): if modules_info := self.modules_mapping.get_sub_modules(name): # Map vllm's names to transformers's names. rep_name, sub_modules = modules_info @@ -409,29 +416,11 @@ def _get_bnb_target_modules(self, model: nn.Module) -> None: ), "vllm currently does not support BNB quantization for" f" {type(model).__name__}" - def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: - if not hasattr(model, "load_weights"): - raise AttributeError( - "The required method 'load_weights' is not defined in class" - f" {type(model).__name__}.") - - if not hasattr(model, "packed_modules_mapping"): - raise AttributeError( - f"Model {type(model).__name__} does not support BitsAndBytes " - "quantization yet. No 'packed_modules_mapping' found.") - self.is_pool_model=is_pooling_model(model) - - self.modules_mapping = ParamMapping(get_packed_modules_mapping(model)) - - # For some models like Molmo, we need to use hf_to_vllm_mapper - # to ensure correct loading of weights. - if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): - self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name) - - # Modules whose weights might have fused on disk - # we need their output_sizes to make shard in flight correctly with TP - self.maybe_fused_weights_modules: dict[str, list[int]] = {} - self._get_bnb_target_modules(model) + def _classify_module_sharding(self, model: nn.Module): + """ + Categorize modules based on their weight sharding requirements + for tensor parallelism. + """ for name, module in model.named_modules(): # Some modules like `ReplicatedLinear` should not have their weights # sharded. The reason for implementing it this way is to avoid new @@ -449,19 +438,27 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: elif isinstance(module, (RowParallelLinear, )): self.column_sharded_weights_modules.append(name) - self.model_type = type(model).__name__ + def _verify_model_compatibility(self, model: nn.Module, + model_config: ModelConfig) -> None: + """ + Verify that the model is compatible with BitsAndBytes quantization. + """ + if not hasattr(model, "load_weights"): + raise AttributeError( + "The required method 'load_weights' is not defined in class" + f" {type(model).__name__}.") - logger.info("Loading weights with BitsAndBytes quantization. " - "May take a while ...") + if not hasattr(model, "packed_modules_mapping"): + raise AttributeError( + f"Model {type(model).__name__} does not support BitsAndBytes " + "quantization yet. No 'packed_modules_mapping' found.") quant_config = getattr(model_config.hf_config, "quantization_config", None) - - pre_quant = False if quant_config is not None: quant_method = quant_config.get("quant_method") if quant_method == "bitsandbytes": - pre_quant = True + self.pre_quant = True else: raise ValueError( f"BitsAndBytes loader does not support {quant_method} " @@ -469,20 +466,43 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: # The quant_states in pre_quantized models cannot work with a split # weight tensor. So TP does not work with pre_quantized bnb models. - if pre_quant and get_tensor_model_parallel_world_size() > 1: + if self.pre_quant and get_tensor_model_parallel_world_size() > 1: raise ValueError( "Prequant BitsAndBytes models with tensor parallelism is not " "supported. Please try with pipeline parallelism.") + if self.pre_quant: + self.load_8bit = quant_config.get("load_in_8bit", False) + + def _initialize_loader_state(self, model: nn.Module, + model_config: ModelConfig) -> None: + """ + Initialize the loader's internal state based on the model and + configuration. + """ + self.is_pool_model = is_pooling_model(model) + self.modules_mapping = ParamMapping(get_packed_modules_mapping(model)) - load_8bit = False - if pre_quant: - load_8bit = quant_config.get("load_in_8bit", False) + # For some models like Molmo, we need to use hf_to_vllm_mapper + # to ensure correct loading of weights. + if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): + self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name) - qweight_iterator, quant_state_dict = ( - self._get_quantized_weights_iterator(model_config.model, - model_config.revision, - pre_quant, load_8bit)) + self._get_bnb_target_modules(model) + self._classify_module_sharding(model) + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + + self._verify_model_compatibility(model, model_config) + self._initialize_loader_state(model, model_config) + + logger.info("Loading weights with BitsAndBytes quantization. " + "May take a while ...") + qweight_iterator, quant_state_dict = ( + self._get_quantized_weights_iterator( + model_config.model, + model_config.revision, + )) weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights(qweight_iterator) # Some models may have weights loading tracker unimplemented. @@ -562,10 +582,11 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: offsets = torch.tensor(offsets).cpu() set_weight_attrs(param, {"bnb_shard_offsets": offsets}) - if load_8bit: + if self.load_8bit: set_weight_attrs( param, {"matmul_state": [None] * len(quant_states)}) torch.cuda.empty_cache() + def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision)