|  | 
| 6 | 6 | from typing import Any, Optional, Sequence | 
| 7 | 7 | 
 | 
| 8 | 8 | import coremltools as ct | 
|  | 9 | +import torch | 
| 9 | 10 | 
 | 
| 10 | 11 | from executorch.backends.apple.coreml.compiler import CoreMLBackend | 
| 11 | 12 | from executorch.backends.apple.coreml.partition.coreml_partitioner import ( | 
|  | 
| 18 | 19 | 
 | 
| 19 | 20 | from executorch.exir import EdgeCompileConfig | 
| 20 | 21 | from executorch.export import ( | 
|  | 22 | +    AOQuantizationConfig, | 
| 21 | 23 |     BackendRecipeProvider, | 
| 22 | 24 |     ExportRecipe, | 
| 23 | 25 |     LoweringRecipe, | 
|  | 26 | +    QuantizationRecipe, | 
| 24 | 27 |     RecipeType, | 
| 25 | 28 | ) | 
|  | 29 | +from torchao.quantization.granularity import PerAxis, PerGroup | 
|  | 30 | +from torchao.quantization.quant_api import IntxWeightOnlyConfig | 
| 26 | 31 | 
 | 
| 27 | 32 | 
 | 
| 28 | 33 | class CoreMLRecipeProvider(BackendRecipeProvider): | 
| @@ -50,66 +55,321 @@ def create_recipe( | 
| 50 | 55 |         # Validate kwargs | 
| 51 | 56 |         self._validate_recipe_kwargs(recipe_type, **kwargs) | 
| 52 | 57 | 
 | 
| 53 |  | -        # Parse recipe type to get precision and compute unit | 
| 54 |  | -        precision = None | 
| 55 | 58 |         if recipe_type == CoreMLRecipeType.FP32: | 
| 56 |  | -            precision = ct.precision.FLOAT32 | 
|  | 59 | +            return self._build_fp_recipe(recipe_type, ct.precision.FLOAT32, **kwargs) | 
| 57 | 60 |         elif recipe_type == CoreMLRecipeType.FP16: | 
| 58 |  | -            precision = ct.precision.FLOAT16 | 
| 59 |  | - | 
| 60 |  | -        if precision is None: | 
| 61 |  | -            raise ValueError(f"Unknown precision for recipe: {recipe_type.value}") | 
|  | 61 | +            return self._build_fp_recipe(recipe_type, ct.precision.FLOAT16, **kwargs) | 
|  | 62 | +        elif recipe_type == CoreMLRecipeType.PT2E_INT8_STATIC: | 
|  | 63 | +            return self._build_pt2e_quantized_recipe( | 
|  | 64 | +                recipe_type, activation_dtype=torch.quint8, **kwargs | 
|  | 65 | +            ) | 
|  | 66 | +        elif recipe_type == CoreMLRecipeType.PT2E_INT8_WEIGHT_ONLY: | 
|  | 67 | +            return self._build_pt2e_quantized_recipe( | 
|  | 68 | +                recipe_type, activation_dtype=torch.float32, **kwargs | 
|  | 69 | +            ) | 
|  | 70 | +        elif recipe_type == CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL: | 
|  | 71 | +            return self._build_torchao_quantized_recipe( | 
|  | 72 | +                recipe_type, | 
|  | 73 | +                weight_dtype=torch.int4, | 
|  | 74 | +                is_per_channel=True, | 
|  | 75 | +                **kwargs, | 
|  | 76 | +            ) | 
|  | 77 | +        elif recipe_type == CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP: | 
|  | 78 | +            group_size = kwargs.pop("group_size", 32) | 
|  | 79 | +            return self._build_torchao_quantized_recipe( | 
|  | 80 | +                recipe_type, | 
|  | 81 | +                weight_dtype=torch.int4, | 
|  | 82 | +                is_per_channel=False, | 
|  | 83 | +                group_size=group_size, | 
|  | 84 | +                **kwargs, | 
|  | 85 | +            ) | 
|  | 86 | +        elif recipe_type == CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL: | 
|  | 87 | +            return self._build_torchao_quantized_recipe( | 
|  | 88 | +                recipe_type, weight_dtype=torch.int8, is_per_channel=True, **kwargs | 
|  | 89 | +            ) | 
|  | 90 | +        elif recipe_type == CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP: | 
|  | 91 | +            group_size = kwargs.pop("group_size", 32) | 
|  | 92 | +            return self._build_torchao_quantized_recipe( | 
|  | 93 | +                recipe_type, | 
|  | 94 | +                weight_dtype=torch.int8, | 
|  | 95 | +                is_per_channel=False, | 
|  | 96 | +                group_size=group_size, | 
|  | 97 | +                **kwargs, | 
|  | 98 | +            ) | 
|  | 99 | +        elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY: | 
|  | 100 | +            bits = kwargs.pop("bits") | 
|  | 101 | +            block_size = kwargs.pop("block_size") | 
|  | 102 | +            return self._build_codebook_quantized_recipe( | 
|  | 103 | +                recipe_type, bits=bits, block_size=block_size, **kwargs | 
|  | 104 | +            ) | 
| 62 | 105 | 
 | 
| 63 |  | -        return self._build_recipe(recipe_type, precision, **kwargs) | 
|  | 106 | +        return None | 
| 64 | 107 | 
 | 
| 65 | 108 |     def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None: | 
| 66 |  | -        if not kwargs: | 
| 67 |  | -            return | 
| 68 |  | -        expected_keys = {"minimum_deployment_target", "compute_unit"} | 
|  | 109 | +        """Validate kwargs for each recipe type""" | 
|  | 110 | +        expected_keys = self._get_expected_keys(recipe_type) | 
|  | 111 | + | 
| 69 | 112 |         unexpected = set(kwargs.keys()) - expected_keys | 
| 70 | 113 |         if unexpected: | 
| 71 | 114 |             raise ValueError( | 
| 72 |  | -                f"CoreML Recipes only accept 'minimum_deployment_target' or 'compute_unit' as parameter. " | 
| 73 |  | -                f"Unexpected parameters: {list(unexpected)}" | 
|  | 115 | +                f"Recipe '{recipe_type.value}' received unexpected parameters: {list(unexpected)}" | 
| 74 | 116 |             ) | 
|  | 117 | + | 
|  | 118 | +        self._validate_base_parameters(kwargs) | 
|  | 119 | +        self._validate_group_size_parameter(recipe_type, kwargs) | 
|  | 120 | +        self._validate_codebook_parameters(recipe_type, kwargs) | 
|  | 121 | + | 
|  | 122 | +    def _get_expected_keys(self, recipe_type: RecipeType) -> set: | 
|  | 123 | +        """Get expected parameter keys for a recipe type""" | 
|  | 124 | +        common_keys = {"minimum_deployment_target", "compute_unit"} | 
|  | 125 | + | 
|  | 126 | +        if recipe_type in [ | 
|  | 127 | +            CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP, | 
|  | 128 | +            CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP, | 
|  | 129 | +        ]: | 
|  | 130 | +            return common_keys | {"group_size", "filter_fn"} | 
|  | 131 | +        elif recipe_type in [ | 
|  | 132 | +            CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL, | 
|  | 133 | +            CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL, | 
|  | 134 | +        ]: | 
|  | 135 | +            return common_keys | {"filter_fn"} | 
|  | 136 | +        elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY: | 
|  | 137 | +            return common_keys | {"bits", "block_size", "filter_fn"} | 
|  | 138 | +        else: | 
|  | 139 | +            return common_keys | 
|  | 140 | + | 
|  | 141 | +    def _validate_base_parameters(self, kwargs: Any) -> None: | 
|  | 142 | +        """Validate minimum_deployment_target and compute_unit parameters""" | 
| 75 | 143 |         if "minimum_deployment_target" in kwargs: | 
| 76 | 144 |             minimum_deployment_target = kwargs["minimum_deployment_target"] | 
| 77 | 145 |             if not isinstance(minimum_deployment_target, ct.target): | 
| 78 | 146 |                 raise ValueError( | 
| 79 | 147 |                     f"Parameter 'minimum_deployment_target' must be an enum of type ct.target, got {type(minimum_deployment_target)}" | 
| 80 | 148 |                 ) | 
|  | 149 | + | 
| 81 | 150 |         if "compute_unit" in kwargs: | 
| 82 | 151 |             compute_unit = kwargs["compute_unit"] | 
| 83 | 152 |             if not isinstance(compute_unit, ct.ComputeUnit): | 
| 84 | 153 |                 raise ValueError( | 
| 85 | 154 |                     f"Parameter 'compute_unit' must be an enum of type ct.ComputeUnit, got {type(compute_unit)}" | 
| 86 | 155 |                 ) | 
| 87 | 156 | 
 | 
| 88 |  | -    def _build_recipe( | 
|  | 157 | +    def _validate_group_size_parameter( | 
|  | 158 | +        self, recipe_type: RecipeType, kwargs: Any | 
|  | 159 | +    ) -> None: | 
|  | 160 | +        """Validate group_size parameter for applicable recipe types""" | 
|  | 161 | +        if ( | 
|  | 162 | +            recipe_type | 
|  | 163 | +            in [ | 
|  | 164 | +                CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP, | 
|  | 165 | +                CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP, | 
|  | 166 | +            ] | 
|  | 167 | +            and "group_size" in kwargs | 
|  | 168 | +        ): | 
|  | 169 | +            group_size = kwargs["group_size"] | 
|  | 170 | +            if not isinstance(group_size, int): | 
|  | 171 | +                raise ValueError( | 
|  | 172 | +                    f"Parameter 'group_size' must be an integer, got {type(group_size).__name__}: {group_size}" | 
|  | 173 | +                ) | 
|  | 174 | +            if group_size <= 0: | 
|  | 175 | +                raise ValueError( | 
|  | 176 | +                    f"Parameter 'group_size' must be positive, got: {group_size}" | 
|  | 177 | +                ) | 
|  | 178 | + | 
|  | 179 | +    def _validate_codebook_parameters( | 
|  | 180 | +        self, recipe_type: RecipeType, kwargs: Any | 
|  | 181 | +    ) -> None: | 
|  | 182 | +        """Validate bits and block_size parameters for codebook recipe type""" | 
|  | 183 | +        if recipe_type != CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY: | 
|  | 184 | +            return | 
|  | 185 | + | 
|  | 186 | +        # Both bits and block_size must be present | 
|  | 187 | +        if not ("bits" in kwargs and "block_size" in kwargs): | 
|  | 188 | +            raise ValueError( | 
|  | 189 | +                "Parameters 'bits' and 'block_size' must be present for codebook recipes" | 
|  | 190 | +            ) | 
|  | 191 | + | 
|  | 192 | +        if "bits" in kwargs: | 
|  | 193 | +            bits = kwargs["bits"] | 
|  | 194 | +            if not isinstance(bits, int): | 
|  | 195 | +                raise ValueError( | 
|  | 196 | +                    f"Parameter 'bits' must be an integer, got {type(bits).__name__}: {bits}" | 
|  | 197 | +                ) | 
|  | 198 | +            if not (1 <= bits <= 8): | 
|  | 199 | +                raise ValueError( | 
|  | 200 | +                    f"Parameter 'bits' must be between 1 and 8, got: {bits}" | 
|  | 201 | +                ) | 
|  | 202 | + | 
|  | 203 | +        if "block_size" in kwargs: | 
|  | 204 | +            block_size = kwargs["block_size"] | 
|  | 205 | +            if not isinstance(block_size, list): | 
|  | 206 | +                raise ValueError( | 
|  | 207 | +                    f"Parameter 'block_size' must be a list, got {type(block_size).__name__}: {block_size}" | 
|  | 208 | +                ) | 
|  | 209 | + | 
|  | 210 | +    def _validate_and_set_deployment_target( | 
|  | 211 | +        self, kwargs: Any, min_target: ct.target, quantization_type: str | 
|  | 212 | +    ) -> None: | 
|  | 213 | +        """Validate or set minimum deployment target for quantization recipes""" | 
|  | 214 | +        minimum_deployment_target = kwargs.get("minimum_deployment_target", None) | 
|  | 215 | +        if minimum_deployment_target and minimum_deployment_target < min_target: | 
|  | 216 | +            raise ValueError( | 
|  | 217 | +                f"minimum_deployment_target must be {str(min_target)} or higher for {quantization_type} quantization" | 
|  | 218 | +            ) | 
|  | 219 | +        else: | 
|  | 220 | +            # Default to the minimum target for this quantization type | 
|  | 221 | +            kwargs["minimum_deployment_target"] = min_target | 
|  | 222 | + | 
|  | 223 | +    def _build_fp_recipe( | 
| 89 | 224 |         self, | 
| 90 | 225 |         recipe_type: RecipeType, | 
| 91 | 226 |         precision: ct.precision, | 
| 92 | 227 |         **kwargs: Any, | 
| 93 | 228 |     ) -> ExportRecipe: | 
|  | 229 | +        """Build FP32/FP16 recipe""" | 
| 94 | 230 |         lowering_recipe = self._get_coreml_lowering_recipe( | 
| 95 | 231 |             compute_precision=precision, | 
| 96 | 232 |             **kwargs, | 
| 97 | 233 |         ) | 
| 98 | 234 | 
 | 
| 99 | 235 |         return ExportRecipe( | 
| 100 | 236 |             name=recipe_type.value, | 
| 101 |  | -            quantization_recipe=None,  # TODO - add quantization recipe | 
|  | 237 | +            lowering_recipe=lowering_recipe, | 
|  | 238 | +        ) | 
|  | 239 | + | 
|  | 240 | +    def _build_pt2e_quantized_recipe( | 
|  | 241 | +        self, | 
|  | 242 | +        recipe_type: RecipeType, | 
|  | 243 | +        activation_dtype: torch.dtype, | 
|  | 244 | +        **kwargs: Any, | 
|  | 245 | +    ) -> ExportRecipe: | 
|  | 246 | +        """Build PT2E-based quantization recipe""" | 
|  | 247 | +        from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer | 
|  | 248 | + | 
|  | 249 | +        self._validate_and_set_deployment_target(kwargs, ct.target.iOS17, "pt2e") | 
|  | 250 | + | 
|  | 251 | +        # Validate activation_dtype | 
|  | 252 | +        assert activation_dtype in [ | 
|  | 253 | +            torch.quint8, | 
|  | 254 | +            torch.float32, | 
|  | 255 | +        ], f"activation_dtype must be torch.quint8 or torch.float32, got {activation_dtype}" | 
|  | 256 | + | 
|  | 257 | +        # Create quantization config | 
|  | 258 | +        config = ct.optimize.torch.quantization.LinearQuantizerConfig( | 
|  | 259 | +            global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig( | 
|  | 260 | +                quantization_scheme="symmetric", | 
|  | 261 | +                activation_dtype=activation_dtype, | 
|  | 262 | +                weight_dtype=torch.qint8, | 
|  | 263 | +                weight_per_channel=True, | 
|  | 264 | +            ) | 
|  | 265 | +        ) | 
|  | 266 | + | 
|  | 267 | +        quantizer = CoreMLQuantizer(config) | 
|  | 268 | +        quantization_recipe = QuantizationRecipe(quantizers=[quantizer]) | 
|  | 269 | + | 
|  | 270 | +        lowering_recipe = self._get_coreml_lowering_recipe(**kwargs) | 
|  | 271 | + | 
|  | 272 | +        return ExportRecipe( | 
|  | 273 | +            name=recipe_type.value, | 
|  | 274 | +            quantization_recipe=quantization_recipe, | 
|  | 275 | +            lowering_recipe=lowering_recipe, | 
|  | 276 | +        ) | 
|  | 277 | + | 
|  | 278 | +    def _build_torchao_quantized_recipe( | 
|  | 279 | +        self, | 
|  | 280 | +        recipe_type: RecipeType, | 
|  | 281 | +        weight_dtype: torch.dtype, | 
|  | 282 | +        is_per_channel: bool, | 
|  | 283 | +        group_size: int = 32, | 
|  | 284 | +        **kwargs: Any, | 
|  | 285 | +    ) -> ExportRecipe: | 
|  | 286 | +        """Build TorchAO-based quantization recipe""" | 
|  | 287 | +        if is_per_channel: | 
|  | 288 | +            weight_granularity = PerAxis(axis=0) | 
|  | 289 | +        else: | 
|  | 290 | +            weight_granularity = PerGroup(group_size=group_size) | 
|  | 291 | + | 
|  | 292 | +        # Use user-provided filter_fn if provided | 
|  | 293 | +        filter_fn = kwargs.get("filter_fn", None) | 
|  | 294 | +        config = AOQuantizationConfig( | 
|  | 295 | +            ao_base_config=IntxWeightOnlyConfig( | 
|  | 296 | +                weight_dtype=weight_dtype, | 
|  | 297 | +                granularity=weight_granularity, | 
|  | 298 | +            ), | 
|  | 299 | +            filter_fn=filter_fn, | 
|  | 300 | +        ) | 
|  | 301 | + | 
|  | 302 | +        quantization_recipe = QuantizationRecipe( | 
|  | 303 | +            quantizers=None, | 
|  | 304 | +            ao_quantization_configs=[config], | 
|  | 305 | +        ) | 
|  | 306 | + | 
|  | 307 | +        # override minimum_deployment_target to ios18 for torchao (GH issue #13122) | 
|  | 308 | +        self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "torchao") | 
|  | 309 | +        lowering_recipe = self._get_coreml_lowering_recipe(**kwargs) | 
|  | 310 | + | 
|  | 311 | +        return ExportRecipe( | 
|  | 312 | +            name=recipe_type.value, | 
|  | 313 | +            quantization_recipe=quantization_recipe, | 
|  | 314 | +            lowering_recipe=lowering_recipe, | 
|  | 315 | +        ) | 
|  | 316 | + | 
|  | 317 | +    def _build_codebook_quantized_recipe( | 
|  | 318 | +        self, | 
|  | 319 | +        recipe_type: RecipeType, | 
|  | 320 | +        bits: int, | 
|  | 321 | +        block_size: list, | 
|  | 322 | +        **kwargs: Any, | 
|  | 323 | +    ) -> ExportRecipe: | 
|  | 324 | +        """Build codebook/palettization quantization recipe""" | 
|  | 325 | +        from torchao.prototype.quantization.codebook_coreml import ( | 
|  | 326 | +            CodebookWeightOnlyConfig, | 
|  | 327 | +        ) | 
|  | 328 | + | 
|  | 329 | +        self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "codebook") | 
|  | 330 | + | 
|  | 331 | +        # Get the appropriate dtype (torch.uint1 through torch.uint8) | 
|  | 332 | +        dtype = getattr(torch, f"uint{bits}") | 
|  | 333 | + | 
|  | 334 | +        # Use user-provided filter_fn or default to Linear/Embedding layers | 
|  | 335 | +        filter_fn = kwargs.get( | 
|  | 336 | +            "filter_fn", | 
|  | 337 | +            lambda m, fqn: ( | 
|  | 338 | +                isinstance(m, torch.nn.Embedding) or isinstance(m, torch.nn.Linear) | 
|  | 339 | +            ), | 
|  | 340 | +        ) | 
|  | 341 | + | 
|  | 342 | +        config = AOQuantizationConfig( | 
|  | 343 | +            ao_base_config=CodebookWeightOnlyConfig( | 
|  | 344 | +                dtype=dtype, | 
|  | 345 | +                block_size=block_size, | 
|  | 346 | +            ), | 
|  | 347 | +            filter_fn=filter_fn, | 
|  | 348 | +        ) | 
|  | 349 | + | 
|  | 350 | +        quantization_recipe = QuantizationRecipe( | 
|  | 351 | +            quantizers=None, | 
|  | 352 | +            ao_quantization_configs=[config], | 
|  | 353 | +        ) | 
|  | 354 | + | 
|  | 355 | +        lowering_recipe = self._get_coreml_lowering_recipe(**kwargs) | 
|  | 356 | + | 
|  | 357 | +        return ExportRecipe( | 
|  | 358 | +            name=recipe_type.value, | 
|  | 359 | +            quantization_recipe=quantization_recipe, | 
| 102 | 360 |             lowering_recipe=lowering_recipe, | 
| 103 | 361 |         ) | 
| 104 | 362 | 
 | 
| 105 | 363 |     def _get_coreml_lowering_recipe( | 
| 106 | 364 |         self, | 
| 107 |  | -        compute_precision: ct.precision, | 
|  | 365 | +        compute_precision: ct.precision = ct.precision.FLOAT16, | 
| 108 | 366 |         **kwargs: Any, | 
| 109 | 367 |     ) -> LoweringRecipe: | 
|  | 368 | +        """Get CoreML lowering recipe with optional precision""" | 
| 110 | 369 |         compile_specs = CoreMLBackend.generate_compile_specs( | 
| 111 | 370 |             compute_precision=compute_precision, | 
| 112 |  | -            **kwargs, | 
|  | 371 | +            compute_unit=kwargs.get("compute_unit", ct.ComputeUnit.ALL), | 
|  | 372 | +            minimum_deployment_target=kwargs.get("minimum_deployment_target", None), | 
| 113 | 373 |         ) | 
| 114 | 374 | 
 | 
| 115 | 375 |         minimum_deployment_target = kwargs.get("minimum_deployment_target", None) | 
|  | 
0 commit comments