Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 46 additions & 36 deletions tensorrt_llm/_torch/modules/fused_moe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import os
import threading
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from typing import Dict, List, NamedTuple, Optional, Union

Expand Down Expand Up @@ -1278,45 +1278,55 @@ def load_expert_w2_weight(w2_weight,
# Even though CPython has global interpreter lock (GIL),
# it's still faster to load weights in parallel because it can utilize
# CPU memory bandwidth better.
threads = []
max_workers = min(
(self.expert_end - self.expert_start) * 2,
os.cpu_count() * 2,
16,
)

for expert_id in range(self.expert_start, self.expert_end):
expert_idx = expert_id - self.expert_start
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []

for expert_id in range(self.expert_start, self.expert_end):
expert_idx = expert_id - self.expert_start

if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
w1_weight = weights[f"{expert_id}.w1.weight"]
w3_weight = weights[f"{expert_id}.w3.weight"]
w2_weight = weights[f"{expert_id}.w2.weight"]
elif self.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
w1_w3_weight = weights["gate_up_proj"][expert_id].transpose(
0, 1)
w1_weight, w3_weight = w1_w3_weight.chunk(2, dim=0)
w2_weight = weights["down_proj"][expert_id].transpose(
0, 1).contiguous()
else:
raise NotImplementedError(
f"Unknown weight loading mode in MoE: {self.weight_loading_mode}"
)

if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
w1_weight = weights[f"{expert_id}.w1.weight"]
w3_weight = weights[f"{expert_id}.w3.weight"]
w2_weight = weights[f"{expert_id}.w2.weight"]
elif self.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
w1_w3_weight = weights["gate_up_proj"][expert_id].transpose(
0, 1)
w1_weight, w3_weight = w1_w3_weight.chunk(2, dim=0)
w2_weight = weights["down_proj"][expert_id].transpose(
0, 1).contiguous()
else:
raise NotImplementedError(
f"Unknown weight loading mode in MoE: {self.weight_loading_mode}"
is_trtllm_nvfp4 = self.is_trtllm(
) and self.quant_config.quant_mode.has_nvfp4()

future_w3_w1 = executor.submit(
load_expert_w3_w1_weight,
w1_weight,
w3_weight,
self.w3_w1_weight.data[expert_idx],
is_trtllm_nvfp4,
)
futures.append(future_w3_w1)

future_w2 = executor.submit(
load_expert_w2_weight,
w2_weight,
self.w2_weight.data[expert_idx],
is_trtllm_nvfp4,
)
futures.append(future_w2)

is_trtllm_nvfp4 = self.is_trtllm(
) and self.quant_config.quant_mode.has_nvfp4()

thread = threading.Thread(target=load_expert_w3_w1_weight,
args=(w1_weight, w3_weight,
self.w3_w1_weight.data[expert_idx],
is_trtllm_nvfp4))
thread.start()
threads.append(thread)

thread = threading.Thread(target=load_expert_w2_weight,
args=(w2_weight,
self.w2_weight.data[expert_idx],
is_trtllm_nvfp4))
thread.start()
threads.append(thread)

for thread in threads:
thread.join()
for future in futures:
future.result()

if self.quant_config and self.quant_config.quant_mode.has_any_quant(
exclude_kv_cache=True):
Expand Down
5 changes: 0 additions & 5 deletions tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -484,11 +484,6 @@ test_e2e.py::test_ptp_quickstart_advanced_8gpus[Llama3.1-70B-BF16-llama-3.1-mode
test_e2e.py::test_ptp_quickstart_advanced_8gpus[Mixtral-8x7B-BF16-Mixtral-8x7B-v0.1] SKIP (https://nvbugs/5136994)
test_e2e.py::test_ptp_quickstart_advanced_8gpus[Llama3.1-70B-FP8-llama-3.1-model/Llama-3.1-70B-Instruct-FP8] SKIP (https://nvbugs/5289909)
test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B] SKIP (https://nvbugs/5289910)
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5289912)
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5232406)
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5232406)
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5232406)
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one_mtp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5232406)
perf/test_perf.py::test_perf[mixtral_8x22b_v0.1-bench-float16-input_output_len:512,512-quant:fp8-tp:4] SKIP (https://nvbugspro.nvidia.com/bug/5274894)
perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-input_output_len:128,128-gpus:4] SKIP (https://nvbugspro.nvidia.com/bug/5274894)
perf/test_perf.py::test_perf[llama_v3.1_70b-bench-streaming-bfloat16-input_output_len:128,128-gpus:4] SKIP (https://nvbugspro.nvidia.com/bug/5274894)
Expand Down