diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp b/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp new file mode 100644 index 000000000000..8444272940b4 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp @@ -0,0 +1,8 @@ +#include + +#include "linear.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("linear_silu_a8_w8_bfp32_ofp32", &linear_silu_a8_w8_bfp32_ofp32, + "Linear SiLU (INT8)"); +} diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu new file mode 100644 index 000000000000..a30d02a4cf42 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu @@ -0,0 +1,162 @@ +// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu + +#include "linear.h" +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // FP32 + float alpha, // FP32 + float beta // FP32 +) { + auto M = input.size(0); + auto N = weight.size(0); + auto K = input.size(1); + + using ElementOutput = float; + using ElementAccumulator = int32_t; + using ElementComputeEpilogue = float; + using ElementInputA = int8_t; // <- data type of elements in input matrix A + using ElementInputB = int8_t; // <- data type of elements in input matrix B + + // The code section below describes matrix layout of input and output + // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major + // for Matrix C + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + +#if CUDA_ARCH >= 800 + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits< + ElementOutput>::value, // <- this is the number of elements per + // vectorized memory access. For half + // precision, it's 8 elements. This + // becomes the vector width of math + // instructions in epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue // <- data type for alpha in linear combination + // function + >; + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; +#elif CUDA_ARCH >= 750 + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits< + ElementOutput>::value, // <- this is the number of elements per + // vectorized memory access. For half + // precision, it's 8 elements. This + // becomes the vector width of math + // instructions in epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue // <- data type for alpha in linear combination + // function + >; + + using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, + DefaultGemmCfg::InstructionShape, + EpilogueOp>; +#elif CUDA_ARCH >= 700 + #define USE_TORCH_SILU + using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< + cutlass::arch::OpClassSimt, cutlass::arch::Sm70, + ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassSimt, cutlass::arch::Sm70, + DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, + DefaultGemmCfg::InstructionShape, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>; +#else + #error "Unsupported cuda arch" +#endif + + auto input_size = cutlass::MatrixCoord(M, K); + auto weight_size = cutlass::MatrixCoord(K, N); + auto output_size = cutlass::MatrixCoord(M, N); + + auto device = input.device(); + // use the broadcasted bias as the output + auto out = bias.to(device).view({1, -1}).repeat({M, 1}); + + // constexpr int kSparse = Gemm::kSparse; + // How many elements of A are covered per ElementE + // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; + // The size of individual meta data + // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; + cutlass::gemm::GemmCoord problem_size(M, N, K); + + cutlass::TensorRef input_ref( + input.data_ptr(), LayoutInputA::packed(input_size)); + cutlass::TensorRef weight_ref( + weight.data_ptr(), LayoutInputB::packed(weight_size)); + cutlass::TensorRef out_ref( + out.data_ptr(), LayoutOutput::packed(output_size)); + + typename Gemm::Arguments arguments{ + problem_size, // <- problem size of matrix multiplication + input_ref, // <- reference to matrix A on device + weight_ref, // <- reference to matrix B on device + out_ref, // <- reference to matrix C on device + out_ref, // <- reference to matrix D on device + {alpha, beta}, 1}; + Gemm gemm_op; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm_op(); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot run"); + } +#ifdef USE_TORCH_SILU +#undef USE_TORCH_SILU + out = torch::silu(out); +#endif + return out; +} diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h new file mode 100644 index 000000000000..b62a27f3f8f3 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h @@ -0,0 +1,12 @@ +#include +#include + +#include +#include + +torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // FP32 + float alpha, // FP32 + float beta // FP32 +); diff --git a/op_builder/smoothquant.py b/op_builder/smoothquant.py new file mode 100644 index 000000000000..d562a4c4f626 --- /dev/null +++ b/op_builder/smoothquant.py @@ -0,0 +1,52 @@ +import torch + +from .builder import Builder +from .utils import append_nvcc_threads + + +class SmoothquantBuilder(Builder): + NAME = "cu_smoothquant" + PREBUILT_IMPORT_PATH = "colossalai._C.cu_smoothquant" + + def __init__(self): + super().__init__(name=SmoothquantBuilder.NAME, prebuilt_import_path=SmoothquantBuilder.PREBUILT_IMPORT_PATH) + + def include_dirs(self): + ret = [self.csrc_abs_path("smoothquant"), self.get_cuda_home_include()] + return ret + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in [ + "smoothquant/binding.cpp", + "smoothquant/linear.cu", + ] + ] + return ret + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + compute_capability = torch.cuda.get_device_capability() + cuda_arch = compute_capability[0] * 100 + compute_capability[1] * 10 + + extra_cuda_flags = [ + "-v", + f"-DCUDA_ARCH={cuda_arch}", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", + ] + + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) + + def builder(self): + try: + super().builder() + except: + warnings.warn("build smoothquant lib not successful") diff --git a/tests/test_smoothquant/test_smoothquant_linear.py b/tests/test_smoothquant/test_smoothquant_linear.py new file mode 100644 index 000000000000..58a0b82f6759 --- /dev/null +++ b/tests/test_smoothquant/test_smoothquant_linear.py @@ -0,0 +1,39 @@ +import warnings + +import pytest +import torch + +try: + from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder + + smoothquant_cuda = SmoothquantBuilder().load() + HAS_SMOOTHQUANT_CUDA = True +except: + warnings.warn("CUDA smoothquant linear is not installed") + HAS_SMOOTHQUANT_CUDA = False + + +@pytest.mark.skipif( + not HAS_SMOOTHQUANT_CUDA, + reason="smoothquant linear not installed properly", +) +def test_linear(): + a = torch.randint(-127, 127, (128, 512), dtype=torch.int8, device="cuda") + b = torch.randint(-127, 127, (512, 256), dtype=torch.int8, device="cuda") + c = torch.rand(256, dtype=torch.float, device="cuda") + + alpha = 1 / 127 + beta = 1.0 + torch_out = torch.mm(a.to(torch.float) * alpha, b.to(torch.float)) + c + + silu = torch.nn.SiLU() + torch_out = silu(torch_out) + + b = b.transpose(0, 1).contiguous() + cuda_out = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(a, b, c, alpha, beta) + + assert torch.allclose(torch_out, cuda_out, rtol=1e-02, atol=1e-02) + + +if __name__ == "__main__": + test_linear()