diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index 95cc9cac6be..4a07b1af899 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -11,6 +11,7 @@ from torch.utils._pytree import tree_any_only from tqdm import tqdm +from tensorrt_llm._utils import mpi_rank from tensorrt_llm.lora_manager import HfLoraLoader from tensorrt_llm.models.convert_utils import split_matrix_tp @@ -844,6 +845,7 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM], } def load_single_module(name, module): + torch.cuda.set_device(mpi_rank()) if len(module._parameters) > 0: # skip load weights if module is in skip_modules if any(skip_module in name for skip_module in skip_modules): @@ -940,6 +942,7 @@ def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM], logger.info(f"Renamed weights with params_map: {params_map}") def load_single_module(name, module): + torch.cuda.set_device(mpi_rank()) if len(module._parameters) > 0: if weight_mapper.should_skip_module(name): return