4646from vllm .sequence import IntermediateTensors
4747
4848from .interfaces import SupportsQuant , SupportsV0Only
49- from .utils import maybe_prefix
49+ from .utils import AutoWeightsLoader , WeightsMapper , maybe_prefix
5050
5151logger = logging .get_logger (__name__ )
5252
@@ -700,7 +700,8 @@ def forward(
700700
701701class BartModel (nn .Module , SupportsQuant ):
702702 _tied_weights_keys = [
703- "encoder.embed_tokens.weight" , "decoder.embed_tokens.weight"
703+ "encoder.embed_tokens.weight" ,
704+ "decoder.embed_tokens.weight" ,
704705 ]
705706
706707 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
@@ -763,10 +764,54 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
763764
764765 return decoder_outputs
765766
767+ def load_weights (self , weights : Iterable [tuple [str ,
768+ torch .Tensor ]]) -> set [str ]:
769+ stacked_params_mapping = [
770+ # (param_name, shard_name, shard_id)
771+ ("qkv_proj" , "q_proj" , "q" ),
772+ ("qkv_proj" , "k_proj" , "k" ),
773+ ("qkv_proj" , "v_proj" , "v" ),
774+ ]
775+
776+ other_weights = []
777+ loaded_stacked_params = []
778+ model_params_dict = dict (self .named_parameters ())
779+
780+ for name , loaded_weight in weights :
781+ for (param_name , weight_name , shard_id ) in stacked_params_mapping :
782+ if weight_name not in name :
783+ continue
784+ name = name .replace (weight_name , param_name )
785+ if name not in model_params_dict :
786+ continue
787+ param = model_params_dict [name ]
788+ weight_loader = param .weight_loader
789+ weight_loader (param , loaded_weight , shard_id )
790+ loaded_stacked_params .append (name )
791+ break
792+ else :
793+ if name in model_params_dict :
794+ other_weights .append ((name , loaded_weight ))
795+
796+ loader = AutoWeightsLoader (self )
797+ loaded_params = loader .load_weights (other_weights )
798+ loaded_params .update (loaded_stacked_params )
799+ return loaded_params
800+
766801
767802class BartForConditionalGeneration (nn .Module , SupportsV0Only , SupportsQuant ):
768- packed_modules_mapping = {"qkv_proj" : ["q_proj" , "k_proj" , "v_proj" ]}
769- base_model_prefix = "model"
803+ hf_to_vllm_mapper = WeightsMapper (
804+ orig_to_new_prefix = {
805+ "decoder." : "model.decoder." ,
806+ "encoder." : "model.encoder." ,
807+ "shared." : "model.shared."
808+ },
809+ orig_to_new_substr = {
810+ "beta" : "bias" ,
811+ "gamma" : "weight" ,
812+ "LayerNorm" : "layernorm" ,
813+ },
814+ )
770815
771816 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
772817
@@ -789,7 +834,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
789834 self .lm_head = BartParallelLMHead (config .vocab_size ,
790835 config .d_model ,
791836 embed_scale = embed_scale )
792-
793837 self .logits_processor = LogitsProcessor (self .unpadded_vocab_size ,
794838 config .vocab_size )
795839
@@ -828,111 +872,37 @@ def compute_logits(
828872 sampling_metadata )
829873 return logits
830874
831- stacked_params_mapping = {
832- "q_proj" : {
833- "param_name" : "qkv_proj" ,
834- "shard_id" : "q" ,
835- },
836- "k_proj" : {
837- "param_name" : "qkv_proj" ,
838- "shard_id" : "k" ,
839- },
840- "v_proj" : {
841- "param_name" : "qkv_proj" ,
842- "shard_id" : "v" ,
843- },
844- }
845-
846- params_mapping = {
847- "beta" : "bias" ,
848- "gamma" : "weight" ,
849- "LayerNorm" : "layernorm" ,
850- }
851-
852- def _rename_key (self , key : str ):
853- prefix = f"{ self .base_model_prefix } ."
854- key = key [len (prefix ):] if key .startswith (prefix ) else key
855-
856- for src , dst in self .params_mapping .items ():
857- key = key .replace (src , dst )
858-
859- return key
860-
861- def _rename_stacked_param (
862- self ,
863- name : str ,
864- ) -> tuple [str , Optional [str ]]:
865- for key , mapping in self .stacked_params_mapping .items ():
866- if key in name :
867- name = name .replace (key , mapping ["param_name" ])
868- return name , mapping ["shard_id" ]
869- return name , None
870-
871- def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
872-
873- model_params_dict = dict (self .model .named_parameters ())
874- top_params_dict = dict (self .named_parameters ())
875-
875+ def load_weights (self , weights : Iterable [tuple [str ,
876+ torch .Tensor ]]) -> set [str ]:
876877 weights_tuple_list = list (weights )
877878
878879 shared_embedding_weight = None
879- shared_embedding_shard_id = None
880-
881880 for name , loaded_weight in weights_tuple_list :
882-
883- name = self ._rename_key (name )
884- name , shard_id = self ._rename_stacked_param (name )
885-
886881 if ('shared.weight' in name
887882 or 'encoder.embed_tokens.weight' in name
888883 or 'decoder.embed_tokens.weight' in name
889884 or 'lm_head.weight' in name ):
890885 assert shared_embedding_weight is None , (
891886 "Conflicting embedding weights." )
892887 shared_embedding_weight = loaded_weight
893- shared_embedding_shard_id = shard_id
894- else :
895- # Skip the specific downstream task weight.
896- if name .startswith ('cls.' ):
897- continue
898- # use Pooler instead.
899- if name .startswith ('pooler.' ):
900- continue
901- # Skip loading extra bias for GPTQ models.
902- if name .endswith (".bias" ) and name not in model_params_dict :
903- continue
904888
905- param = model_params_dict [name ]
906- weight_loader = getattr (param , "weight_loader" ,
907- default_weight_loader )
908- if shard_id :
909- weight_loader (param , loaded_weight , shard_id )
910- else :
911- weight_loader (param , loaded_weight )
912-
913- # Assign shared weight values
914- encoder_in_param = model_params_dict ['encoder.embed_tokens.weight' ]
915- encoder_in_weight_loader = getattr (encoder_in_param , "weight_loader" ,
916- default_weight_loader )
917-
918- decoder_in_param = model_params_dict ['decoder.embed_tokens.weight' ]
919- decoder_in_weight_loader = getattr (decoder_in_param , "weight_loader" ,
920- default_weight_loader )
921-
922- lm_head_in_param = top_params_dict ['lm_head.weight' ]
923- lm_head_in_weight_loader = getattr (lm_head_in_param , "weight_loader" ,
924- default_weight_loader )
925-
926- assert shared_embedding_weight is not None
927-
928- if shared_embedding_shard_id :
929- encoder_in_weight_loader (encoder_in_param , shared_embedding_weight ,
930- shared_embedding_shard_id )
931- decoder_in_weight_loader (decoder_in_param , shared_embedding_weight ,
932- shared_embedding_shard_id )
933- lm_head_in_weight_loader (lm_head_in_param , shared_embedding_weight ,
934- shared_embedding_shard_id )
935- else :
936- encoder_in_weight_loader (encoder_in_param , shared_embedding_weight )
937- decoder_in_weight_loader (decoder_in_param , shared_embedding_weight )
938- lm_head_in_weight_loader (lm_head_in_param , shared_embedding_weight )
889+ loader = AutoWeightsLoader (
890+ self ,
891+ skip_prefixes = (["cls." , "pooler." ]),
892+ )
893+ loaded_params = loader .load_weights (weights_tuple_list ,
894+ mapper = self .hf_to_vllm_mapper )
895+
896+ if shared_embedding_weight is not None :
897+ weight_loader = getattr (self .lm_head .weight , "weight_loader" ,
898+ default_weight_loader )
899+ weight_loader (self .lm_head .weight , shared_embedding_weight )
900+
901+ self .model .encoder .embed_tokens .weight = self .lm_head .weight
902+ self .model .decoder .embed_tokens .weight = self .lm_head .weight
903+ loaded_params .update ({
904+ 'model.encoder.embed_tokens.weight' , 'lm_head.weight' ,
905+ 'model.decoder.embed_tokens.weight'
906+ })
907+
908+ return loaded_params
0 commit comments