diff --git a/tensorrt_llm/_torch/modules/fused_moe.py b/tensorrt_llm/_torch/modules/fused_moe.py index 2f283a1da50..bc786f62274 100755 --- a/tensorrt_llm/_torch/modules/fused_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe.py @@ -1,6 +1,6 @@ import math import os -import threading +from concurrent.futures import ThreadPoolExecutor from enum import Enum from typing import Dict, List, NamedTuple, Optional, Union @@ -1278,45 +1278,55 @@ def load_expert_w2_weight(w2_weight, # Even though CPython has global interpreter lock (GIL), # it's still faster to load weights in parallel because it can utilize # CPU memory bandwidth better. - threads = [] + max_workers = min( + (self.expert_end - self.expert_start) * 2, + os.cpu_count() * 2, + 16, + ) - for expert_id in range(self.expert_start, self.expert_end): - expert_idx = expert_id - self.expert_start + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + + for expert_id in range(self.expert_start, self.expert_end): + expert_idx = expert_id - self.expert_start + + if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w1_weight = weights[f"{expert_id}.w1.weight"] + w3_weight = weights[f"{expert_id}.w3.weight"] + w2_weight = weights[f"{expert_id}.w2.weight"] + elif self.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + w1_w3_weight = weights["gate_up_proj"][expert_id].transpose( + 0, 1) + w1_weight, w3_weight = w1_w3_weight.chunk(2, dim=0) + w2_weight = weights["down_proj"][expert_id].transpose( + 0, 1).contiguous() + else: + raise NotImplementedError( + f"Unknown weight loading mode in MoE: {self.weight_loading_mode}" + ) - if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA: - w1_weight = weights[f"{expert_id}.w1.weight"] - w3_weight = weights[f"{expert_id}.w3.weight"] - w2_weight = weights[f"{expert_id}.w2.weight"] - elif self.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: - w1_w3_weight = weights["gate_up_proj"][expert_id].transpose( - 0, 1) - w1_weight, w3_weight = w1_w3_weight.chunk(2, dim=0) - w2_weight = weights["down_proj"][expert_id].transpose( - 0, 1).contiguous() - else: - raise NotImplementedError( - f"Unknown weight loading mode in MoE: {self.weight_loading_mode}" + is_trtllm_nvfp4 = self.is_trtllm( + ) and self.quant_config.quant_mode.has_nvfp4() + + future_w3_w1 = executor.submit( + load_expert_w3_w1_weight, + w1_weight, + w3_weight, + self.w3_w1_weight.data[expert_idx], + is_trtllm_nvfp4, + ) + futures.append(future_w3_w1) + + future_w2 = executor.submit( + load_expert_w2_weight, + w2_weight, + self.w2_weight.data[expert_idx], + is_trtllm_nvfp4, ) + futures.append(future_w2) - is_trtllm_nvfp4 = self.is_trtllm( - ) and self.quant_config.quant_mode.has_nvfp4() - - thread = threading.Thread(target=load_expert_w3_w1_weight, - args=(w1_weight, w3_weight, - self.w3_w1_weight.data[expert_idx], - is_trtllm_nvfp4)) - thread.start() - threads.append(thread) - - thread = threading.Thread(target=load_expert_w2_weight, - args=(w2_weight, - self.w2_weight.data[expert_idx], - is_trtllm_nvfp4)) - thread.start() - threads.append(thread) - - for thread in threads: - thread.join() + for future in futures: + future.result() if self.quant_config and self.quant_config.quant_mode.has_any_quant( exclude_kv_cache=True): diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index d4583e06469..22f2f276a7c 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -484,11 +484,6 @@ test_e2e.py::test_ptp_quickstart_advanced_8gpus[Llama3.1-70B-BF16-llama-3.1-mode test_e2e.py::test_ptp_quickstart_advanced_8gpus[Mixtral-8x7B-BF16-Mixtral-8x7B-v0.1] SKIP (https://nvbugs/5136994) test_e2e.py::test_ptp_quickstart_advanced_8gpus[Llama3.1-70B-FP8-llama-3.1-model/Llama-3.1-70B-Instruct-FP8] SKIP (https://nvbugs/5289909) test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B] SKIP (https://nvbugs/5289910) -disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5289912) -disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5232406) -disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5232406) -disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5232406) -disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one_mtp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5232406) perf/test_perf.py::test_perf[mixtral_8x22b_v0.1-bench-float16-input_output_len:512,512-quant:fp8-tp:4] SKIP (https://nvbugspro.nvidia.com/bug/5274894) perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-input_output_len:128,128-gpus:4] SKIP (https://nvbugspro.nvidia.com/bug/5274894) perf/test_perf.py::test_perf[llama_v3.1_70b-bench-streaming-bfloat16-input_output_len:128,128-gpus:4] SKIP (https://nvbugspro.nvidia.com/bug/5274894)