Skip to content

Commit c93b8aa

Browse files
arekaynvzhihanj
authored andcommitted
[ModelLoad] Concurrent load model (NVIDIA#5291)
Signed-off-by: Rashid K <[email protected]> Co-authored-by: Zhihan Jiang <[email protected]>
1 parent 11e777b commit c93b8aa

File tree

3 files changed

+87
-15
lines changed

3 files changed

+87
-15
lines changed

tensorrt_llm/_torch/models/modeling_utils.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import contextlib
22
import math
3+
import os
34
import time
45
from dataclasses import dataclass
56
from typing import Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union
@@ -635,6 +636,42 @@ def filter_weights(prefix, weights: Dict):
635636
return result
636637

637638

639+
def run_concurrently(func,
640+
args_list,
641+
reduce_func=None,
642+
pbar=None,
643+
num_workers=None):
644+
"""
645+
Run a function concurrently with a list of arguments.
646+
func: the function to run concurrently.
647+
args_list: a list of tuples of arguments for the function.
648+
reduce_func: an optional function to reduce the results.
649+
pbar: an optional tqdm progress bar.
650+
"""
651+
from concurrent import futures
652+
with futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
653+
# Submit all tasks
654+
future_to_result = {
655+
executor.submit(func, *arg): arg
656+
for arg in args_list
657+
}
658+
659+
# Process completed tasks as they finish
660+
for result in futures.as_completed(future_to_result):
661+
arg = future_to_result[result]
662+
try:
663+
part_weights = result.result()
664+
if reduce_func:
665+
reduce_func(part_weights)
666+
if pbar:
667+
pbar.update(1)
668+
except Exception as e:
669+
logger.error(
670+
f"Error executing {func.__name__} with args {arg}: {str(e)}"
671+
)
672+
raise
673+
674+
638675
def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
639676
weights: Dict,
640677
skip_modules: List[str] = [],
@@ -659,30 +696,29 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
659696
'gate_up_proj': ['gate_proj', 'up_proj']
660697
}
661698

662-
for name, module in tqdm(list(model.named_modules()),
663-
desc="Loading weights"):
699+
def load_single_module(name, module):
664700
if len(module._parameters) > 0:
665701
# skip load weights if module is in skip_modules
666702
if any(skip_module in name for skip_module in skip_modules):
667-
continue
703+
return
668704

669705
# skip load weights if tie word embeddings is enabled and layer is lm_head
670706
if model.config.tie_word_embeddings and name.startswith("lm_head"):
671-
continue
707+
return
672708

673709
# Skip loading weights for embedding and lm_head if LoRA is enabled and has custom values
674710
if hasattr(model, "model") and hasattr(
675711
model.model, 'has_custom_embed_tokens'
676712
) and model.model.has_custom_embed_tokens and name == "model.embed_tokens":
677-
continue
713+
return
678714
if hasattr(model, 'has_custom_lm_head'
679715
) and model.has_custom_lm_head and name == "lm_head":
680-
continue
716+
return
681717

682718
names = name.split('.')
683719
# WAR: better solution is that llama has its own load_weights function.
684720
if names[-1] == 'next_layer_layernorm':
685-
continue
721+
return
686722
if names[-1] in params_map:
687723
module_weights = []
688724
for new_name in params_map[names[-1]]:
@@ -713,3 +749,14 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
713749
for n, p in module._parameters.items():
714750
if p is not None:
715751
p.data.copy_(module_weights[n][:])
752+
753+
if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL",
754+
False) in ["True", "true", "1", "yes", "y"]:
755+
for name, module in tqdm(list(model.named_modules()),
756+
desc="Loading weights"):
757+
load_single_module(name, module)
758+
else:
759+
pbar = tqdm(list(model.named_modules()),
760+
desc="Loading weights concurrently")
761+
args_list = [(name, module) for name, module in model.named_modules()]
762+
run_concurrently(load_single_module, args_list, pbar=pbar)

tensorrt_llm/_torch/modules/linear.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,9 @@ def load_weights_fused_qkv_linear(self, module: Linear,
364364
v_weight = v_weight.to(module.dtype) * weight_scale[2]
365365

366366
fused_weight = torch.cat((q_weight, k_weight, v_weight))
367+
if module.weight_scale.device != fused_weight.device:
368+
module.weight_scale = Parameter(
369+
module.weight_scale.data.to(fused_weight.device))
367370
fused_weight = (fused_weight / module.weight_scale).to(
368371
torch.float8_e4m3fn)
369372
copy_weight(module.weight, fused_weight)
@@ -385,6 +388,9 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
385388
gate_weight = gate_weight.to(module.dtype) * weight_scale[0]
386389
up_weight = up_weight.to(module.dtype) * weight_scale[1]
387390
fused_weight = torch.cat((gate_weight, up_weight))
391+
if module.weight_scale.device != fused_weight.device:
392+
module.weight_scale = Parameter(
393+
module.weight_scale.data.to(fused_weight.device))
388394
fused_weight = (fused_weight / module.weight_scale).to(
389395
torch.float8_e4m3fn)
390396
copy_weight(module.weight, fused_weight)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import safetensors
1919
import torch
2020
import torch._dynamo.config
21+
import tqdm
2122

2223
import tensorrt_llm.bindings.internal.userbuffers as ub
2324
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
@@ -48,7 +49,7 @@
4849
from ..model_config import ModelConfig, MoeLoadBalancerConfig
4950
from ..models import AutoModelForCausalLM
5051
from ..models.modeling_utils import (DecoderModelForCausalLM, MetaInitMode,
51-
timing)
52+
run_concurrently, timing)
5253
from ..modules.fused_moe.moe_load_balancer import (
5354
MoeLoadBalancer, MoeLoadBalancerIterContext, maybe_create_moe_load_balancer)
5455
from ..speculative import SpecConfig, SpecMetadata, get_spec_metadata
@@ -180,19 +181,28 @@ def load_weights(checkpoint_dir: str):
180181
f"Prefetching {prefetch_size / (1024**3):.2f}GB checkpoint files."
181182
)
182183
prefetch_files(weight_files)
183-
for file in weight_files:
184-
logger.info(f"Loading {file}")
185-
part_weights = safetensors.torch.load_file(file)
186-
weights.update(part_weights)
184+
185+
def load_safetensors_file(file):
186+
return safetensors.torch.load_file(file)
187+
188+
pbar = tqdm.tqdm(total=len(weight_files),
189+
desc="Loading safetensors weights in parallel")
190+
191+
# Note that the function is called with a tuple of arguments, hence we need to wrap the arguments in a tuple via [(w,) for w in weight_files]
192+
# specifically the comma right after the w is important to make it a tuple.
193+
run_concurrently(load_safetensors_file, [(w, ) for w in weight_files],
194+
reduce_func=weights.update,
195+
pbar=pbar)
196+
187197
return weights
188198

189199
weight_files = glob.glob(f"{checkpoint_dir}/*.bin")
190200
if not weight_files:
191201
weight_files = glob.glob(f"{checkpoint_dir}/*.pth")
192202

193203
if weight_files:
194-
for file in weight_files:
195-
# try mmap first, if failed, turn off mmap
204+
205+
def load_bin_or_path_file(file):
196206
try:
197207
part_weights = torch.load(file,
198208
weights_only=True,
@@ -206,7 +216,16 @@ def load_weights(checkpoint_dir: str):
206216
weights_only=True,
207217
map_location='cpu',
208218
mmap=False)
209-
weights.update(part_weights)
219+
finally:
220+
return part_weights
221+
222+
pbar = tqdm.tqdm(total=len(weight_files),
223+
desc="Loading bin weights in parallel")
224+
# Note that the function is called with a tuple of arguments, hence we need to wrap the arguments in a tuple via [(w,) for w in weight_files]
225+
# specifically the comma right after the w is important to make it a tuple.
226+
run_concurrently(load_bin_or_path_file, [(w, ) for w in weight_files],
227+
reduce_func=weights.update,
228+
pbar=pbar)
210229
return weights
211230

212231
raise RuntimeError(f"No weight files found in {checkpoint_dir}.")

0 commit comments

Comments
 (0)