Skip to content

Commit c83919c

Browse files
authored
[Model] Add Internlm2 LoRA support (#5064)
Signed-off-by: Isotr0py <[email protected]>
1 parent 98f47f2 commit c83919c

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

docs/source/models/supported_models.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ Text Generation
182182
* - :code:`InternLM2ForCausalLM`
183183
- InternLM2
184184
- :code:`internlm/internlm2-7b`, :code:`internlm/internlm2-chat-7b`, etc.
185-
-
185+
- ✅︎
186186
- ✅︎
187187
* - :code:`JAISLMHeadModel`
188188
- Jais

vllm/model_executor/models/internlm2.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from vllm.model_executor.sampling_metadata import SamplingMetadata
2828
from vllm.sequence import IntermediateTensors
2929

30-
from .interfaces import SupportsPP
30+
from .interfaces import SupportsLoRA, SupportsPP
3131
from .utils import (is_pp_missing_parameter,
3232
make_empty_intermediate_tensors_factory, make_layers,
3333
maybe_prefix)
@@ -319,7 +319,21 @@ def forward(
319319
return hidden_states
320320

321321

322-
class InternLM2ForCausalLM(nn.Module, SupportsPP):
322+
class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
323+
packed_modules_mapping = {
324+
"wqkv": ["wqkv"],
325+
"gate_up_proj": ["w1", "w3"],
326+
}
327+
328+
# LoRA specific attributes
329+
supported_lora_modules = [
330+
"wqkv",
331+
"wo",
332+
"gate_up_proj",
333+
"w2",
334+
]
335+
embedding_modules = {}
336+
embedding_padding_modules = []
323337

324338
def __init__(self,
325339
*,
@@ -329,8 +343,12 @@ def __init__(self,
329343
super().__init__()
330344
config = vllm_config.model_config.hf_config
331345
quant_config = vllm_config.quant_config
346+
lora_config = vllm_config.lora_config
347+
332348
self.config = config
333349
self.quant_config = quant_config
350+
self.lora_config = lora_config
351+
334352
self.model = model_type(vllm_config=vllm_config,
335353
prefix=maybe_prefix(prefix, "model"))
336354
self.output = ParallelLMHead(config.vocab_size,

0 commit comments

Comments
 (0)