@@ -94,15 +94,40 @@ class FusedMoeRunner : public torch::CustomClassHolder
9494 }
9595 };
9696
97+ template <typename TypeAct>
98+ std::unique_ptr<kernels::CutlassMoeFCRunnerInterface> create_weight_quant_runner ()
99+ {
100+ if (isInt8Quant ())
101+ {
102+ return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, uint8_t >>();
103+ }
104+ else if (isInt4Quant ())
105+ {
106+ #ifdef ENABLE_FP8
107+ if (mUseW4GroupScaling )
108+ {
109+ return std::make_unique<
110+ kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t , TypeAct, TypeAct>>();
111+ }
112+ #endif
113+ return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, cutlass::uint4b_t >>();
114+ }
115+ else
116+ {
117+ C10_THROW_ERROR_FORMATTED (Error, " Unsupported weight quantization type" );
118+ }
119+ }
120+
97121 FusedMoeRunner (c10::ScalarType activation_dtype, c10::ScalarType weight_dtype, c10::ScalarType output_dtype,
98- bool use_deepseek_fp8_block_scale, bool use_w4_group_scaling, bool use_mxfp8_act_scaling ,
99- bool use_fused_finalize)
122+ bool use_deepseek_fp8_block_scale, bool use_w4_group_scaling, bool use_int8_woq_per_channel ,
123+ bool use_mxfp8_act_scaling, bool use_fused_finalize)
100124 {
101125 mActivationDtype = activation_dtype;
102126 mWeightDtype = weight_dtype;
103127 mOutputDtype = output_dtype;
104128 mUseDeepSeekFP8BlockScaling = use_deepseek_fp8_block_scale;
105129 mUseW4GroupScaling = use_w4_group_scaling;
130+ mUseINT8WoqPerChannel = use_int8_woq_per_channel;
106131 mUseMxfp8ActScaling = use_mxfp8_act_scaling;
107132 mUseFusedFinalize = use_fused_finalize;
108133 mInnerDimMultiplier = 1 ;
@@ -137,7 +162,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
137162 mInnerDimMultiplier = 16 ; // 16 FP4 -> 1 LONG
138163 mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp4_e2m1>(mOutputDtype );
139164 }
140-
141165 if (isNvfp4Quant ())
142166 {
143167 mInnerDimMultiplier = 16 ; // 16 FP4 -> 1 LONG
@@ -152,7 +176,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
152176 default : mKernelRunner = switch_output_type<__nv_fp4_e2m1, __nv_fp4_e2m1, false >(mOutputDtype );
153177 }
154178 }
155-
156179 if (isWFP4A16Quant ())
157180 {
158181 mInnerDimMultiplier = 2 ;
@@ -167,45 +190,19 @@ class FusedMoeRunner : public torch::CustomClassHolder
167190 }
168191#endif
169192 }
170-
171193#endif
172- if (isInt4Quant ())
194+ if (isIntWeightOnlyQuant ())
173195 {
174- mInnerDimMultiplier = 2 ; // 2 INT4 -> 1 INT8
175- if (mActivationDtype == c10::ScalarType::Half)
196+ if (isInt4Quant ())
176197 {
177- #ifdef ENABLE_FP8
178- if (mUseW4GroupScaling )
179- {
180- mKernelRunner
181- = std::make_unique<kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t , half, half>>();
182- }
183- else
184- {
185- mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<half, cutlass::uint4b_t >>();
186- }
187- #else
188- mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<half, cutlass::uint4b_t >>();
189- #endif
198+ mInnerDimMultiplier = 2 ; // 2 INT4 -> 1 INT8
190199 }
191- #ifdef ENABLE_BF16
192- else if (mActivationDtype == c10::ScalarType::BFloat16)
200+ switch (mActivationDtype )
193201 {
194- #ifdef ENABLE_FP8
195- if (mUseW4GroupScaling )
196- {
197- mKernelRunner = std::make_unique<
198- kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t , __nv_bfloat16, __nv_bfloat16>>();
199- }
200- else
201- {
202- mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t >>();
203- }
204- #else
205- mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t >>();
206- #endif
202+ case c10::ScalarType::Half: mKernelRunner = create_weight_quant_runner<half>(); break ;
203+ case c10::ScalarType::BFloat16: mKernelRunner = create_weight_quant_runner<__nv_bfloat16>(); break ;
204+ default : C10_THROW_ERROR_FORMATTED (Error, " Unsupported activation type for int-type weight" );
207205 }
208- #endif
209206 }
210207 if (!mKernelRunner )
211208 {
@@ -310,13 +307,31 @@ class FusedMoeRunner : public torch::CustomClassHolder
310307 }
311308 TORCH_CHECK (fc1_expert_weights.sizes ()[0 ] == fc2_expert_weights.sizes ()[0 ],
312309 " fc1_expert_weights and fc2_expert_weights must have the same number of experts." );
313- TORCH_CHECK (fc1_expert_weights.sizes ()[1 ] == fc2_expert_weights.sizes ()[2 ] * mInnerDimMultiplier * 2 ,
314- " fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size." );
310+
311+ if (mUseINT8WoqPerChannel )
312+ {
313+ // Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
314+ // [num_experts, inter_size, hidden_size]
315+ TORCH_CHECK (fc1_expert_weights.sizes ()[2 ] == fc2_expert_weights.sizes ()[1 ] * mInnerDimMultiplier * 2 ,
316+ " fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size." );
317+ }
318+ else
319+ {
320+ TORCH_CHECK (fc1_expert_weights.sizes ()[1 ] == fc2_expert_weights.sizes ()[2 ] * mInnerDimMultiplier * 2 ,
321+ " fc1_expert_weights inter size must be fc2_expert_weights inter size." );
322+ }
315323
316324 int experts_per_token = token_selected_experts.sizes ()[1 ];
317325 int64_t num_rows = input.sizes ()[0 ];
318326 int64_t hidden_size = fc2_expert_weights.sizes ()[1 ];
319327 int64_t inter_size = fc2_expert_weights.sizes ()[2 ] * mInnerDimMultiplier ;
328+ if (mUseINT8WoqPerChannel )
329+ {
330+ // Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
331+ // [num_experts, inter_size, hidden_size]
332+ hidden_size = fc2_expert_weights.sizes ()[2 ] * mInnerDimMultiplier ;
333+ inter_size = fc2_expert_weights.sizes ()[1 ];
334+ }
320335
321336 if (isWMxfp4AMxfp8Quant () || isWMxfp4AFp8Quant ())
322337 {
@@ -593,8 +608,15 @@ class FusedMoeRunner : public torch::CustomClassHolder
593608 }
594609
595610 int64_t const num_rows = input.sizes ()[0 ];
596- int64_t const hidden_size = fc2_expert_weights.sizes ()[1 ];
597- int64_t const inter_size = fc2_expert_weights.sizes ()[2 ] * mInnerDimMultiplier ;
611+ int64_t hidden_size = fc2_expert_weights.sizes ()[1 ];
612+ int64_t inter_size = fc2_expert_weights.sizes ()[2 ] * mInnerDimMultiplier ;
613+ if (mUseINT8WoqPerChannel )
614+ {
615+ // Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
616+ // [num_experts, inter_size, hidden_size]
617+ hidden_size = fc2_expert_weights.sizes ()[2 ] * mInnerDimMultiplier ;
618+ inter_size = fc2_expert_weights.sizes ()[1 ];
619+ }
598620 int64_t const group_size_
599621 = isInt4Quant () ? TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size : -1 ;
600622 int64_t const group_size = isWFP4A16Quant ()
@@ -677,6 +699,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
677699
678700 bool mUseDeepSeekFP8BlockScaling = false ;
679701 bool mUseW4GroupScaling = false ;
702+ bool mUseINT8WoqPerChannel = false ;
680703 bool mUseMxfp8ActScaling = false ;
681704 bool mUseFusedFinalize = true ;
682705
@@ -891,7 +914,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
891914 TORCH_CHECK (false , " MXFP8 x MXFP4 quantization is not supported in OSS Cutlass Moe Gemm" );
892915#endif
893916 }
894-
895917 else if (isNvfp4Quant ())
896918 {
897919 TORCH_CHECK (quant_scales.has_value (), " Expecting quant scales for nvfp4 quantization" );
@@ -966,8 +988,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
966988 }
967989 else if (isWFP4A16Quant ())
968990 {
969- TORCH_CHECK (quant_scales.has_value (), " Expecting quant scales for W4 quantization" );
970- TORCH_CHECK (quant_scales.value ().size () == 2 , " Expecting 8 quant scales for W4A16 quantization" );
991+ TORCH_CHECK (quant_scales.has_value (), " Expecting quant scales for weight only quantization" );
992+ TORCH_CHECK (quant_scales.value ().size () == 2 , " Expecting 2 quant scales for W4A16 quantization" );
971993
972994 auto & fc1_weight_scales = quant_scales.value ()[0 ];
973995 auto & fc2_weight_scales = quant_scales.value ()[1 ];
@@ -976,28 +998,45 @@ class FusedMoeRunner : public torch::CustomClassHolder
976998 static_cast <void const *>(fc2_weight_scales.data_ptr ()), nullptr , nullptr , nullptr , nullptr , nullptr ,
977999 nullptr );
9781000 }
979- else if (isInt4Quant ())
1001+ else if (isIntWeightOnlyQuant ())
9801002 {
981- TORCH_CHECK (quant_scales.has_value (), " Expecting quant scales for W4 quantization" );
982- TORCH_CHECK (quant_scales.value ().size () == 8 , " Expecting 8 quant scales for W4A8 quantization" );
983-
984- auto & fc1_weight_scales = quant_scales.value ()[0 ];
985- auto & fc2_weight_scales = quant_scales.value ()[1 ];
986- auto & fc1_act_scales = quant_scales.value ()[2 ];
987- auto & fc2_act_scales = quant_scales.value ()[3 ];
988- auto & fc1_weight_zeros = quant_scales.value ()[4 ];
989- auto & fc2_weight_zeros = quant_scales.value ()[5 ];
990- auto & fc1_alpha = quant_scales.value ()[6 ];
991- auto & fc2_alpha = quant_scales.value ()[7 ];
992- int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size;
993- return kernels::QuantParams::GroupWise (group_size, static_cast <void const *>(fc1_weight_scales.data_ptr ()),
994- static_cast <void const *>(fc2_weight_scales.data_ptr ()),
995- static_cast <void const *>(fc1_act_scales.numel () > 0 ? fc1_act_scales.data_ptr () : nullptr ),
996- static_cast <void const *>(fc2_act_scales.numel () > 0 ? fc2_act_scales.data_ptr () : nullptr ),
997- static_cast <void const *>(fc1_weight_zeros.numel () > 0 ? fc1_weight_zeros.data_ptr () : nullptr ),
998- static_cast <void const *>(fc2_weight_zeros.numel () > 0 ? fc2_weight_zeros.data_ptr () : nullptr ),
999- static_cast <float const *>(fc1_alpha.numel () > 0 ? fc1_alpha.data_ptr () : nullptr ),
1000- static_cast <float const *>(fc2_alpha.numel () > 0 ? fc2_alpha.data_ptr () : nullptr ));
1003+ TORCH_CHECK (quant_scales.has_value (), " Expecting quant scales for weight only quantization" );
1004+ if (mUseINT8WoqPerChannel )
1005+ {
1006+ TORCH_CHECK (
1007+ quant_scales.value ().size () == 2 , " Expecting 2 quant scales for INT8 weight only quantization" );
1008+ auto & fc1_weight_scales = quant_scales.value ()[0 ];
1009+ auto & fc2_weight_scales = quant_scales.value ()[1 ];
1010+ return kernels::QuantParams::Int (static_cast <float const *>(fc1_weight_scales.data_ptr ()),
1011+ static_cast <float const *>(fc2_weight_scales.data_ptr ()));
1012+ }
1013+ else if (isInt4Quant () && mUseW4GroupScaling )
1014+ {
1015+ TORCH_CHECK (quant_scales.value ().size () == 8 , " Expecting 8 quant scales for W4A8 quantization" );
1016+
1017+ auto & fc1_weight_scales = quant_scales.value ()[0 ];
1018+ auto & fc2_weight_scales = quant_scales.value ()[1 ];
1019+ auto & fc1_act_scales = quant_scales.value ()[2 ];
1020+ auto & fc2_act_scales = quant_scales.value ()[3 ];
1021+ auto & fc1_weight_zeros = quant_scales.value ()[4 ];
1022+ auto & fc2_weight_zeros = quant_scales.value ()[5 ];
1023+ auto & fc1_alpha = quant_scales.value ()[6 ];
1024+ auto & fc2_alpha = quant_scales.value ()[7 ];
1025+ int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size;
1026+ return kernels::QuantParams::GroupWise (group_size,
1027+ static_cast <void const *>(fc1_weight_scales.data_ptr ()),
1028+ static_cast <void const *>(fc2_weight_scales.data_ptr ()),
1029+ static_cast <void const *>(fc1_act_scales.numel () > 0 ? fc1_act_scales.data_ptr () : nullptr ),
1030+ static_cast <void const *>(fc2_act_scales.numel () > 0 ? fc2_act_scales.data_ptr () : nullptr ),
1031+ static_cast <void const *>(fc1_weight_zeros.numel () > 0 ? fc1_weight_zeros.data_ptr () : nullptr ),
1032+ static_cast <void const *>(fc2_weight_zeros.numel () > 0 ? fc2_weight_zeros.data_ptr () : nullptr ),
1033+ static_cast <float const *>(fc1_alpha.numel () > 0 ? fc1_alpha.data_ptr () : nullptr ),
1034+ static_cast <float const *>(fc2_alpha.numel () > 0 ? fc2_alpha.data_ptr () : nullptr ));
1035+ }
1036+ else
1037+ {
1038+ TORCH_CHECK (false , " Unsupported weight only quantization" );
1039+ }
10011040 }
10021041 else
10031042 {
@@ -1022,6 +1061,11 @@ class FusedMoeRunner : public torch::CustomClassHolder
10221061 return mUseW4GroupScaling && mWeightDtype == c10::ScalarType::Byte;
10231062 }
10241063
1064+ bool isInt8Quant () const
1065+ {
1066+ return mWeightDtype == c10::ScalarType::Char;
1067+ }
1068+
10251069 bool isInt4Quant () const
10261070 {
10271071 return mWeightDtype == c10::ScalarType::QUInt4x2;
@@ -1032,6 +1076,11 @@ class FusedMoeRunner : public torch::CustomClassHolder
10321076 return mActivationDtype == c10::ScalarType::Float8_e4m3fn && isInt4Quant ();
10331077 }
10341078
1079+ bool isIntWeightOnlyQuant () const
1080+ {
1081+ return isInt8Quant () || isInt4Quant ();
1082+ }
1083+
10351084 bool isWMxfp4AFp8Quant () const
10361085 {
10371086 return mActivationDtype == c10::ScalarType::Float8_e4m3fn && mWeightDtype == c10::ScalarType::Long
@@ -1050,7 +1099,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
10501099TORCH_LIBRARY (trtllm, m)
10511100{
10521101 m.class_ <torch_ext::FusedMoeRunner>(" FusedMoeRunner" )
1053- .def (torch::init<c10::ScalarType, c10::ScalarType, c10::ScalarType, bool , bool , bool , bool >())
1102+ .def (torch::init<c10::ScalarType, c10::ScalarType, c10::ScalarType, bool , bool , bool , bool , bool >())
10541103 .def (" run_gemm_profile" , &torch_ext::FusedMoeRunner::runGemmProfile)
10551104 .def (" get_tactic_num" , &torch_ext::FusedMoeRunner::getTacticNum)
10561105 .def (" run_moe" , &torch_ext::FusedMoeRunner::runMoe)
0 commit comments