From 2005fc0b5cda5962a91307698248a0aae3b5b0d5 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 25 Dec 2024 15:55:16 +0000 Subject: [PATCH 1/5] Init Signed-off-by: Jee Jee Li --- vllm/lora/models.py | 3 ++- vllm/lora/utils.py | 40 +++++++++++----------------- vllm/lora/worker_manager.py | 2 ++ vllm/model_executor/models/molmo.py | 41 +++++++++++++++++++++++++++-- 4 files changed, 59 insertions(+), 27 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index f50db8e3b8e1..5c0e4e5cbc63 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -231,7 +231,8 @@ def from_local_checkpoint( with safetensors.safe_open(lora_tensor_path, framework="pt") as f: # type: ignore for lora_module in f.keys(): # noqa - module_name, _, _ = parse_fine_tuned_lora_name(lora_module) + module_name, _, _ = parse_fine_tuned_lora_name( + lora_module, weights_mapper) part_name = module_name.split(".")[-1] if part_name not in expected_lora_modules: unexpected_modules.append(module_name) diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 3a84a6ae1c02..f2a86a4f6980 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -1,4 +1,3 @@ -import copy import os import re from typing import List, Optional, Set, Tuple, Type, Union @@ -32,7 +31,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.utils import WeightsMapper -from vllm.utils import print_warning_once logger = init_logger(__name__) @@ -111,37 +109,31 @@ def parse_fine_tuned_lora_name( is_lora_a whether the tensor is lora_a or lora_b. is_bias whether the tensor is lora bias. """ + # LoRA weights qualified name always starts with `base_model.model.`, + # so we remove the prefix `base_model.model.` to make the following + # mapping correctly. + parts = name.split(".") + if parts[0] == "base_model" and parts[1] == "model": + name = ".".join(parts[2:]) + else: + raise ValueError(f"Invalid LoRA weight name: {name}") - w_mapper = None if weights_mapper: - w_mapper = copy.deepcopy(weights_mapper) - # TODO: Currently only supports mapping for prefix, mapping for - # substr and subfix will be supported in the future. - for attr, mapping in [ - ("orig_to_new_substr", w_mapper.orig_to_new_substr), - ("orig_to_new_suffix", w_mapper.orig_to_new_suffix), - ]: - if mapping: - print_warning_once( - f"vLLM currently does not support mapping of LoRA weights " - f"for {mapping}.") - setattr(w_mapper, attr, {}) - - mapper = (lambda name: w_mapper._map_name(name) - if w_mapper is not None else name) + name = weights_mapper._map_name(name) + parts = name.split(".") if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"): - new_name = ".".join(parts[2:-2]) - return mapper(new_name), parts[-2] == "lora_A", False + new_name = ".".join(parts[:-2]) + return new_name, parts[-2] == "lora_A", False if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": - new_name = ".".join(parts[2:-1]) - return mapper(new_name), parts[-1] == "lora_embedding_A", False + new_name = ".".join(parts[:-1]) + return new_name, parts[-1] == "lora_embedding_A", False if parts[-1] == "bias": - new_name = ".".join(parts[2:-2]) - return mapper(new_name), False, True + new_name = ".".join(parts[:-2]) + return new_name, False, True raise ValueError(f"{name} is unsupported LoRA weight") diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index ef8cc5886103..10976fac2302 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -91,6 +91,8 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: packed_modules_mapping[module]) else: expected_lora_modules.append(module) + + expected_lora_modules = list(set(expected_lora_modules)) lora_path = get_adapter_absolute_path(lora_request.lora_path) # For some models like Qwen2VL, we need to use hf_to_vllm_mapper diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 63a25137f8aa..806e2a542d06 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -36,6 +36,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer @@ -43,7 +44,7 @@ SequenceData) from vllm.transformers_utils.processor import get_processor -from .interfaces import SupportsMultiModal, SupportsPP +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, merge_multimodal_embeddings) @@ -1121,7 +1122,33 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo) @INPUT_REGISTRY.register_input_processor(input_processor_for_molmo) -class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): +class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, + SupportsLoRA): + + packed_modules_mapping = { + "qkv_proj": ["qkv_proj"], + "mlp.gate_up_proj": ["gate_up_proj"], + "image_projector.gate_up_proj": ["gate_proj", "up_proj"] + } + + # LoRA specific attributes + supported_lora_modules = [ + # language model + "qkv_proj", + "o_proj", + "mlp.gate_up_proj", + "image_projector.gate_up_proj", + "down_proj", + # vision tower + "wq", + "wk", + "wv", + "wo", + "w1", + "w2", + ] + embedding_modules = {} + embedding_padding_modules = [] hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ @@ -1331,6 +1358,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weights = _get_weights_with_merged_embedding(weights) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="model", + connector="vision_backbone.image_projector", + tower_model="vision_backbone", + ) + def _get_weights_with_merged_embedding( weights: Iterable[Tuple[str, torch.Tensor]] From a452bc77db9f21378f4c98e4a50673ba8ccc9012 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 26 Dec 2024 01:10:54 +0000 Subject: [PATCH 2/5] Add unit testt Signed-off-by: Jee Jee Li --- tests/lora/test_lora_checkpoints.py | 15 ++++++++++----- vllm/lora/utils.py | 4 ++-- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py index 9842203eb15e..537d95b025a9 100644 --- a/tests/lora/test_lora_checkpoints.py +++ b/tests/lora/test_lora_checkpoints.py @@ -74,7 +74,7 @@ def test_load_checkpoints( embedding_padding_modules=embed_padding_modules) -def test_lora_weights_mapping(baichuan_lora_files, ): +def test_lora_weights_mapping(baichuan_lora_files): supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping embedding_modules = BaiChuanBaseForCausalLM.embedding_modules @@ -86,10 +86,14 @@ def test_lora_weights_mapping(baichuan_lora_files, ): else: expected_lora_modules.append(module) - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ - "model.": "language_model.model.", - }, ) - + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.": "language_model.model.", + }, + orig_to_new_substr={ + ".layers.": ".baichuan_layers.", + }, + ) lora_model = LoRAModel.from_local_checkpoint( baichuan_lora_files, expected_lora_modules, @@ -101,3 +105,4 @@ def test_lora_weights_mapping(baichuan_lora_files, ): ) for name in lora_model.loras: assert name.startswith(hf_to_vllm_mapper.orig_to_new_prefix["model."]) + assert ".baichuan_layers." in name diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index f2a86a4f6980..efbb49e39aa1 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -109,14 +109,14 @@ def parse_fine_tuned_lora_name( is_lora_a whether the tensor is lora_a or lora_b. is_bias whether the tensor is lora bias. """ - # LoRA weights qualified name always starts with `base_model.model.`, + # LoRA weight qualified name always starts with `base_model.model.`, # so we remove the prefix `base_model.model.` to make the following # mapping correctly. parts = name.split(".") if parts[0] == "base_model" and parts[1] == "model": name = ".".join(parts[2:]) else: - raise ValueError(f"Invalid LoRA weight name: {name}") + raise ValueError(f"{name} is unsupported LoRA weight") if weights_mapper: name = weights_mapper._map_name(name) From fbadff4154d9b9c3814c2781f5f5da0df881d722 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 26 Dec 2024 01:16:39 +0000 Subject: [PATCH 3/5] revert molmo Signed-off-by: Jee Jee Li --- vllm/model_executor/models/molmo.py | 41 ++--------------------------- 1 file changed, 2 insertions(+), 39 deletions(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 806e2a542d06..63a25137f8aa 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -36,7 +36,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer @@ -44,7 +43,7 @@ SequenceData) from vllm.transformers_utils.processor import get_processor -from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP +from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, merge_multimodal_embeddings) @@ -1122,33 +1121,7 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo) @INPUT_REGISTRY.register_input_processor(input_processor_for_molmo) -class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, - SupportsLoRA): - - packed_modules_mapping = { - "qkv_proj": ["qkv_proj"], - "mlp.gate_up_proj": ["gate_up_proj"], - "image_projector.gate_up_proj": ["gate_proj", "up_proj"] - } - - # LoRA specific attributes - supported_lora_modules = [ - # language model - "qkv_proj", - "o_proj", - "mlp.gate_up_proj", - "image_projector.gate_up_proj", - "down_proj", - # vision tower - "wq", - "wk", - "wv", - "wo", - "w1", - "w2", - ] - embedding_modules = {} - embedding_padding_modules = [] +class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ @@ -1358,16 +1331,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weights = _get_weights_with_merged_embedding(weights) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - def get_mm_mapping(self) -> MultiModelKeys: - """ - Get the module prefix in multimodal models - """ - return MultiModelKeys.from_string_field( - language_model="model", - connector="vision_backbone.image_projector", - tower_model="vision_backbone", - ) - def _get_weights_with_merged_embedding( weights: Iterable[Tuple[str, torch.Tensor]] From 5decbaab9e2227b86c6bf358f0d25136b214275c Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 26 Dec 2024 05:30:26 +0000 Subject: [PATCH 4/5] Fix bug Signed-off-by: Jee Jee Li --- tests/lora/test_qwen2vl.py | 9 +++++++-- vllm/lora/utils.py | 21 ++++++++++----------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/tests/lora/test_qwen2vl.py b/tests/lora/test_qwen2vl.py index c8c720ff0c77..a2b1dd1ad590 100644 --- a/tests/lora/test_qwen2vl.py +++ b/tests/lora/test_qwen2vl.py @@ -7,7 +7,7 @@ from vllm.lora.request import LoRARequest from vllm.platforms import current_platform -MODEL_PATH = "Qwen/Qwen2-VL-7B-Instruct" +MODEL_PATH = "/home/sobey/Models/llm_models/BaseModel/Qwen/Qwn2-VL/Qwen2-VL-7B-Instruct" PROMPT_TEMPLATE = ( "<|im_start|>system\nYou are a helpful assistant.<|im_end|>" @@ -22,7 +22,7 @@ # After fine-tuning with LoRA, all generated content should start begin `A`. EXPECTED_OUTPUT = [ - "A stop sign stands prominently in the foreground, with a traditional Chinese gate and a black SUV in the background, illustrating a blend of modern and cultural elements.", # noqa: E501 + "A red stop sign stands prominently in the foreground, with a traditional Chinese gate and a black SUV in the background, illustrating a blend of modern and cultural elements.", # noqa: E501 "A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501 ] @@ -76,3 +76,8 @@ def test_qwen2vl_lora(qwen2vl_lora_files): output1 = do_sample(llm, qwen2vl_lora_files, lora_id=1) for i in range(len(EXPECTED_OUTPUT)): assert EXPECTED_OUTPUT[i].startswith(output1[i]) + + output2 = do_sample(llm, qwen2vl_lora_files, lora_id=2) + for i in range(len(EXPECTED_OUTPUT)): + assert EXPECTED_OUTPUT[i].startswith(output2[i]) + diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index efbb49e39aa1..91052cec4768 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -109,30 +109,29 @@ def parse_fine_tuned_lora_name( is_lora_a whether the tensor is lora_a or lora_b. is_bias whether the tensor is lora bias. """ + + # LoRA weight qualified name always starts with `base_model.model.`, # so we remove the prefix `base_model.model.` to make the following # mapping correctly. - parts = name.split(".") - if parts[0] == "base_model" and parts[1] == "model": - name = ".".join(parts[2:]) - else: - raise ValueError(f"{name} is unsupported LoRA weight") - - if weights_mapper: - name = weights_mapper._map_name(name) + if "base_model.model." in name: + name = name.replace("base_model.model.", "") + name = weights_mapper._map_name(name) if weights_mapper else name + # recover the prefix `base_model.model.` + name = "base_model.model." + name parts = name.split(".") if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"): - new_name = ".".join(parts[:-2]) + new_name = ".".join(parts[2:-2]) return new_name, parts[-2] == "lora_A", False if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": - new_name = ".".join(parts[:-1]) + new_name = ".".join(parts[2:-1]) return new_name, parts[-1] == "lora_embedding_A", False if parts[-1] == "bias": - new_name = ".".join(parts[:-2]) + new_name = ".".join(parts[2:-2]) return new_name, False, True raise ValueError(f"{name} is unsupported LoRA weight") From 68270c2aecf2160c5f914b7bf9ea003ed5aa79aa Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 26 Dec 2024 05:35:33 +0000 Subject: [PATCH 5/5] format Signed-off-by: Jee Jee Li --- tests/lora/test_qwen2vl.py | 3 +-- vllm/lora/utils.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/lora/test_qwen2vl.py b/tests/lora/test_qwen2vl.py index a2b1dd1ad590..c9f48402b026 100644 --- a/tests/lora/test_qwen2vl.py +++ b/tests/lora/test_qwen2vl.py @@ -7,7 +7,7 @@ from vllm.lora.request import LoRARequest from vllm.platforms import current_platform -MODEL_PATH = "/home/sobey/Models/llm_models/BaseModel/Qwen/Qwn2-VL/Qwen2-VL-7B-Instruct" +MODEL_PATH = "Qwen/Qwen2-VL-7B-Instruct" PROMPT_TEMPLATE = ( "<|im_start|>system\nYou are a helpful assistant.<|im_end|>" @@ -80,4 +80,3 @@ def test_qwen2vl_lora(qwen2vl_lora_files): output2 = do_sample(llm, qwen2vl_lora_files, lora_id=2) for i in range(len(EXPECTED_OUTPUT)): assert EXPECTED_OUTPUT[i].startswith(output2[i]) - diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 91052cec4768..d72b7638d84a 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -110,7 +110,6 @@ def parse_fine_tuned_lora_name( is_bias whether the tensor is lora bias. """ - # LoRA weight qualified name always starts with `base_model.model.`, # so we remove the prefix `base_model.model.` to make the following # mapping correctly.