|
10 | 10 | run_tests, |
11 | 11 | ) |
12 | 12 | from torch.testing._internal.optests import opcheck |
13 | | -from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5 |
| 13 | +from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5, compute_max_diff |
14 | 14 | from torchao.prototype.quant_llm import from_scaled_tc_fpx |
| 15 | +from torchao.sparsity.marlin import marlin_24_workspace, pack_to_marlin_24, inject_24 |
15 | 16 | import pytest |
16 | 17 |
|
17 | 18 | if is_fbcode(): |
@@ -302,5 +303,119 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size |
302 | 303 | test_utils=test_utils, |
303 | 304 | ) |
304 | 305 |
|
| 306 | + |
| 307 | +MARLIN_24_K_CHUNKS = [128] |
| 308 | +MARLIN_24_N_CHUNKS = [512] |
| 309 | +MNK_FACTORS = [ |
| 310 | + (1, 1, 1), |
| 311 | + (1, 4, 8), |
| 312 | + (1, 7, 5), |
| 313 | + (13, 17, 67), |
| 314 | + (26, 37, 13), |
| 315 | + (67, 13, 11), |
| 316 | +] |
| 317 | +MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] |
| 318 | +MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] |
| 319 | + |
| 320 | +MARLIN_TEST_PARAMS = list(itertools.product( |
| 321 | + MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, MARLIN_24_SUPPORTED_NUM_BITS, |
| 322 | + MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS |
| 323 | +)) |
| 324 | + |
| 325 | +def _symmetric_quantize_with_ref(w: torch.Tensor, num_bits: int, group_size: int): |
| 326 | + orig_device = w.device |
| 327 | + size_k, size_n = w.shape |
| 328 | + |
| 329 | + assert w.is_floating_point(), "w must be float" |
| 330 | + |
| 331 | + if group_size == -1: |
| 332 | + group_size = size_k |
| 333 | + assert group_size <= size_k |
| 334 | + |
| 335 | + max_q_val = 2**num_bits - 1 |
| 336 | + half_q_val = (max_q_val + 1) // 2 |
| 337 | + |
| 338 | + # Reshape to [groupsize, -1] |
| 339 | + if group_size < size_k: |
| 340 | + w = w.reshape((-1, group_size, size_n)) |
| 341 | + w = w.permute(1, 0, 2) |
| 342 | + w = w.reshape((group_size, -1)) |
| 343 | + |
| 344 | + # Compute scale for each group |
| 345 | + s = torch.max(torch.abs(w), 0, keepdim=True)[0] |
| 346 | + s *= 2 / max_q_val # 2 => symmetric |
| 347 | + |
| 348 | + # Quantize |
| 349 | + q_w = torch.round(w / s).int() |
| 350 | + q_w += half_q_val |
| 351 | + q_w = torch.clamp(q_w, 0, max_q_val) |
| 352 | + |
| 353 | + # Compute ref (dequantized) |
| 354 | + w_ref = (q_w - half_q_val).half() * s |
| 355 | + |
| 356 | + # Restore original shapes |
| 357 | + if group_size < size_k: |
| 358 | + |
| 359 | + def reshape_w(w): |
| 360 | + w = w.reshape((group_size, -1, size_n)) |
| 361 | + w = w.permute(1, 0, 2) |
| 362 | + w = w.reshape((size_k, size_n)).contiguous() |
| 363 | + return w |
| 364 | + |
| 365 | + q_w = reshape_w(q_w) |
| 366 | + w_ref = reshape_w(w_ref) |
| 367 | + |
| 368 | + s = s.reshape((-1, size_n)).contiguous() |
| 369 | + |
| 370 | + return ( |
| 371 | + w_ref.to(device=orig_device), |
| 372 | + q_w.to(device=orig_device), |
| 373 | + s.to(device=orig_device), |
| 374 | + ) |
| 375 | + |
| 376 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
| 377 | +@pytest.mark.parametrize("k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, ids=str) |
| 378 | +def test_marlin_24(k_chunk, n_chunk, num_bits, group_size, mnk_factors): |
| 379 | + m_factor, n_factor, k_factor = mnk_factors |
| 380 | + |
| 381 | + size_m = m_factor |
| 382 | + size_k = k_chunk * k_factor |
| 383 | + size_n = n_chunk * n_factor |
| 384 | + |
| 385 | + a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda") |
| 386 | + b_weight = torch.rand((size_k, size_n), dtype=torch.float16, device="cuda") |
| 387 | + |
| 388 | + # Inject 2:4 sparsity |
| 389 | + w_24, _ = inject_24(b_weight, size_k, size_n) |
| 390 | + |
| 391 | + # Symmetric quantize |
| 392 | + w_24_ref, q_w_24, scale = _symmetric_quantize_with_ref(w_24, num_bits, group_size) |
| 393 | + |
| 394 | + # Obtains reference output |
| 395 | + output_ref = torch.matmul(a_input, w_24_ref) |
| 396 | + |
| 397 | + # Packs to marlin 2:4 |
| 398 | + marlin_24_q_w_comp, marlin_24_scale, meta = pack_to_marlin_24(q_w_24, scale, num_bits, group_size) |
| 399 | + workspace_24 = marlin_24_workspace(size_n) |
| 400 | + |
| 401 | + fn_inputs = ( |
| 402 | + a_input, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24, |
| 403 | + num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1], |
| 404 | + ) |
| 405 | + output = torchao.ops.marlin_24_gemm(*fn_inputs) |
| 406 | + torch.cuda.synchronize() |
| 407 | + |
| 408 | + max_diff = compute_max_diff(output, output_ref) |
| 409 | + assert max_diff < 0.04 |
| 410 | + |
| 411 | + # Performs opcheck |
| 412 | + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"] |
| 413 | + opcheck( |
| 414 | + torch.ops.torchao.marlin_24_gemm, |
| 415 | + fn_inputs, |
| 416 | + test_utils=test_utils, |
| 417 | + ) |
| 418 | + |
| 419 | + |
305 | 420 | if __name__ == "__main__": |
306 | 421 | run_tests() |
0 commit comments