|  | 
| 9 | 9 | from vllm import _custom_ops as ops | 
| 10 | 10 | from vllm.model_executor.layers.quantization.gptq_marlin import ( | 
| 11 | 11 |     GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, | 
| 12 |  | -    GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) | 
|  | 12 | +    GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS, | 
|  | 13 | +    marlin_permute_scales) | 
| 13 | 14 | from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( | 
| 14 | 15 |     GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, | 
| 15 | 16 |     GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) | 
| 16 | 17 | from vllm.model_executor.layers.quantization.utils.marlin_perms import ( | 
| 17 | 18 |     marlin_perm) | 
| 18 | 19 | from vllm.model_executor.layers.quantization.utils.marlin_utils import ( | 
| 19 | 20 |     MarlinWorkspace, compute_max_diff, is_marlin_supported, marlin_24_quantize, | 
| 20 |  | -    marlin_quantize, marlin_weights) | 
|  | 21 | +    marlin_quantize, marlin_weights, pack_fp8_to_int32) | 
| 21 | 22 | from vllm.model_executor.layers.quantization.utils.quant_utils import ( | 
| 22 | 23 |     gptq_pack, quantize_weights, sort_weights) | 
| 23 | 24 | 
 | 
|  | 
| 43 | 44 |     (67, 13, 11), | 
| 44 | 45 | ] | 
| 45 | 46 | 
 | 
|  | 47 | +DTYPES = [torch.float16, torch.bfloat16] | 
| 46 | 48 | 
 | 
| 47 |  | -def rand_data(shape): | 
| 48 |  | -    return torch.randn(shape, dtype=torch.half, device="cuda") | 
|  | 49 | + | 
|  | 50 | +def rand_data(shape, dtype=torch.float16): | 
|  | 51 | +    return torch.randn(shape, dtype=dtype, device="cuda") | 
| 49 | 52 | 
 | 
| 50 | 53 | 
 | 
| 51 | 54 | @pytest.mark.skipif(not is_marlin_supported(), | 
| @@ -222,3 +225,80 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors): | 
| 222 | 225 |     print("max_diff = {}".format(max_diff)) | 
| 223 | 226 | 
 | 
| 224 | 227 |     assert max_diff < 0.04 | 
|  | 228 | + | 
|  | 229 | + | 
|  | 230 | +@pytest.mark.skipif(not is_marlin_supported(), | 
|  | 231 | +                    reason="Marlin is not supported on this GPU type.") | 
|  | 232 | +@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) | 
|  | 233 | +@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) | 
|  | 234 | +@pytest.mark.parametrize("num_bits", [8]) | 
|  | 235 | +@pytest.mark.parametrize("group_size", [-1]) | 
|  | 236 | +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) | 
|  | 237 | +@pytest.mark.parametrize("dtype", DTYPES) | 
|  | 238 | +def test_fp8_marlin_gemm( | 
|  | 239 | +    k_chunk, | 
|  | 240 | +    n_chunk, | 
|  | 241 | +    num_bits, | 
|  | 242 | +    group_size, | 
|  | 243 | +    mnk_factors, | 
|  | 244 | +    dtype, | 
|  | 245 | +): | 
|  | 246 | +    m_factor, n_factor, k_factor = mnk_factors | 
|  | 247 | + | 
|  | 248 | +    size_m = m_factor | 
|  | 249 | +    size_k = k_chunk * k_factor | 
|  | 250 | +    size_n = n_chunk * n_factor | 
|  | 251 | + | 
|  | 252 | +    print(f"MNK = {size_m} {size_n} {size_k}") | 
|  | 253 | +    print(f"groupsize = {group_size}") | 
|  | 254 | + | 
|  | 255 | +    a_input = rand_data((size_m, size_k), dtype=dtype) | 
|  | 256 | +    b_weight = rand_data((size_k, size_n), dtype=dtype) | 
|  | 257 | + | 
|  | 258 | +    # WEIGHTS | 
|  | 259 | +    fp8_weight, weight_scale = ops.scaled_fp8_quant(b_weight, scale=None) | 
|  | 260 | +    # Repack weights to gptq format (packed int32 elements) | 
|  | 261 | +    packed_gptq_qweight = pack_fp8_to_int32(fp8_weight) | 
|  | 262 | +    # Repack weights to marlin format | 
|  | 263 | +    marlin_qweight = ops.gptq_marlin_repack( | 
|  | 264 | +        b_q_weight=packed_gptq_qweight, | 
|  | 265 | +        perm=torch.empty(0, dtype=torch.int, device="cuda"), | 
|  | 266 | +        size_k=size_k, | 
|  | 267 | +        size_n=size_n, | 
|  | 268 | +        num_bits=8, | 
|  | 269 | +    ) | 
|  | 270 | + | 
|  | 271 | +    # WEIGHT SCALES | 
|  | 272 | +    # Currently Marlin doesn't support per-tensor scales, so we | 
|  | 273 | +    # expand it to channelwise | 
|  | 274 | +    scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda") | 
|  | 275 | +    # Permute scales | 
|  | 276 | +    marlin_scales = marlin_permute_scales( | 
|  | 277 | +        s=scales, | 
|  | 278 | +        size_k=size_k, | 
|  | 279 | +        size_n=size_n, | 
|  | 280 | +        group_size=-1, | 
|  | 281 | +        num_bits=8, | 
|  | 282 | +    ) | 
|  | 283 | + | 
|  | 284 | +    workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, | 
|  | 285 | +                                GPTQ_MARLIN_MAX_PARALLEL) | 
|  | 286 | + | 
|  | 287 | +    output = ops.fp8_marlin_gemm( | 
|  | 288 | +        a=a_input, | 
|  | 289 | +        b_q_weight=marlin_qweight, | 
|  | 290 | +        b_scales=marlin_scales, | 
|  | 291 | +        workspace=workspace.scratch, | 
|  | 292 | +        num_bits=num_bits, | 
|  | 293 | +        size_m=a_input.shape[0], | 
|  | 294 | +        size_n=b_weight.shape[1], | 
|  | 295 | +        size_k=a_input.shape[1], | 
|  | 296 | +    ) | 
|  | 297 | +    output_ref = torch.matmul(a_input, b_weight) | 
|  | 298 | + | 
|  | 299 | +    torch.cuda.synchronize() | 
|  | 300 | + | 
|  | 301 | +    max_diff = compute_max_diff(output, output_ref) | 
|  | 302 | +    print("max_diff = {}".format(max_diff)) | 
|  | 303 | + | 
|  | 304 | +    assert max_diff < 0.04 | 
0 commit comments