@@ -70,7 +70,59 @@ quantize_(m, config)
7070
7171## MX inference
7272
73- Coming soon!
73+ ``` python
74+ import copy
75+
76+ import torch
77+ import torch.nn as nn
78+ from torchao.quantization import quantize_
79+ from torchao.prototype.mx_formats.config import (
80+ MXGemmKernelChoice,
81+ )
82+ from torchao.prototype.mx_formats.inference_workflow import (
83+ MXFPInferenceConfig,
84+ NVFP4InferenceConfig,
85+ NVFP4MMConfig,
86+ )
87+
88+ m = nn.Linear(32 , 128 , bias = False , dtype = torch.bfloat16, device = " cuda" )
89+ x = torch.randn(128 , 32 , device = " cuda" , dtype = torch.bfloat16)
90+
91+ # mxfp8
92+
93+ m_mxfp8 = copy.deepcopy(m)
94+ config = MXFPInferenceConfig(
95+ activation_dtype = torch.float8_e4m3fn,
96+ weight_dtype = torch.float8_e4m3fn,
97+ gemm_kernel_choice = MXGemmKernelChoice.CUBLAS ,
98+ )
99+ quantize_(m_mxfp8, config = config)
100+ m_mxfp8 = torch.compile(m_mxfp8, fullgraph = True )
101+ y_mxfp8 = m_mxfp8(x)
102+
103+ # mxfp4
104+
105+ m_mxfp4 = copy.deepcopy(m)
106+ config = MXFPInferenceConfig(
107+ activation_dtype = torch.float4_e2m1fn_x2,
108+ weight_dtype = torch.float4_e2m1fn_x2,
109+ gemm_kernel_choice = MXGemmKernelChoice.CUTLASS ,
110+ )
111+ quantize_(m_mxfp4, config = config)
112+ m_mxfp4 = torch.compile(m_mxfp4, fullgraph = True )
113+ y_mxfp4 = m_mxfp4(x)
114+
115+ # nvfp4
116+
117+ m_nvfp4 = copy.deepcopy(m)
118+ config = NVFP4InferenceConfig(
119+ mm_config = NVFP4MMConfig.DYNAMIC ,
120+ use_dynamic_per_tensor_scale = True ,
121+ )
122+ quantize_(m_nvfp4, config = config)
123+ m_nvfp4 = torch.compile(m_nvfp4, fullgraph = True )
124+ y_nvfp4 = m_nvfp4(x)
125+ ```
74126
75127## MXTensor
76128
0 commit comments