@@ -77,31 +77,45 @@ struct enable_sm89_to_sm90 : Kernel {
7777};
7878
7979/* 
80-    This epilogue function defines a quantized GEMM operation similar to 
81-    torch._scaled_mm. 
82- 
83-    A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or 
84-    per-row. B can be quantized per-tensor or per-column. 
85-    Any combination of per-tensor and per-row or column is supported. 
86-    A and B must have symmetric quantization (zero point == 0). 
87- 
88-    So the GEMM operation is D = (a_scales * A) (b_scales * B), where the 
89-    scales are applied elementwise with numpy-style broadcasting. 
90- 
91-    ScaleA and ScaleB define the epilogue functions that apply the scales for 
92-    the A and B operands respectively. These scales may be either per-tensor or 
93-    per row or column. 
94- */ 
80+  * This class provides the common ScaleA and ScaleB descriptors for the 
81+  * ScaledEpilogue and ScaledEpilogueBias classes. 
82+  */  
9583template  <typename  ElementD, typename  OutputTileThreadMap>
96- struct  ScaledEpilogue  {
97-  private :
84+ struct  ScaledEpilogueBase  {
85+  protected :
9886  using  Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
9987
10088  using  ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
10189      OutputTileThreadMap, float , Stride<Int<1 >, Int<0 >, Int<0 >>>;
10290
10391  using  ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
10492      OutputTileThreadMap, float , Stride<Int<0 >, Int<1 >, Int<0 >>>;
93+ };
94+ 
95+ /* 
96+  This epilogue function defines a quantized GEMM operation similar to 
97+  torch._scaled_mm. 
98+ 
99+  A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or 
100+  per-row. B can be quantized per-tensor or per-column. 
101+  Any combination of per-tensor and per-row or column is supported. 
102+  A and B must have symmetric quantization (zero point == 0). 
103+ 
104+  So the GEMM operation is D = (a_scales * A) (b_scales * B), where the 
105+  scales are applied elementwise with numpy-style broadcasting. 
106+ 
107+  ScaleA and ScaleB define the epilogue functions that apply the scales for 
108+  the A and B operands respectively. These scales may be either per-tensor or 
109+  per row or column. 
110+ */ 
111+ template  <typename  ElementD, typename  OutputTileThreadMap>
112+ struct  ScaledEpilogue 
113+     : private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
114+  private: 
115+   using  SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
116+   using  Accum = typename  SUPER::Accum;
117+   using  ScaleA = typename  SUPER::ScaleA;
118+   using  ScaleB = typename  SUPER::ScaleB;
105119
106120  using  Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
107121      cutlass::multiplies, float , float ,
@@ -134,6 +148,53 @@ struct ScaledEpilogue {
134148  }
135149};
136150
151+ template  <typename  ElementD, typename  OutputTileThreadMap>
152+ struct  ScaledEpilogueBias 
153+     : private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
154+  private: 
155+   using  SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
156+   using  Accum = typename  SUPER::Accum;
157+   using  ScaleA = typename  SUPER::ScaleA;
158+   using  ScaleB = typename  SUPER::ScaleB;
159+ 
160+   using  Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
161+       cutlass::multiplies, float , float ,
162+       cutlass::FloatRoundStyle::round_to_nearest>;
163+ 
164+   using  EVTCompute0 =
165+       cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
166+ 
167+   using  Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
168+       cutlass::multiply_add, ElementD, float ,
169+       cutlass::FloatRoundStyle::round_to_nearest>;
170+ 
171+   using  Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
172+       OutputTileThreadMap, ElementD, Stride<Int<0 >, Int<1 >, Int<0 >>>;
173+ 
174+  public: 
175+   using  EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
176+                                                              EVTCompute0, Bias>;
177+   using  ArgumentType = typename  EVTCompute::Arguments;
178+ 
179+   static  ArgumentType prepare_args (torch::Tensor const & a_scales,
180+                                    torch::Tensor const & b_scales,
181+                                    torch::Tensor const & bias) {
182+     using  ScaleAArgs = typename  ScaleA::Arguments;
183+     using  ScaleBArgs = typename  ScaleB::Arguments;
184+     using  BiasArgs = typename  Bias::Arguments;
185+ 
186+     ScaleBArgs b_args{b_scales.data_ptr <float >(), b_scales.numel () != 1 , {}};
187+     ScaleAArgs a_args{a_scales.data_ptr <float >(), a_scales.numel () != 1 , {}};
188+     BiasArgs bias_args{static_cast <ElementD*>(bias.data_ptr ()), {}};
189+ 
190+     typename  EVTCompute0::Arguments evt0_compute_args{b_args};
191+ 
192+     typename  EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args,
193+                                                     bias_args};
194+     return  evt_compute_args;
195+   }
196+ };
197+ 
137198template  <typename  Arch, template  <typename > typename  ArchGuard,
138199          typename  ElementAB_, typename  ElementD_,
139200          template  <typename , typename > typename  Epilogue_, typename  TileShape,
@@ -168,13 +229,13 @@ struct cutlass_2x_gemm {
168229  //  clang-format off
169230  using  RowMajor = typename  cutlass::layout::RowMajor;
170231  using  ColumnMajor = typename  cutlass::layout::ColumnMajor;
171-   using  KernelType =  
232+   using  KernelType =
172233    ArchGuard<typename  cutlass::gemm::kernel::DefaultGemmWithVisitor<
173-       ElementAB, RowMajor, cutlass::ComplexTransform::kNone , 16 ,  
174-       ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone , 16 ,  
234+       ElementAB, RowMajor, cutlass::ComplexTransform::kNone , 16 ,
235+       ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone , 16 ,
175236      float , cutlass::layout::RowMajor, 4 ,
176-       ElementAcc, float , cutlass::arch::OpClassTensorOp,  
177-       Arch,  
237+       ElementAcc, float , cutlass::arch::OpClassTensorOp,
238+       Arch,
178239      TileShape, WarpShape, InstructionShape,
179240      EVTD,
180241      cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
@@ -404,14 +465,13 @@ void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a,
404465  }
405466}
406467
407- void  cutlass_scaled_mm_sm75 (torch::Tensor& out, torch::Tensor const & a,
408-                             torch::Tensor const & b,
409-                             torch::Tensor const & a_scales,
410-                             torch::Tensor const & b_scales) {
468+ template  <template  <typename , typename > typename  Epilogue,
469+           typename ... EpilogueArgs>
470+ void  cutlass_scaled_mm_sm75_epilogue (torch::Tensor& out, torch::Tensor const & a,
471+                                      torch::Tensor const & b,
472+                                      EpilogueArgs&&... epilogue_args) {
411473  TORCH_CHECK (a.dtype () == torch::kInt8 );
412474  TORCH_CHECK (b.dtype () == torch::kInt8 );
413-   TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
414-   TORCH_CHECK (b_scales.dtype () == torch::kFloat32 );
415475
416476  using  TileShape = typename  cutlass::gemm::GemmShape<128 , 128 , 64 >;
417477  using  WarpShape = typename  cutlass::gemm::GemmShape<64 , 64 , 64 >;
@@ -420,78 +480,130 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
420480  if  (out.dtype () == torch::kBFloat16 ) {
421481    return  cutlass_gemm_caller<cutlass_2x_gemm<
422482        cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t , cutlass::bfloat16_t ,
423-         ScaledEpilogue , TileShape, WarpShape, InstructionShape, 2 >>(
424-         out, a, b, a_scales, b_scales );
483+         Epilogue , TileShape, WarpShape, InstructionShape, 2 >>(
484+         out, a, b, std::forward<EpilogueArgs>(epilogue_args)... );
425485  } else  {
426486    TORCH_CHECK (out.dtype () == torch::kFloat16 );
427487    return  cutlass_gemm_caller<cutlass_2x_gemm<
428488        cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t , cutlass::half_t ,
429-         ScaledEpilogue , TileShape, WarpShape, InstructionShape, 2 >>(
430-         out, a, b, a_scales, b_scales );
489+         Epilogue , TileShape, WarpShape, InstructionShape, 2 >>(
490+         out, a, b, std::forward<EpilogueArgs>(epilogue_args)... );
431491  }
432492}
433493
434- void  cutlass_scaled_mm_sm80 (torch::Tensor& out, torch::Tensor const & a,
494+ void  cutlass_scaled_mm_sm75 (torch::Tensor& out, torch::Tensor const & a,
435495                            torch::Tensor const & b,
436496                            torch::Tensor const & a_scales,
437-                             torch::Tensor const & b_scales) {
438-   TORCH_CHECK (a.dtype () == torch::kInt8 );
439-   TORCH_CHECK (b.dtype () == torch::kInt8 );
497+                             torch::Tensor const & b_scales,
498+                             c10::optional<torch::Tensor> const & bias) {
440499  TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
441500  TORCH_CHECK (b_scales.dtype () == torch::kFloat32 );
501+   if  (bias) {
502+     TORCH_CHECK (bias->dtype () == out.dtype (),
503+                 " currently bias dtype must match output dtype " dtype ());
504+     return  cutlass_scaled_mm_sm75_epilogue<ScaledEpilogueBias>(
505+         out, a, b, a_scales, b_scales, *bias);
506+   } else  {
507+     return  cutlass_scaled_mm_sm75_epilogue<ScaledEpilogue>(out, a, b, a_scales,
508+                                                            b_scales);
509+   }
510+ }
511+ 
512+ template  <template  <typename , typename > typename  Epilogue,
513+           typename ... EpilogueArgs>
514+ void  cutlass_scaled_mm_sm80_epilogue (torch::Tensor& out, torch::Tensor const & a,
515+                                      torch::Tensor const & b,
516+                                      EpilogueArgs&&... epilogue_args) {
517+   TORCH_CHECK (a.dtype () == torch::kInt8 );
518+   TORCH_CHECK (b.dtype () == torch::kInt8 );
442519
443520  if  (out.dtype () == torch::kBFloat16 ) {
444-     return  cutlass_gemm_sm80_dispatch<int8_t , cutlass::bfloat16_t ,
445-                                       ScaledEpilogue>(out, a, b, a_scales,
446-                                                       b_scales);
521+     return  cutlass_gemm_sm80_dispatch<int8_t , cutlass::bfloat16_t , Epilogue>(
522+         out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
447523  } else  {
448524    TORCH_CHECK (out.dtype () == torch::kFloat16 );
449-     return  cutlass_gemm_sm80_dispatch<int8_t , cutlass::half_t , ScaledEpilogue >(
450-         out, a, b, a_scales, b_scales );
525+     return  cutlass_gemm_sm80_dispatch<int8_t , cutlass::half_t , Epilogue >(
526+         out, a, b, std::forward<EpilogueArgs>(epilogue_args)... );
451527  }
452528}
453529
454- void  cutlass_scaled_mm_sm89 (torch::Tensor& out, torch::Tensor const & a,
530+ void  cutlass_scaled_mm_sm80 (torch::Tensor& out, torch::Tensor const & a,
455531                            torch::Tensor const & b,
456532                            torch::Tensor const & a_scales,
457-                             torch::Tensor const & b_scales) {
533+                             torch::Tensor const & b_scales,
534+                             c10::optional<torch::Tensor> const & bias) {
535+   TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
536+   TORCH_CHECK (b_scales.dtype () == torch::kFloat32 );
537+   if  (bias) {
538+     TORCH_CHECK (bias->dtype () == out.dtype (),
539+                 " currently bias dtype must match output dtype " dtype ());
540+     return  cutlass_scaled_mm_sm80_epilogue<ScaledEpilogueBias>(
541+         out, a, b, a_scales, b_scales, *bias);
542+   } else  {
543+     return  cutlass_scaled_mm_sm80_epilogue<ScaledEpilogue>(out, a, b, a_scales,
544+                                                            b_scales);
545+   }
546+ }
547+ 
548+ template  <template  <typename , typename > typename  Epilogue,
549+           typename ... EpilogueArgs>
550+ void  cutlass_scaled_mm_sm89_epilogue (torch::Tensor& out, torch::Tensor const & a,
551+                                      torch::Tensor const & b,
552+                                      EpilogueArgs&&... epilogue_args) {
458553  using  TileShape = typename  cutlass::gemm::GemmShape<128 , 128 , 64 >;
459554  using  WarpShape = typename  cutlass::gemm::GemmShape<64 , 64 , 64 >;
460555  using  InstructionShape = typename  cutlass::gemm::GemmShape<16 , 8 , 32 >;
461556
462-   TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
463-   TORCH_CHECK (b_scales.dtype () == torch::kFloat32 );
464- 
465557  if  (a.dtype () == torch::kInt8 ) {
466558    TORCH_CHECK (b.dtype () == torch::kInt8 );
467559
468560    if  (out.dtype () == torch::kBFloat16 ) {
469561      return  cutlass_gemm_caller<cutlass_2x_gemm<
470562          cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t , cutlass::bfloat16_t ,
471-           ScaledEpilogue , TileShape, WarpShape, InstructionShape, 5 >>(
472-           out, a, b, a_scales, b_scales );
563+           Epilogue , TileShape, WarpShape, InstructionShape, 5 >>(
564+           out, a, b, std::forward<EpilogueArgs>(epilogue_args)... );
473565    } else  {
474566      assert (out.dtype () == torch::kFloat16 );
475567      return  cutlass_gemm_caller<cutlass_2x_gemm<
476568          cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t , cutlass::half_t ,
477-           ScaledEpilogue , TileShape, WarpShape, InstructionShape, 5 >>(
478-           out, a, b, a_scales, b_scales );
569+           Epilogue , TileShape, WarpShape, InstructionShape, 5 >>(
570+           out, a, b, std::forward<EpilogueArgs>(epilogue_args)... );
479571    }
480572  } else  {
481573    TORCH_CHECK (a.dtype () == torch::kFloat8_e4m3fn );
482574    TORCH_CHECK (b.dtype () == torch::kFloat8_e4m3fn );
483575
484576    if  (out.dtype () == torch::kBFloat16 ) {
485-       return  cutlass_gemm_caller<cutlass_2x_gemm<
486-           cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t ,
487-           cutlass::bfloat16_t , ScaledEpilogue, TileShape, WarpShape,
488-           InstructionShape, 5 >>(out, a, b, a_scales, b_scales);
577+       return  cutlass_gemm_caller<
578+           cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
579+                           cutlass::float_e4m3_t , cutlass::bfloat16_t , Epilogue,
580+                           TileShape, WarpShape, InstructionShape, 5 >>(
581+           out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
489582    } else  {
490583      TORCH_CHECK (out.dtype () == torch::kFloat16 );
491-       return  cutlass_gemm_caller<cutlass_2x_gemm<
492-           cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t ,
493-           cutlass::half_t , ScaledEpilogue, TileShape, WarpShape,
494-           InstructionShape, 5 >>(out, a, b, a_scales, b_scales);
584+       return  cutlass_gemm_caller<
585+           cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
586+                           cutlass::float_e4m3_t , cutlass::half_t , Epilogue,
587+                           TileShape, WarpShape, InstructionShape, 5 >>(
588+           out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
495589    }
496590  }
497591}
592+ 
593+ void  cutlass_scaled_mm_sm89 (torch::Tensor& out, torch::Tensor const & a,
594+                             torch::Tensor const & b,
595+                             torch::Tensor const & a_scales,
596+                             torch::Tensor const & b_scales,
597+                             c10::optional<torch::Tensor> const & bias) {
598+   TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
599+   TORCH_CHECK (b_scales.dtype () == torch::kFloat32 );
600+   if  (bias) {
601+     TORCH_CHECK (bias->dtype () == out.dtype (),
602+                 " currently bias dtype must match output dtype " dtype ());
603+     return  cutlass_scaled_mm_sm89_epilogue<ScaledEpilogueBias>(
604+         out, a, b, a_scales, b_scales, *bias);
605+   } else  {
606+     return  cutlass_scaled_mm_sm89_epilogue<ScaledEpilogue>(out, a, b, a_scales,
607+                                                            b_scales);
608+   }
609+ }
0 commit comments