3030
3131from .interfaces import SupportsPP
3232from .utils import (is_pp_missing_parameter ,
33- make_empty_intermediate_tensors_factory , make_layers )
33+ make_empty_intermediate_tensors_factory , make_layers ,
34+ maybe_prefix )
3435
3536
3637class InternLM2MLP (nn .Module ):
@@ -41,16 +42,23 @@ def __init__(
4142 intermediate_size : int ,
4243 hidden_act : str ,
4344 quant_config : Optional [QuantizationConfig ] = None ,
45+ prefix : str = "" ,
4446 ) -> None :
4547 super ().__init__ ()
4648 self .gate_up_proj = MergedColumnParallelLinear (
47- hidden_size , [intermediate_size ] * 2 ,
49+ hidden_size ,
50+ [intermediate_size ] * 2 ,
51+ bias = False ,
52+ quant_config = quant_config ,
53+ prefix = f"{ prefix } .gate_up_proj" ,
54+ )
55+ self .w2 = RowParallelLinear (
56+ intermediate_size ,
57+ hidden_size ,
4858 bias = False ,
49- quant_config = quant_config )
50- self .w2 = RowParallelLinear (intermediate_size ,
51- hidden_size ,
52- bias = False ,
53- quant_config = quant_config )
59+ quant_config = quant_config ,
60+ prefix = f"{ prefix } .w2" ,
61+ )
5462 if hidden_act != "silu" :
5563 raise ValueError (f"Unsupported activation: { hidden_act } . "
5664 "Only silu is supported for now." )
@@ -75,6 +83,7 @@ def __init__(
7583 max_position_embeddings : int = 8192 ,
7684 cache_config : Optional [CacheConfig ] = None ,
7785 quant_config : Optional [QuantizationConfig ] = None ,
86+ prefix : str = "" ,
7887 ) -> None :
7988 super ().__init__ ()
8089 self .hidden_size = hidden_size
@@ -108,12 +117,14 @@ def __init__(
108117 self .total_num_kv_heads ,
109118 bias = False ,
110119 quant_config = quant_config ,
120+ prefix = f"{ prefix } .wqkv" ,
111121 )
112122 self .wo = RowParallelLinear (
113123 self .total_num_heads * self .head_dim ,
114124 hidden_size ,
115125 bias = False ,
116126 quant_config = quant_config ,
127+ prefix = f"{ prefix } .wo" ,
117128 )
118129
119130 self .rotary_emb = get_rope (
@@ -123,12 +134,15 @@ def __init__(
123134 base = rope_theta ,
124135 rope_scaling = rope_scaling ,
125136 )
126- self .attn = Attention (self .num_heads ,
127- self .head_dim ,
128- self .scaling ,
129- num_kv_heads = self .num_kv_heads ,
130- cache_config = cache_config ,
131- quant_config = quant_config )
137+ self .attn = Attention (
138+ self .num_heads ,
139+ self .head_dim ,
140+ self .scaling ,
141+ num_kv_heads = self .num_kv_heads ,
142+ cache_config = cache_config ,
143+ quant_config = quant_config ,
144+ prefix = f"{ prefix } .attn" ,
145+ )
132146
133147 def split_qkv (self , qkv : torch .Tensor ):
134148 seq_len = qkv .shape [0 ]
@@ -176,6 +190,7 @@ def __init__(
176190 config : PretrainedConfig ,
177191 cache_config : Optional [CacheConfig ] = None ,
178192 quant_config : Optional [QuantizationConfig ] = None ,
193+ prefix : str = "" ,
179194 ) -> None :
180195 super ().__init__ ()
181196 self .hidden_size = config .hidden_size
@@ -192,12 +207,14 @@ def __init__(
192207 max_position_embeddings = max_position_embeddings ,
193208 cache_config = cache_config ,
194209 quant_config = quant_config ,
210+ prefix = f"{ prefix } .attention" ,
195211 )
196212 self .feed_forward = InternLM2MLP (
197213 hidden_size = self .hidden_size ,
198214 intermediate_size = config .intermediate_size ,
199215 hidden_act = config .hidden_act ,
200216 quant_config = quant_config ,
217+ prefix = f"{ prefix } .feed_forward" ,
201218 )
202219 self .attention_norm = RMSNorm (config .hidden_size ,
203220 eps = config .rms_norm_eps )
@@ -251,8 +268,8 @@ def __init__(
251268 )
252269 self .start_layer , self .end_layer , self .layers = make_layers (
253270 config .num_hidden_layers ,
254- lambda prefix : InternLMDecoderLayer (config , cache_config ,
255- quant_config ),
271+ lambda prefix : InternLMDecoderLayer (
272+ config , cache_config , quant_config , prefix = prefix ),
256273 prefix = f"{ prefix } .layers" )
257274 self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
258275 self .make_empty_intermediate_tensors = (
@@ -306,14 +323,19 @@ def __init__(
306323 config : PretrainedConfig ,
307324 cache_config : Optional [CacheConfig ] = None ,
308325 quant_config : Optional [QuantizationConfig ] = None ,
326+ prefix : str = "" ,
309327 ) -> None :
310328 super ().__init__ ()
311329 self .config = config
312330 self .quant_config = quant_config
313- self .model = InternLM2Model (config , cache_config , quant_config )
331+ self .model = InternLM2Model (config ,
332+ cache_config ,
333+ quant_config ,
334+ prefix = maybe_prefix (prefix , "model" ))
314335 self .output = ParallelLMHead (config .vocab_size ,
315336 config .hidden_size ,
316- quant_config = quant_config )
337+ quant_config = quant_config ,
338+ prefix = maybe_prefix (prefix , "output" ))
317339 if self .config .tie_word_embeddings :
318340 self .output .weight = self .model .tok_embeddings .weight
319341 self .logits_processor = LogitsProcessor (config .vocab_size )
0 commit comments