From da7ba9d005c723b007a44c35ba8adb734dc4e97e Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 23 Jul 2024 17:41:35 +0000 Subject: [PATCH] Add lovelace i8 kernels --- .../cutlass_w8a8/scaled_mm_c2x.cu | 27 +- ...uh => scaled_mm_c2x_sm89_fp8_dispatch.cuh} | 54 +-- .../scaled_mm_c2x_sm89_int8_dispatch.cuh | 353 ++++++++++++++++++ tests/kernels/test_cutlass.py | 4 +- 4 files changed, 395 insertions(+), 43 deletions(-) rename csrc/quantization/cutlass_w8a8/{scaled_mm_c2x_sm89_dispatch.cuh => scaled_mm_c2x_sm89_fp8_dispatch.cuh} (89%) create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index d26c43de522c..aac4900f933a 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -4,7 +4,8 @@ #include "scaled_mm_c2x.cuh" #include "scaled_mm_c2x_sm80_dispatch.cuh" -#include "scaled_mm_c2x_sm89_dispatch.cuh" +#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh" +#include "scaled_mm_c2x_sm89_int8_dispatch.cuh" /* This file defines quantized GEMM operations using the CUTLASS 2.x API, for @@ -98,25 +99,17 @@ template