|
7 | 7 | rowwise_scaled_linear_cutlass_s4s4, |
8 | 8 | rowwise_scaled_linear_cutlass_s8s4, |
9 | 9 | ) |
| 10 | +from torchao.quantization.quant_api import ( |
| 11 | + _int4_symm_cutlass_quant, |
| 12 | + _int8_symm_cutlass_quant, |
| 13 | +) |
| 14 | + |
| 15 | +dtype = torch.bfloat16 |
| 16 | +dtypeq = torch.int8 |
| 17 | +dtype_scale = torch.float32 |
| 18 | +device = torch.device("cuda") |
10 | 19 |
|
11 | 20 |
|
12 | 21 | def benchmark_microseconds(f, *args): |
13 | 22 | return do_bench(lambda: f(*args), return_mode="median") * 1e3 |
14 | 23 |
|
15 | 24 |
|
16 | | -def get_problem(m: int, n: int, k: int, A_nbits: int, B_nbits: int): |
17 | | - assert A_nbits in (4, 8) and B_nbits in (4, 8) |
| 25 | +def get_problem(m: int, n: int, k: int, Xq_nbits: int): |
| 26 | + assert k % 2 == 0 |
| 27 | + assert Xq_nbits in [4, 8] |
| 28 | + |
| 29 | + X_ref = torch.randn((m, k), dtype=dtype, device=device) |
| 30 | + W_ref = torch.rand((n, k), dtype=dtype, device=device) |
18 | 31 |
|
19 | | - dev = torch.device("cuda") |
20 | | - A = torch.randint(-128, 127, (m, k * A_nbits // 8), dtype=torch.int8, device=dev) |
21 | | - A_scale = torch.randn((m,), dtype=torch.half, device=dev) |
22 | | - B = torch.randint( |
23 | | - -128, 127, size=(n, k * B_nbits // 8), dtype=torch.int8, device=dev |
| 32 | + X_quant_func = ( |
| 33 | + _int4_symm_cutlass_quant if Xq_nbits == 4 else _int8_symm_cutlass_quant |
24 | 34 | ) |
25 | | - B_scale = torch.randn((n,), dtype=torch.half, device=dev) |
26 | | - C = None |
| 35 | + W_quant_func = _int4_symm_cutlass_quant |
| 36 | + X_aqt = X_quant_func(X_ref) |
| 37 | + W_aqt = W_quant_func(W_ref) |
27 | 38 |
|
28 | | - return A, A_scale, B, B_scale, C |
| 39 | + Xq = X_aqt.tensor_impl.int_data |
| 40 | + X_scale = X_aqt.tensor_impl.scale |
| 41 | + Wq = W_aqt.tensor_impl.int_data |
| 42 | + W_scale = W_aqt.tensor_impl.scale |
| 43 | + bias = None |
| 44 | + out_dtype = dtype |
29 | 45 |
|
| 46 | + return (X_ref, W_ref), (Xq, X_scale, Wq, W_scale, bias, out_dtype) |
30 | 47 |
|
31 | | -def benchmark(m: int, k: int, n: int): |
32 | | - dev = torch.device("cuda") |
33 | | - A_ref = torch.randn((m, k), dtype=torch.half, device=dev) |
34 | | - B_ref = torch.randn((n, k), dtype=torch.half, device=dev) |
35 | | - fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref) |
36 | 48 |
|
37 | | - A, A_scale, B, B_scale, C = get_problem(m, n, k, 8, 4) |
38 | | - rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds( |
39 | | - rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C |
| 49 | +def benchmark(m: int, k: int, n: int): |
| 50 | + ref_args, args = get_problem(m, n, k, 4) |
| 51 | + fp16_time = benchmark_microseconds(torch.nn.functional.linear, *ref_args) |
| 52 | + rowwise_scaled_linear_cutlass_s4s4_time = benchmark_microseconds( |
| 53 | + rowwise_scaled_linear_cutlass_s4s4, *args |
40 | 54 | ) |
41 | 55 |
|
42 | | - A, A_scale, B, B_scale, C = get_problem(m, n, k, 4, 4) |
43 | | - rowwise_scaled_linear_cutlass_s4s4_time = benchmark_microseconds( |
44 | | - rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale, C |
| 56 | + _, args = get_problem(m, n, k, 8) |
| 57 | + rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds( |
| 58 | + rowwise_scaled_linear_cutlass_s8s4, *args |
45 | 59 | ) |
46 | 60 |
|
47 | 61 | return { |
|
0 commit comments