| 
6 | 6 | 
 
  | 
7 | 7 | import pytest  | 
8 | 8 | import torch  | 
 | 9 | +from torch._inductor.utils import run_and_get_code  | 
 | 10 | +from torch.testing import FileCheck  | 
9 | 11 | 
 
  | 
10 | 12 | from torchao.prototype.mx_formats.config import MXGemmKernelChoice  | 
11 | 13 | from torchao.prototype.mx_formats.constants import (  | 
@@ -284,3 +286,25 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):  | 
284 | 286 |         use_fp4_custom_triton_dequant_kernel,  | 
285 | 287 |     )  | 
286 | 288 |     torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0)  | 
 | 289 | + | 
 | 290 | + | 
 | 291 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")  | 
 | 292 | +@pytest.mark.skipif(  | 
 | 293 | +    is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"  | 
 | 294 | +)  | 
 | 295 | +@pytest.mark.skipif(  | 
 | 296 | +    not is_sm_at_least_89(),  | 
 | 297 | +    reason="float8 in triton requires CUDA capability 8.9 or greater",  | 
 | 298 | +)  | 
 | 299 | +def test_to_mx_inductor_single_kernel():  | 
 | 300 | +    """  | 
 | 301 | +    Verify that inductor can fuse the cast of a high precision tensor to mx  | 
 | 302 | +    into a single kernel  | 
 | 303 | +    """  | 
 | 304 | +    # TODO(future PR): add fp4 and fp6 here  | 
 | 305 | +    # TODO(#1773): add swizzled scale format here  | 
 | 306 | +    x = torch.randn(2048, 2048, dtype=torch.bfloat16, device="cuda")  | 
 | 307 | +    block_size = 32  | 
 | 308 | +    to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True)  | 
 | 309 | +    out, code = run_and_get_code(to_mx_c, x, torch.float8_e4m3fn, block_size)  | 
 | 310 | +    FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run(code[0])  | 
0 commit comments