11# ruff: noqa: SIM117
22import collections
33import copy
4+ import dataclasses
45import fnmatch
56import glob
67import json
78import math
89import os
910from abc import ABC , abstractmethod
1011from contextlib import contextmanager
11- from typing import Any , Dict , Generator , List , Optional , Tuple , Type
12+ from typing import (Any , Dict , Generator , Iterable , List , Optional , Tuple ,
13+ Type , cast )
1214
1315import gguf
1416import huggingface_hub
@@ -207,6 +209,22 @@ def load_model(self, *, model_config: ModelConfig,
207209class DefaultModelLoader (BaseModelLoader ):
208210 """Model loader that can load different file types from disk."""
209211
212+ @dataclasses .dataclass
213+ class Source :
214+ """A source for weights."""
215+
216+ model_or_path : str
217+ """The model ID or path."""
218+
219+ revision : Optional [str ]
220+ """The optional model revision."""
221+
222+ prefix : str = ""
223+ """A prefix to prepend to all weights."""
224+
225+ fall_back_to_pt : bool = True
226+ """Whether .pt weights can be used."""
227+
210228 def __init__ (self , load_config : LoadConfig ):
211229 super ().__init__ (load_config )
212230 if load_config .model_loader_extra_config :
@@ -313,17 +331,16 @@ def _prepare_weights(self, model_name_or_path: str,
313331 return hf_folder , hf_weights_files , use_safetensors
314332
315333 def _get_weights_iterator (
316- self , model_name_or_path : str , revision : Optional [str ],
317- fall_back_to_pt : bool
334+ self , source : "Source"
318335 ) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
319336 """Get an iterator for the model weights based on the load format."""
320337 hf_folder , hf_weights_files , use_safetensors = self ._prepare_weights (
321- model_name_or_path , revision , fall_back_to_pt )
338+ source . model_or_path , source . revision , source . fall_back_to_pt )
322339 if self .load_config .load_format == LoadFormat .NPCACHE :
323340 # Currently np_cache only support *.bin checkpoints
324341 assert use_safetensors is False
325342 weights_iterator = np_cache_weights_iterator (
326- model_name_or_path , self .load_config .download_dir , hf_folder ,
343+ source . model_or_path , self .load_config .download_dir , hf_folder ,
327344 hf_weights_files )
328345 elif use_safetensors :
329346 weights_iterator = safetensors_weights_iterator (hf_weights_files )
@@ -341,7 +358,29 @@ def _xla_weights_iterator(iterator: Generator):
341358 xm .mark_step ()
342359
343360 weights_iterator = _xla_weights_iterator (weights_iterator )
344- return weights_iterator
361+
362+ # Apply the prefix.
363+ return ((source .prefix + name , tensor )
364+ for (name , tensor ) in weights_iterator )
365+
366+ def _get_all_weights (
367+ self ,
368+ model_config : ModelConfig ,
369+ model : nn .Module ,
370+ ) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
371+
372+ primary_weights = DefaultModelLoader .Source (
373+ model_config .model ,
374+ model_config .revision ,
375+ prefix = "" ,
376+ fall_back_to_pt = getattr (model , "fall_back_to_pt_during_load" ,
377+ True ))
378+ yield from self ._get_weights_iterator (primary_weights )
379+
380+ secondary_weights = cast (Iterable [DefaultModelLoader .Source ],
381+ getattr (model , "secondary_weights" , ()))
382+ for source in secondary_weights :
383+ yield from self ._get_weights_iterator (source )
345384
346385 def download_model (self , model_config : ModelConfig ) -> None :
347386 self ._prepare_weights (model_config .model ,
@@ -360,13 +399,8 @@ def load_model(self, *, model_config: ModelConfig,
360399 model = _initialize_model (model_config , self .load_config ,
361400 lora_config , cache_config ,
362401 scheduler_config )
363- model .load_weights (
364- self ._get_weights_iterator (model_config .model ,
365- model_config .revision ,
366- fall_back_to_pt = getattr (
367- model ,
368- "fall_back_to_pt_during_load" ,
369- True )), )
402+
403+ model .load_weights (self ._get_all_weights (model_config , model ))
370404
371405 for _ , module in model .named_modules ():
372406 quant_method = getattr (module , "quant_method" , None )
0 commit comments