Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch.utils._pytree import tree_any_only
from tqdm import tqdm

from tensorrt_llm._utils import mpi_rank
from tensorrt_llm.lora_manager import HfLoraLoader
from tensorrt_llm.models.convert_utils import split_matrix_tp

Expand Down Expand Up @@ -844,6 +845,7 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
}

def load_single_module(name, module):
torch.cuda.set_device(mpi_rank())
if len(module._parameters) > 0:
# skip load weights if module is in skip_modules
if any(skip_module in name for skip_module in skip_modules):
Expand Down Expand Up @@ -940,6 +942,7 @@ def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM],
logger.info(f"Renamed weights with params_map: {params_map}")

def load_single_module(name, module):
torch.cuda.set_device(mpi_rank())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consistent fix applied to v2 implementation - good practice.

This change mirrors the fix in _load_weights_impl and ensures both weight loading implementations handle GPU device context correctly during concurrent execution.

Consider consolidating the error handling by extracting the device setting logic into a helper function:

+def _set_cuda_device_for_worker():
+    """Set CUDA device to current MPI rank with error handling."""
+    try:
+        torch.cuda.set_device(mpi_rank())
+    except (RuntimeError, ValueError) as e:
+        logger.warning(f"Failed to set CUDA device to MPI rank {mpi_rank()}: {e}")
+        # Handle appropriately based on requirements

Then use this helper in both load_single_module functions to maintain consistency and reduce code duplication.

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_utils.py at line 878, the call to set the
CUDA device using torch.cuda.set_device(mpi_rank()) should be extracted into a
separate helper function that sets the device context. Refactor both
_load_weights_impl and the v2 load_single_module functions to call this helper
instead of duplicating the device setting logic, ensuring consistent GPU device
handling and reducing code duplication.

if len(module._parameters) > 0:
if weight_mapper.should_skip_module(name):
return
Expand Down