Skip to content

Commit f5e6e7e

Browse files
authored
fix: resolve GPU memory imbalance in concurrent weight loading
Signed-off-by: Necofish <[email protected]>
1 parent 1f39a11 commit f5e6e7e

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

tensorrt_llm/_torch/models/modeling_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from tensorrt_llm.lora_manager import HfLoraLoader
1515
from tensorrt_llm.models.convert_utils import split_matrix_tp
16+
from tensorrt_llm._utils import mpi_rank
1617

1718
from ...logger import logger
1819
from ...models.modeling_utils import QuantConfig
@@ -777,6 +778,7 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
777778
}
778779

779780
def load_single_module(name, module):
781+
torch.cuda.set_device(mpi_rank())
780782
if len(module._parameters) > 0:
781783
# skip load weights if module is in skip_modules
782784
if any(skip_module in name for skip_module in skip_modules):
@@ -873,6 +875,7 @@ def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM],
873875
logger.info(f"Renamed weights with params_map: {params_map}")
874876

875877
def load_single_module(name, module):
878+
torch.cuda.set_device(mpi_rank())
876879
if len(module._parameters) > 0:
877880
if weight_mapper.should_skip_module(name):
878881
return

0 commit comments

Comments
 (0)