Skip to content

Commit cea8b38

Browse files
committed
[inference] add silu linear fusion for smoothquant llama mlp (hpcaitech#4853)
* add silu linear * update skip condition * catch smoothquant cuda lib exception * prcocess exception for tests
1 parent 30f23cd commit cea8b38

File tree

5 files changed

+273
-0
lines changed

5 files changed

+273
-0
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#include <torch/extension.h>
2+
3+
#include "linear.h"
4+
5+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6+
m.def("linear_silu_a8_w8_bfp32_ofp32", &linear_silu_a8_w8_bfp32_ofp32,
7+
"Linear SiLU (INT8)");
8+
}
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu
2+
3+
#include "linear.h"
4+
#include <cutlass/core_io.h>
5+
#include <cutlass/cutlass.h>
6+
#include <cutlass/half.h>
7+
8+
#include <cutlass/gemm/device/gemm.h>
9+
#include <cutlass/numeric_types.h>
10+
#include <cutlass/util/host_tensor.h>
11+
#include <cutlass/epilogue/thread/linear_combination_silu.h>
12+
#include <cstdint>
13+
#include <cuda.h>
14+
#include <cuda_runtime.h>
15+
#include <cuda_fp16.h>
16+
#include <iostream>
17+
#include <torch/torch.h>
18+
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8
19+
torch::Tensor weight, // INT8
20+
torch::Tensor bias, // FP32
21+
float alpha, // FP32
22+
float beta // FP32
23+
) {
24+
auto M = input.size(0);
25+
auto N = weight.size(0);
26+
auto K = input.size(1);
27+
28+
using ElementOutput = float;
29+
using ElementAccumulator = int32_t;
30+
using ElementComputeEpilogue = float;
31+
using ElementInputA = int8_t; // <- data type of elements in input matrix A
32+
using ElementInputB = int8_t; // <- data type of elements in input matrix B
33+
34+
// The code section below describes matrix layout of input and output
35+
// matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major
36+
// for Matrix C
37+
using LayoutInputA = cutlass::layout::RowMajor;
38+
using LayoutInputB = cutlass::layout::ColumnMajor;
39+
using LayoutOutput = cutlass::layout::RowMajor;
40+
41+
#if CUDA_ARCH >= 800
42+
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu<
43+
ElementOutput, // <- data type of output matrix
44+
128 / cutlass::sizeof_bits<
45+
ElementOutput>::value, // <- this is the number of elements per
46+
// vectorized memory access. For half
47+
// precision, it's 8 elements. This
48+
// becomes the vector width of math
49+
// instructions in epilogue too
50+
ElementAccumulator, // <- data type of accumulator
51+
ElementComputeEpilogue // <- data type for alpha in linear combination
52+
// function
53+
>;
54+
using Gemm = cutlass::gemm::device::Gemm<
55+
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
56+
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
57+
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
58+
cutlass::gemm::GemmShape<256, 128, 64>,
59+
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>,
60+
EpilogueOp,
61+
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
62+
#elif CUDA_ARCH >= 750
63+
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu<
64+
ElementOutput, // <- data type of output matrix
65+
128 / cutlass::sizeof_bits<
66+
ElementOutput>::value, // <- this is the number of elements per
67+
// vectorized memory access. For half
68+
// precision, it's 8 elements. This
69+
// becomes the vector width of math
70+
// instructions in epilogue too
71+
ElementAccumulator, // <- data type of accumulator
72+
ElementComputeEpilogue // <- data type for alpha in linear combination
73+
// function
74+
>;
75+
76+
using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
77+
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
78+
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
79+
using Gemm = cutlass::gemm::device::Gemm<
80+
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
81+
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
82+
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
83+
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
84+
DefaultGemmCfg::InstructionShape,
85+
EpilogueOp>;
86+
#elif CUDA_ARCH >= 700
87+
#define USE_TORCH_SILU
88+
using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
89+
cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
90+
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
91+
using Gemm = cutlass::gemm::device::Gemm<
92+
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
93+
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
94+
cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
95+
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
96+
DefaultGemmCfg::InstructionShape,
97+
cutlass::epilogue::thread::LinearCombination<
98+
ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>;
99+
#else
100+
#error "Unsupported cuda arch"
101+
#endif
102+
103+
auto input_size = cutlass::MatrixCoord(M, K);
104+
auto weight_size = cutlass::MatrixCoord(K, N);
105+
auto output_size = cutlass::MatrixCoord(M, N);
106+
107+
auto device = input.device();
108+
// use the broadcasted bias as the output
109+
auto out = bias.to(device).view({1, -1}).repeat({M, 1});
110+
111+
// constexpr int kSparse = Gemm::kSparse;
112+
// How many elements of A are covered per ElementE
113+
// constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
114+
// The size of individual meta data
115+
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
116+
cutlass::gemm::GemmCoord problem_size(M, N, K);
117+
118+
cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref(
119+
input.data_ptr<ElementInputA>(), LayoutInputA::packed(input_size));
120+
cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref(
121+
weight.data_ptr<ElementInputB>(), LayoutInputB::packed(weight_size));
122+
cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref(
123+
out.data_ptr<ElementOutput>(), LayoutOutput::packed(output_size));
124+
125+
typename Gemm::Arguments arguments{
126+
problem_size, // <- problem size of matrix multiplication
127+
input_ref, // <- reference to matrix A on device
128+
weight_ref, // <- reference to matrix B on device
129+
out_ref, // <- reference to matrix C on device
130+
out_ref, // <- reference to matrix D on device
131+
{alpha, beta}, 1};
132+
Gemm gemm_op;
133+
134+
// Using the arguments, query for extra workspace required for matrix
135+
// multiplication computation
136+
size_t workspace_size = Gemm::get_workspace_size(arguments);
137+
138+
// Allocate workspace memory
139+
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
140+
141+
// Check the problem size is supported or not
142+
cutlass::Status status = gemm_op.can_implement(arguments);
143+
if (status != cutlass::Status::kSuccess) {
144+
throw std::runtime_error("cutlass cannot implement");
145+
}
146+
147+
// Initialize CUTLASS kernel with arguments and workspace pointer
148+
status = gemm_op.initialize(arguments, workspace.get());
149+
if (status != cutlass::Status::kSuccess) {
150+
throw std::runtime_error("cutlass cannot initialize");
151+
}
152+
153+
status = gemm_op();
154+
if (status != cutlass::Status::kSuccess) {
155+
throw std::runtime_error("cutlass cannot run");
156+
}
157+
#ifdef USE_TORCH_SILU
158+
#undef USE_TORCH_SILU
159+
out = torch::silu(out);
160+
#endif
161+
return out;
162+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#include <torch/torch.h>
2+
#include <torch/types.h>
3+
4+
#include <cstdint>
5+
#include <iostream>
6+
7+
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8
8+
torch::Tensor weight, // INT8
9+
torch::Tensor bias, // FP32
10+
float alpha, // FP32
11+
float beta // FP32
12+
);

op_builder/smoothquant.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
3+
from .builder import Builder
4+
from .utils import append_nvcc_threads
5+
6+
7+
class SmoothquantBuilder(Builder):
8+
NAME = "cu_smoothquant"
9+
PREBUILT_IMPORT_PATH = "colossalai._C.cu_smoothquant"
10+
11+
def __init__(self):
12+
super().__init__(name=SmoothquantBuilder.NAME, prebuilt_import_path=SmoothquantBuilder.PREBUILT_IMPORT_PATH)
13+
14+
def include_dirs(self):
15+
ret = [self.csrc_abs_path("smoothquant"), self.get_cuda_home_include()]
16+
return ret
17+
18+
def sources_files(self):
19+
ret = [
20+
self.csrc_abs_path(fname)
21+
for fname in [
22+
"smoothquant/binding.cpp",
23+
"smoothquant/linear.cu",
24+
]
25+
]
26+
return ret
27+
28+
def cxx_flags(self):
29+
return ["-O3"] + self.version_dependent_macros
30+
31+
def nvcc_flags(self):
32+
compute_capability = torch.cuda.get_device_capability()
33+
cuda_arch = compute_capability[0] * 100 + compute_capability[1] * 10
34+
35+
extra_cuda_flags = [
36+
"-v",
37+
f"-DCUDA_ARCH={cuda_arch}",
38+
"-std=c++17",
39+
"-U__CUDA_NO_HALF_OPERATORS__",
40+
"-U__CUDA_NO_HALF_CONVERSIONS__",
41+
"-U__CUDA_NO_HALF2_OPERATORS__",
42+
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
43+
]
44+
45+
ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags
46+
return append_nvcc_threads(ret)
47+
48+
def builder(self):
49+
try:
50+
super().builder()
51+
except:
52+
warnings.warn("build smoothquant lib not successful")
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import warnings
2+
3+
import pytest
4+
import torch
5+
6+
try:
7+
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
8+
9+
smoothquant_cuda = SmoothquantBuilder().load()
10+
HAS_SMOOTHQUANT_CUDA = True
11+
except:
12+
warnings.warn("CUDA smoothquant linear is not installed")
13+
HAS_SMOOTHQUANT_CUDA = False
14+
15+
16+
@pytest.mark.skipif(
17+
not HAS_SMOOTHQUANT_CUDA,
18+
reason="smoothquant linear not installed properly",
19+
)
20+
def test_linear():
21+
a = torch.randint(-127, 127, (128, 512), dtype=torch.int8, device="cuda")
22+
b = torch.randint(-127, 127, (512, 256), dtype=torch.int8, device="cuda")
23+
c = torch.rand(256, dtype=torch.float, device="cuda")
24+
25+
alpha = 1 / 127
26+
beta = 1.0
27+
torch_out = torch.mm(a.to(torch.float) * alpha, b.to(torch.float)) + c
28+
29+
silu = torch.nn.SiLU()
30+
torch_out = silu(torch_out)
31+
32+
b = b.transpose(0, 1).contiguous()
33+
cuda_out = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(a, b, c, alpha, beta)
34+
35+
assert torch.allclose(torch_out, cuda_out, rtol=1e-02, atol=1e-02)
36+
37+
38+
if __name__ == "__main__":
39+
test_linear()

0 commit comments

Comments
 (0)