11import contextlib
22import math
3+ import os
34import time
45from dataclasses import dataclass
56from typing import Dict , Generic , List , Optional , Tuple , Type , TypeVar , Union
@@ -635,6 +636,42 @@ def filter_weights(prefix, weights: Dict):
635636 return result
636637
637638
639+ def run_concurrently (func ,
640+ args_list ,
641+ reduce_func = None ,
642+ pbar = None ,
643+ num_workers = None ):
644+ """
645+ Run a function concurrently with a list of arguments.
646+ func: the function to run concurrently.
647+ args_list: a list of tuples of arguments for the function.
648+ reduce_func: an optional function to reduce the results.
649+ pbar: an optional tqdm progress bar.
650+ """
651+ from concurrent import futures
652+ with futures .ThreadPoolExecutor (max_workers = num_workers ) as executor :
653+ # Submit all tasks
654+ future_to_result = {
655+ executor .submit (func , * arg ): arg
656+ for arg in args_list
657+ }
658+
659+ # Process completed tasks as they finish
660+ for result in futures .as_completed (future_to_result ):
661+ arg = future_to_result [result ]
662+ try :
663+ part_weights = result .result ()
664+ if reduce_func :
665+ reduce_func (part_weights )
666+ if pbar :
667+ pbar .update (1 )
668+ except Exception as e :
669+ logger .error (
670+ f"Error executing { func .__name__ } with args { arg } : { str (e )} "
671+ )
672+ raise
673+
674+
638675def _load_weights_impl (model : Union [nn .Module , DecoderModelForCausalLM ],
639676 weights : Dict ,
640677 skip_modules : List [str ] = [],
@@ -659,30 +696,29 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
659696 'gate_up_proj' : ['gate_proj' , 'up_proj' ]
660697 }
661698
662- for name , module in tqdm (list (model .named_modules ()),
663- desc = "Loading weights" ):
699+ def load_single_module (name , module ):
664700 if len (module ._parameters ) > 0 :
665701 # skip load weights if module is in skip_modules
666702 if any (skip_module in name for skip_module in skip_modules ):
667- continue
703+ return
668704
669705 # skip load weights if tie word embeddings is enabled and layer is lm_head
670706 if model .config .tie_word_embeddings and name .startswith ("lm_head" ):
671- continue
707+ return
672708
673709 # Skip loading weights for embedding and lm_head if LoRA is enabled and has custom values
674710 if hasattr (model , "model" ) and hasattr (
675711 model .model , 'has_custom_embed_tokens'
676712 ) and model .model .has_custom_embed_tokens and name == "model.embed_tokens" :
677- continue
713+ return
678714 if hasattr (model , 'has_custom_lm_head'
679715 ) and model .has_custom_lm_head and name == "lm_head" :
680- continue
716+ return
681717
682718 names = name .split ('.' )
683719 # WAR: better solution is that llama has its own load_weights function.
684720 if names [- 1 ] == 'next_layer_layernorm' :
685- continue
721+ return
686722 if names [- 1 ] in params_map :
687723 module_weights = []
688724 for new_name in params_map [names [- 1 ]]:
@@ -713,3 +749,14 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
713749 for n , p in module ._parameters .items ():
714750 if p is not None :
715751 p .data .copy_ (module_weights [n ][:])
752+
753+ if os .environ .get ("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL" ,
754+ False ) in ["True" , "true" , "1" , "yes" , "y" ]:
755+ for name , module in tqdm (list (model .named_modules ()),
756+ desc = "Loading weights" ):
757+ load_single_module (name , module )
758+ else :
759+ pbar = tqdm (list (model .named_modules ()),
760+ desc = "Loading weights concurrently" )
761+ args_list = [(name , module ) for name , module in model .named_modules ()]
762+ run_concurrently (load_single_module , args_list , pbar = pbar )
0 commit comments