|
13 | 13 | from torchao.core.config import AOBaseConfig |
14 | 14 | from torchao.prototype.mx_formats.constants import ( |
15 | 15 | DTYPE_FP4, |
| 16 | + DTYPE_FP6_E2M3, |
| 17 | + DTYPE_FP6_E3M2, |
16 | 18 | DTYPE_TO_SHORT_STR, |
17 | 19 | SUPPORTED_ELEM_DTYPES, |
18 | 20 | ) |
@@ -41,6 +43,31 @@ class MXLinearRecipeName(Enum): |
41 | 43 | MXFP4_CUTLASS = "mxfp4_cutlass" |
42 | 44 |
|
43 | 45 |
|
| 46 | +def _validate_elem_dtype(elem_dtype): |
| 47 | + assert ( |
| 48 | + elem_dtype in SUPPORTED_ELEM_DTYPES |
| 49 | + ), f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {elem_dtype}" |
| 50 | + |
| 51 | + |
| 52 | +def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype): |
| 53 | + if gemm_kernel_choice == MXGemmKernelChoice.CUTLASS: |
| 54 | + assert ( |
| 55 | + block_size == 32 |
| 56 | + ), f"block_size must be 32 to use the CUTLASS MX gemm kernels, got {block_size}" |
| 57 | + valid_dtypes = [torch.float8_e4m3fn, DTYPE_FP4] |
| 58 | + assert ( |
| 59 | + elem_dtype in valid_dtypes |
| 60 | + ), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}" |
| 61 | + elif gemm_kernel_choice == MXGemmKernelChoice.CUBLAS: |
| 62 | + assert ( |
| 63 | + block_size == 32 |
| 64 | + ), f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {block_size}" |
| 65 | + valid_dtypes = [torch.float8_e4m3fn] |
| 66 | + assert ( |
| 67 | + elem_dtype in valid_dtypes |
| 68 | + ), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}" |
| 69 | + |
| 70 | + |
44 | 71 | @dataclass |
45 | 72 | class MXLinearConfig(AOBaseConfig): |
46 | 73 | # block size for scaling, default is 32 to match |
@@ -68,53 +95,17 @@ class MXLinearConfig(AOBaseConfig): |
68 | 95 | # If True, uses a custom triton kernel for fp4 dequantize |
69 | 96 | use_fp4_custom_triton_dequant_kernel: bool = False |
70 | 97 |
|
71 | | - # If True, packs 4xFP6 into 3xuint8 containers for inference, using custom triton |
72 | | - # kernels (fused unpack/dequantize). Training not currently supported. |
73 | | - pack_fp6 = True if hasattr(torch.library, "custom_op") else False |
74 | | - |
75 | 98 | def __post_init__(self): |
76 | | - # validate elem_dtype and its overrides |
77 | | - assert ( |
78 | | - self.elem_dtype in SUPPORTED_ELEM_DTYPES |
79 | | - ), f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" |
| 99 | + _validate_elem_dtype(self.elem_dtype) |
| 100 | + _validate_gemm_kernel_choice( |
| 101 | + self.gemm_kernel_choice, self.block_size, self.elem_dtype |
| 102 | + ) |
80 | 103 | if self.elem_dtype_weight_override is not None: |
81 | | - assert ( |
82 | | - self.elem_dtype_weight_override in SUPPORTED_ELEM_DTYPES |
83 | | - ), f"elem_dtype_weight_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" |
| 104 | + _validate_elem_dtype(self.elem_dtype_weight_override) |
| 105 | + assert self.gemm_kernel_choice == MXGemmKernelChoice.EMULATED, "unsupported" |
84 | 106 | if self.elem_dtype_grad_output_override is not None: |
85 | | - assert ( |
86 | | - self.elem_dtype_grad_output_override in SUPPORTED_ELEM_DTYPES |
87 | | - ), f"elem_dtype_grad_output_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" |
88 | | - |
89 | | - # validate that block size and elem_dtype matches kernel choice |
90 | | - if self.gemm_kernel_choice == MXGemmKernelChoice.CUTLASS: |
91 | | - assert ( |
92 | | - self.block_size == 32 |
93 | | - ), f"block_size must be 32 to use the CUTLASS MX gemm kernels, got {self.block_size}" |
94 | | - valid_dtypes = [torch.float8_e4m3fn, DTYPE_FP4] |
95 | | - assert ( |
96 | | - self.elem_dtype in valid_dtypes |
97 | | - ), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {self.elem_dtype}" |
98 | | - assert ( |
99 | | - self.elem_dtype_weight_override is None |
100 | | - ), "elem_dtype_weight_override not supported for CUTLASS MX gemm kernels" |
101 | | - assert ( |
102 | | - self.elem_dtype_grad_output_override is None |
103 | | - ), "elem_dtype_grad_output_override not supported for CUTLASS MX gemm kernels" |
104 | | - elif self.gemm_kernel_choice == MXGemmKernelChoice.CUBLAS: |
105 | | - assert ( |
106 | | - self.block_size == 32 |
107 | | - ), f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {self.block_size}" |
108 | | - valid_dtypes = [torch.float8_e4m3fn] |
109 | | - assert ( |
110 | | - self.elem_dtype in valid_dtypes |
111 | | - ), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {self.elem_dtype}" |
112 | | - assert ( |
113 | | - self.elem_dtype_weight_override is None |
114 | | - ), "elem_dtype_weight_override not supported for CUTLASS MX gemm kernels" |
115 | | - assert ( |
116 | | - self.elem_dtype_grad_output_override is None |
117 | | - ), "elem_dtype_grad_output_override not supported for CUTLASS MX gemm kernels" |
| 107 | + _validate_elem_dtype(self.elem_dtype_grad_output_override) |
| 108 | + assert self.gemm_kernel_choice == MXGemmKernelChoice.EMULATED, "unsupported" |
118 | 109 |
|
119 | 110 | @staticmethod |
120 | 111 | def from_recipe_name( |
@@ -162,5 +153,47 @@ def short_str(self) -> str: |
162 | 153 | s += ", use_fp8_dim1_cast_triton_kernel=True" |
163 | 154 | if self.use_fp4_custom_triton_dequant_kernel: |
164 | 155 | s += ", use_fp4_custom_triton_dequant_kernel=True" |
165 | | - # TODO(future PR): split training from inference and add fp6 here |
166 | 156 | return s |
| 157 | + |
| 158 | + |
| 159 | +@dataclass |
| 160 | +class MXInferenceLinearConfig(AOBaseConfig): |
| 161 | + # block size for scaling, default is 32 to match |
| 162 | + # https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, |
| 163 | + # section 5.2 |
| 164 | + block_size: int = 32 |
| 165 | + |
| 166 | + # element dtype, used for activations, weights and gradients |
| 167 | + elem_dtype: Any = torch.float8_e4m3fn |
| 168 | + # TODO(future PR): support different elem_dtype for activations vs weights |
| 169 | + |
| 170 | + # defines the gemm kernel choice, if the chosen kernel is not supported |
| 171 | + # on the given hardware an exception will be thrown |
| 172 | + gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED |
| 173 | + |
| 174 | + # If True, uses a custom triton kernel for fp4 dequantize |
| 175 | + use_fp4_custom_triton_dequant_kernel: bool = False |
| 176 | + |
| 177 | + # If True, packs 4xFP6 into 3xuint8 containers for inference, using custom triton |
| 178 | + # kernels (fused unpack/dequantize). |
| 179 | + pack_fp6: bool = True |
| 180 | + |
| 181 | + def __post_init__(self): |
| 182 | + _validate_elem_dtype(self.elem_dtype) |
| 183 | + _validate_gemm_kernel_choice( |
| 184 | + self.gemm_kernel_choice, self.block_size, self.elem_dtype |
| 185 | + ) |
| 186 | + |
| 187 | + def short_str(self) -> str: |
| 188 | + """ |
| 189 | + Returns a concise representation of the current config. |
| 190 | + """ |
| 191 | + s = f"bl_sz={self.block_size}, lp_dtype={DTYPE_TO_SHORT_STR[self.elem_dtype]}" |
| 192 | + s += f", kernel={self.gemm_kernel_choice.value}" |
| 193 | + if self.use_fp4_custom_triton_dequant_kernel: |
| 194 | + s += ", use_fp4_custom_triton_dequant_kernel=True" |
| 195 | + if self.elem_dtype in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2) and self.pack_fp6: |
| 196 | + s += ", pack_fp6=True" |
| 197 | + return s |
| 198 | + |
| 199 | + # TODO(future PR): add a recipe to config API for inference |
0 commit comments