@@ -308,7 +308,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
308308 std::vector<int64_t > output_shape = {num_rows, hidden_size};
309309 auto output = torch::empty (output_shape, input.options ().dtype (mOutputDtype ));
310310
311- WorkspaceInfo workspace_info = getWorkspaceInfo (num_rows, hidden_size, inter_size, num_experts_total,
311+ setWorkspaceInfo (num_rows, hidden_size, inter_size, num_experts_total,
312312 static_cast <int >(experts_per_token), activation_type, parallelism_config, min_latency_mode);
313313
314314 auto const quant_params = getQuantParams (num_experts_on_rank, hidden_size, inter_size, quant_scales);
@@ -439,7 +439,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
439439 min_latency_params.experts_to_token_score = static_cast <float *>(experts_to_token_score.data_ptr ());
440440 min_latency_params.active_expert_global_ids = static_cast <int *>(active_expert_global_ids.data_ptr ());
441441
442- WorkspaceInfo workspace_info = getWorkspaceInfo (num_rows, hidden_size, inter_size, num_experts_total,
442+ setWorkspaceInfo (num_rows, hidden_size, inter_size, num_experts_total,
443443 static_cast <int >(experts_per_token), activation_type, parallelism_config, min_latency_mode);
444444
445445 auto const quant_params = getQuantParams (num_experts_on_rank, hidden_size, inter_size, quant_scales);
@@ -577,6 +577,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
577577 // e.g. 16 nvfp4 elements are packed into a single int64 element
578578 int64_t mInnerDimMultiplier ;
579579 char * mProfileWorkspace = nullptr ;
580+ WorkspaceInfo workspace_info;
580581
581582 bool mUseDeepSeekFP8BlockScaling = false ;
582583 bool mUseW4A8GroupScaling = false ;
@@ -622,7 +623,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
622623 mKernelRunner ->setTactic (best_gemm1_profile, best_gemm2_profile);
623624 }
624625
625- WorkspaceInfo getWorkspaceInfo (int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
626+ void setWorkspaceInfo (int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
626627 int num_experts, int experts_per_token, ActivationType activation_type,
627628 kernels::MOEParallelismConfig const & parallelismConfig, bool min_latency_mode)
628629 {
@@ -633,15 +634,16 @@ class FusedMoeRunner : public torch::CustomClassHolder
633634
634635 std::vector<size_t > workspaces{moe_workspace_size, src_to_dest_map_size};
635636
636- size_t total_workspace_size = common::calculateTotalWorkspaceSize (workspaces.data (), workspaces.size ());
637+ int64_t const total_workspace_size = common::calculateTotalWorkspaceSize (workspaces.data (), workspaces.size ());
637638
638- WorkspaceInfo info{};
639- info.workspace = torch::empty ({static_cast <long >(total_workspace_size)},
640- torch::dtype (torch::kInt8 ).device (torch::kCUDA ).requires_grad (false ));
641- info.src_to_dest_map
642- = common::nextWorkspacePtr (static_cast <int8_t *>(info.workspace .data_ptr ()), moe_workspace_size);
643-
644- return info;
639+ if (workspace_info.workspace .numel () < total_workspace_size) {
640+ TLLM_LOG_WARNING (" MoE workspace size is not enough, increase the size from %ld bytes to %ld bytes" ,
641+ workspace_info.workspace .numel (), total_workspace_size);
642+ workspace_info.workspace = torch::empty ({static_cast <long >(total_workspace_size)},
643+ torch::dtype (torch::kInt8 ).device (torch::kCUDA ).requires_grad (false ));
644+ }
645+ workspace_info.src_to_dest_map
646+ = common::nextWorkspacePtr (static_cast <int8_t *>(workspace_info.workspace .data_ptr ()), moe_workspace_size);
645647 }
646648
647649 kernels::QuantParams getQuantParams (int64_t const num_experts_on_rank, int64_t const hidden_size,
0 commit comments