Skip to content

Commit 7218c53

Browse files
committed
Allocate MoE workspace only when necessary
Signed-off-by: Yilin Fan <[email protected]>
1 parent ad662dd commit 7218c53

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

cpp/tensorrt_llm/thop/moeOp.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
308308
std::vector<int64_t> output_shape = {num_rows, hidden_size};
309309
auto output = torch::empty(output_shape, input.options().dtype(mOutputDtype));
310310

311-
WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
311+
setWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
312312
static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode);
313313

314314
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
@@ -439,7 +439,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
439439
min_latency_params.experts_to_token_score = static_cast<float*>(experts_to_token_score.data_ptr());
440440
min_latency_params.active_expert_global_ids = static_cast<int*>(active_expert_global_ids.data_ptr());
441441

442-
WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
442+
setWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
443443
static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode);
444444

445445
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
@@ -577,6 +577,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
577577
// e.g. 16 nvfp4 elements are packed into a single int64 element
578578
int64_t mInnerDimMultiplier;
579579
char* mProfileWorkspace = nullptr;
580+
WorkspaceInfo workspace_info;
580581

581582
bool mUseDeepSeekFP8BlockScaling = false;
582583
bool mUseW4A8GroupScaling = false;
@@ -622,7 +623,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
622623
mKernelRunner->setTactic(best_gemm1_profile, best_gemm2_profile);
623624
}
624625

625-
WorkspaceInfo getWorkspaceInfo(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
626+
void setWorkspaceInfo(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
626627
int num_experts, int experts_per_token, ActivationType activation_type,
627628
kernels::MOEParallelismConfig const& parallelismConfig, bool min_latency_mode)
628629
{
@@ -633,15 +634,16 @@ class FusedMoeRunner : public torch::CustomClassHolder
633634

634635
std::vector<size_t> workspaces{moe_workspace_size, src_to_dest_map_size};
635636

636-
size_t total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size());
637+
int64_t const total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size());
637638

638-
WorkspaceInfo info{};
639-
info.workspace = torch::empty({static_cast<long>(total_workspace_size)},
640-
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
641-
info.src_to_dest_map
642-
= common::nextWorkspacePtr(static_cast<int8_t*>(info.workspace.data_ptr()), moe_workspace_size);
643-
644-
return info;
639+
if (workspace_info.workspace.numel() < total_workspace_size) {
640+
TLLM_LOG_WARNING("MoE workspace size is not enough, increase the size from %ld bytes to %ld bytes",
641+
workspace_info.workspace.numel(), total_workspace_size);
642+
workspace_info.workspace = torch::empty({static_cast<long>(total_workspace_size)},
643+
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
644+
}
645+
workspace_info.src_to_dest_map
646+
= common::nextWorkspacePtr(static_cast<int8_t*>(workspace_info.workspace.data_ptr()), moe_workspace_size);
645647
}
646648

647649
kernels::QuantParams getQuantParams(int64_t const num_experts_on_rank, int64_t const hidden_size,

0 commit comments

Comments
 (0)