1111import torch .nn as nn
1212
1313from torchao .prototype .mx_formats .config import (
14+ MXGemmKernelChoice ,
1415 MXInferenceLinearConfig ,
1516 MXLinearConfig ,
1617 MXLinearRecipeName ,
@@ -380,7 +381,7 @@ def test_inference_print_str():
380381 not TORCH_VERSION_AT_LEAST_2_8 , reason = "torch.compile requires PyTorch 2.8+"
381382)
382383@pytest .mark .skipif (not is_sm_at_least_100 , reason = "Reqs sm100" )
383- @pytest .mark .parametrize ("elem_dtype" , [torch .float8_e4m3fn ])
384+ @pytest .mark .parametrize ("elem_dtype" , [torch .float8_e4m3fn , torch . float4_e2m1fn_x2 ])
384385@pytest .mark .parametrize ("bias" , [True , False ])
385386@pytest .mark .parametrize ("compile" , [True , False ])
386387@torch .no_grad ()
@@ -394,7 +395,16 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
394395
395396 m = nn .Linear (32 , 128 , bias = bias , dtype = torch .bfloat16 , device = "cuda" )
396397 m_mx = copy .deepcopy (m )
397- config = MXFPInferenceConfig ()
398+ kernel_choice = (
399+ MXGemmKernelChoice .CUTLASS
400+ if elem_dtype == DTYPE_FP4
401+ else MXGemmKernelChoice .CUBLAS
402+ )
403+ config = MXFPInferenceConfig (
404+ activation_dtype = elem_dtype ,
405+ weight_dtype = elem_dtype ,
406+ gemm_kernel_choice = kernel_choice ,
407+ )
398408 quantize_ (m_mx , config = config )
399409 if compile :
400410 m_mx = torch .compile (m_mx , fullgraph = True )
@@ -403,4 +413,7 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
403413 y_ref = m (x )
404414 y_mx = m_mx (x )
405415 sqnr = compute_error (y_ref , y_mx )
406- assert sqnr >= 25.0 , f"Got a sqnr of { sqnr } for { elem_dtype } and bias={ bias } "
416+ SQNR_THRESHOLD = 25.0 if elem_dtype == torch .float8_e4m3fn else 15.0
417+ assert sqnr >= SQNR_THRESHOLD , (
418+ f"Got a sqnr of { sqnr } for { elem_dtype } and bias={ bias } "
419+ )
0 commit comments