|
6 | 6 | from typing import Any, Optional, Sequence |
7 | 7 |
|
8 | 8 | import coremltools as ct |
9 | | -import torch |
10 | 9 |
|
11 | 10 | from executorch.backends.apple.coreml.compiler import CoreMLBackend |
12 | 11 | from executorch.backends.apple.coreml.partition.coreml_partitioner import ( |
|
19 | 18 |
|
20 | 19 | from executorch.exir import EdgeCompileConfig |
21 | 20 | from executorch.export import ( |
22 | | - AOQuantizationConfig, |
23 | 21 | BackendRecipeProvider, |
24 | 22 | ExportRecipe, |
25 | 23 | LoweringRecipe, |
26 | | - QuantizationRecipe, |
27 | 24 | RecipeType, |
28 | 25 | ) |
29 | | -from torchao.quantization.granularity import PerAxis, PerGroup |
30 | | -from torchao.quantization.quant_api import IntxWeightOnlyConfig |
31 | 26 |
|
32 | 27 |
|
33 | 28 | class CoreMLRecipeProvider(BackendRecipeProvider): |
@@ -55,321 +50,66 @@ def create_recipe( |
55 | 50 | # Validate kwargs |
56 | 51 | self._validate_recipe_kwargs(recipe_type, **kwargs) |
57 | 52 |
|
| 53 | + # Parse recipe type to get precision and compute unit |
| 54 | + precision = None |
58 | 55 | if recipe_type == CoreMLRecipeType.FP32: |
59 | | - return self._build_fp_recipe(recipe_type, ct.precision.FLOAT32, **kwargs) |
| 56 | + precision = ct.precision.FLOAT32 |
60 | 57 | elif recipe_type == CoreMLRecipeType.FP16: |
61 | | - return self._build_fp_recipe(recipe_type, ct.precision.FLOAT16, **kwargs) |
62 | | - elif recipe_type == CoreMLRecipeType.PT2E_INT8_STATIC: |
63 | | - return self._build_pt2e_quantized_recipe( |
64 | | - recipe_type, activation_dtype=torch.quint8, **kwargs |
65 | | - ) |
66 | | - elif recipe_type == CoreMLRecipeType.PT2E_INT8_WEIGHT_ONLY: |
67 | | - return self._build_pt2e_quantized_recipe( |
68 | | - recipe_type, activation_dtype=torch.float32, **kwargs |
69 | | - ) |
70 | | - elif recipe_type == CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL: |
71 | | - return self._build_torchao_quantized_recipe( |
72 | | - recipe_type, |
73 | | - weight_dtype=torch.int4, |
74 | | - is_per_channel=True, |
75 | | - **kwargs, |
76 | | - ) |
77 | | - elif recipe_type == CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP: |
78 | | - group_size = kwargs.pop("group_size", 32) |
79 | | - return self._build_torchao_quantized_recipe( |
80 | | - recipe_type, |
81 | | - weight_dtype=torch.int4, |
82 | | - is_per_channel=False, |
83 | | - group_size=group_size, |
84 | | - **kwargs, |
85 | | - ) |
86 | | - elif recipe_type == CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL: |
87 | | - return self._build_torchao_quantized_recipe( |
88 | | - recipe_type, weight_dtype=torch.int8, is_per_channel=True, **kwargs |
89 | | - ) |
90 | | - elif recipe_type == CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP: |
91 | | - group_size = kwargs.pop("group_size", 32) |
92 | | - return self._build_torchao_quantized_recipe( |
93 | | - recipe_type, |
94 | | - weight_dtype=torch.int8, |
95 | | - is_per_channel=False, |
96 | | - group_size=group_size, |
97 | | - **kwargs, |
98 | | - ) |
99 | | - elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY: |
100 | | - bits = kwargs.pop("bits") |
101 | | - block_size = kwargs.pop("block_size") |
102 | | - return self._build_codebook_quantized_recipe( |
103 | | - recipe_type, bits=bits, block_size=block_size, **kwargs |
104 | | - ) |
| 58 | + precision = ct.precision.FLOAT16 |
105 | 59 |
|
106 | | - return None |
| 60 | + if precision is None: |
| 61 | + raise ValueError(f"Unknown precision for recipe: {recipe_type.value}") |
107 | 62 |
|
108 | | - def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None: |
109 | | - """Validate kwargs for each recipe type""" |
110 | | - expected_keys = self._get_expected_keys(recipe_type) |
| 63 | + return self._build_recipe(recipe_type, precision, **kwargs) |
111 | 64 |
|
| 65 | + def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None: |
| 66 | + if not kwargs: |
| 67 | + return |
| 68 | + expected_keys = {"minimum_deployment_target", "compute_unit"} |
112 | 69 | unexpected = set(kwargs.keys()) - expected_keys |
113 | 70 | if unexpected: |
114 | 71 | raise ValueError( |
115 | | - f"Recipe '{recipe_type.value}' received unexpected parameters: {list(unexpected)}" |
| 72 | + f"CoreML Recipes only accept 'minimum_deployment_target' or 'compute_unit' as parameter. " |
| 73 | + f"Unexpected parameters: {list(unexpected)}" |
116 | 74 | ) |
117 | | - |
118 | | - self._validate_base_parameters(kwargs) |
119 | | - self._validate_group_size_parameter(recipe_type, kwargs) |
120 | | - self._validate_codebook_parameters(recipe_type, kwargs) |
121 | | - |
122 | | - def _get_expected_keys(self, recipe_type: RecipeType) -> set: |
123 | | - """Get expected parameter keys for a recipe type""" |
124 | | - common_keys = {"minimum_deployment_target", "compute_unit"} |
125 | | - |
126 | | - if recipe_type in [ |
127 | | - CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP, |
128 | | - CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP, |
129 | | - ]: |
130 | | - return common_keys | {"group_size", "filter_fn"} |
131 | | - elif recipe_type in [ |
132 | | - CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL, |
133 | | - CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL, |
134 | | - ]: |
135 | | - return common_keys | {"filter_fn"} |
136 | | - elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY: |
137 | | - return common_keys | {"bits", "block_size", "filter_fn"} |
138 | | - else: |
139 | | - return common_keys |
140 | | - |
141 | | - def _validate_base_parameters(self, kwargs: Any) -> None: |
142 | | - """Validate minimum_deployment_target and compute_unit parameters""" |
143 | 75 | if "minimum_deployment_target" in kwargs: |
144 | 76 | minimum_deployment_target = kwargs["minimum_deployment_target"] |
145 | 77 | if not isinstance(minimum_deployment_target, ct.target): |
146 | 78 | raise ValueError( |
147 | 79 | f"Parameter 'minimum_deployment_target' must be an enum of type ct.target, got {type(minimum_deployment_target)}" |
148 | 80 | ) |
149 | | - |
150 | 81 | if "compute_unit" in kwargs: |
151 | 82 | compute_unit = kwargs["compute_unit"] |
152 | 83 | if not isinstance(compute_unit, ct.ComputeUnit): |
153 | 84 | raise ValueError( |
154 | 85 | f"Parameter 'compute_unit' must be an enum of type ct.ComputeUnit, got {type(compute_unit)}" |
155 | 86 | ) |
156 | 87 |
|
157 | | - def _validate_group_size_parameter( |
158 | | - self, recipe_type: RecipeType, kwargs: Any |
159 | | - ) -> None: |
160 | | - """Validate group_size parameter for applicable recipe types""" |
161 | | - if ( |
162 | | - recipe_type |
163 | | - in [ |
164 | | - CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP, |
165 | | - CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP, |
166 | | - ] |
167 | | - and "group_size" in kwargs |
168 | | - ): |
169 | | - group_size = kwargs["group_size"] |
170 | | - if not isinstance(group_size, int): |
171 | | - raise ValueError( |
172 | | - f"Parameter 'group_size' must be an integer, got {type(group_size).__name__}: {group_size}" |
173 | | - ) |
174 | | - if group_size <= 0: |
175 | | - raise ValueError( |
176 | | - f"Parameter 'group_size' must be positive, got: {group_size}" |
177 | | - ) |
178 | | - |
179 | | - def _validate_codebook_parameters( |
180 | | - self, recipe_type: RecipeType, kwargs: Any |
181 | | - ) -> None: |
182 | | - """Validate bits and block_size parameters for codebook recipe type""" |
183 | | - if recipe_type != CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY: |
184 | | - return |
185 | | - |
186 | | - # Both bits and block_size must be present |
187 | | - if not ("bits" in kwargs and "block_size" in kwargs): |
188 | | - raise ValueError( |
189 | | - "Parameters 'bits' and 'block_size' must be present for codebook recipes" |
190 | | - ) |
191 | | - |
192 | | - if "bits" in kwargs: |
193 | | - bits = kwargs["bits"] |
194 | | - if not isinstance(bits, int): |
195 | | - raise ValueError( |
196 | | - f"Parameter 'bits' must be an integer, got {type(bits).__name__}: {bits}" |
197 | | - ) |
198 | | - if not (1 <= bits <= 8): |
199 | | - raise ValueError( |
200 | | - f"Parameter 'bits' must be between 1 and 8, got: {bits}" |
201 | | - ) |
202 | | - |
203 | | - if "block_size" in kwargs: |
204 | | - block_size = kwargs["block_size"] |
205 | | - if not isinstance(block_size, list): |
206 | | - raise ValueError( |
207 | | - f"Parameter 'block_size' must be a list, got {type(block_size).__name__}: {block_size}" |
208 | | - ) |
209 | | - |
210 | | - def _validate_and_set_deployment_target( |
211 | | - self, kwargs: Any, min_target: ct.target, quantization_type: str |
212 | | - ) -> None: |
213 | | - """Validate or set minimum deployment target for quantization recipes""" |
214 | | - minimum_deployment_target = kwargs.get("minimum_deployment_target", None) |
215 | | - if minimum_deployment_target and minimum_deployment_target < min_target: |
216 | | - raise ValueError( |
217 | | - f"minimum_deployment_target must be {str(min_target)} or higher for {quantization_type} quantization" |
218 | | - ) |
219 | | - else: |
220 | | - # Default to the minimum target for this quantization type |
221 | | - kwargs["minimum_deployment_target"] = min_target |
222 | | - |
223 | | - def _build_fp_recipe( |
| 88 | + def _build_recipe( |
224 | 89 | self, |
225 | 90 | recipe_type: RecipeType, |
226 | 91 | precision: ct.precision, |
227 | 92 | **kwargs: Any, |
228 | 93 | ) -> ExportRecipe: |
229 | | - """Build FP32/FP16 recipe""" |
230 | 94 | lowering_recipe = self._get_coreml_lowering_recipe( |
231 | 95 | compute_precision=precision, |
232 | 96 | **kwargs, |
233 | 97 | ) |
234 | 98 |
|
235 | 99 | return ExportRecipe( |
236 | 100 | name=recipe_type.value, |
237 | | - lowering_recipe=lowering_recipe, |
238 | | - ) |
239 | | - |
240 | | - def _build_pt2e_quantized_recipe( |
241 | | - self, |
242 | | - recipe_type: RecipeType, |
243 | | - activation_dtype: torch.dtype, |
244 | | - **kwargs: Any, |
245 | | - ) -> ExportRecipe: |
246 | | - """Build PT2E-based quantization recipe""" |
247 | | - from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer |
248 | | - |
249 | | - self._validate_and_set_deployment_target(kwargs, ct.target.iOS17, "pt2e") |
250 | | - |
251 | | - # Validate activation_dtype |
252 | | - assert activation_dtype in [ |
253 | | - torch.quint8, |
254 | | - torch.float32, |
255 | | - ], f"activation_dtype must be torch.quint8 or torch.float32, got {activation_dtype}" |
256 | | - |
257 | | - # Create quantization config |
258 | | - config = ct.optimize.torch.quantization.LinearQuantizerConfig( |
259 | | - global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig( |
260 | | - quantization_scheme="symmetric", |
261 | | - activation_dtype=activation_dtype, |
262 | | - weight_dtype=torch.qint8, |
263 | | - weight_per_channel=True, |
264 | | - ) |
265 | | - ) |
266 | | - |
267 | | - quantizer = CoreMLQuantizer(config) |
268 | | - quantization_recipe = QuantizationRecipe(quantizers=[quantizer]) |
269 | | - |
270 | | - lowering_recipe = self._get_coreml_lowering_recipe(**kwargs) |
271 | | - |
272 | | - return ExportRecipe( |
273 | | - name=recipe_type.value, |
274 | | - quantization_recipe=quantization_recipe, |
275 | | - lowering_recipe=lowering_recipe, |
276 | | - ) |
277 | | - |
278 | | - def _build_torchao_quantized_recipe( |
279 | | - self, |
280 | | - recipe_type: RecipeType, |
281 | | - weight_dtype: torch.dtype, |
282 | | - is_per_channel: bool, |
283 | | - group_size: int = 32, |
284 | | - **kwargs: Any, |
285 | | - ) -> ExportRecipe: |
286 | | - """Build TorchAO-based quantization recipe""" |
287 | | - if is_per_channel: |
288 | | - weight_granularity = PerAxis(axis=0) |
289 | | - else: |
290 | | - weight_granularity = PerGroup(group_size=group_size) |
291 | | - |
292 | | - # Use user-provided filter_fn if provided |
293 | | - filter_fn = kwargs.get("filter_fn", None) |
294 | | - config = AOQuantizationConfig( |
295 | | - ao_base_config=IntxWeightOnlyConfig( |
296 | | - weight_dtype=weight_dtype, |
297 | | - granularity=weight_granularity, |
298 | | - ), |
299 | | - filter_fn=filter_fn, |
300 | | - ) |
301 | | - |
302 | | - quantization_recipe = QuantizationRecipe( |
303 | | - quantizers=None, |
304 | | - ao_quantization_configs=[config], |
305 | | - ) |
306 | | - |
307 | | - # override minimum_deployment_target to ios18 for torchao (GH issue #13122) |
308 | | - self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "torchao") |
309 | | - lowering_recipe = self._get_coreml_lowering_recipe(**kwargs) |
310 | | - |
311 | | - return ExportRecipe( |
312 | | - name=recipe_type.value, |
313 | | - quantization_recipe=quantization_recipe, |
314 | | - lowering_recipe=lowering_recipe, |
315 | | - ) |
316 | | - |
317 | | - def _build_codebook_quantized_recipe( |
318 | | - self, |
319 | | - recipe_type: RecipeType, |
320 | | - bits: int, |
321 | | - block_size: list, |
322 | | - **kwargs: Any, |
323 | | - ) -> ExportRecipe: |
324 | | - """Build codebook/palettization quantization recipe""" |
325 | | - from torchao.prototype.quantization.codebook_coreml import ( |
326 | | - CodebookWeightOnlyConfig, |
327 | | - ) |
328 | | - |
329 | | - self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "codebook") |
330 | | - |
331 | | - # Get the appropriate dtype (torch.uint1 through torch.uint8) |
332 | | - dtype = getattr(torch, f"uint{bits}") |
333 | | - |
334 | | - # Use user-provided filter_fn or default to Linear/Embedding layers |
335 | | - filter_fn = kwargs.get( |
336 | | - "filter_fn", |
337 | | - lambda m, fqn: ( |
338 | | - isinstance(m, torch.nn.Embedding) or isinstance(m, torch.nn.Linear) |
339 | | - ), |
340 | | - ) |
341 | | - |
342 | | - config = AOQuantizationConfig( |
343 | | - ao_base_config=CodebookWeightOnlyConfig( |
344 | | - dtype=dtype, |
345 | | - block_size=block_size, |
346 | | - ), |
347 | | - filter_fn=filter_fn, |
348 | | - ) |
349 | | - |
350 | | - quantization_recipe = QuantizationRecipe( |
351 | | - quantizers=None, |
352 | | - ao_quantization_configs=[config], |
353 | | - ) |
354 | | - |
355 | | - lowering_recipe = self._get_coreml_lowering_recipe(**kwargs) |
356 | | - |
357 | | - return ExportRecipe( |
358 | | - name=recipe_type.value, |
359 | | - quantization_recipe=quantization_recipe, |
| 101 | + quantization_recipe=None, # TODO - add quantization recipe |
360 | 102 | lowering_recipe=lowering_recipe, |
361 | 103 | ) |
362 | 104 |
|
363 | 105 | def _get_coreml_lowering_recipe( |
364 | 106 | self, |
365 | | - compute_precision: ct.precision = ct.precision.FLOAT16, |
| 107 | + compute_precision: ct.precision, |
366 | 108 | **kwargs: Any, |
367 | 109 | ) -> LoweringRecipe: |
368 | | - """Get CoreML lowering recipe with optional precision""" |
369 | 110 | compile_specs = CoreMLBackend.generate_compile_specs( |
370 | 111 | compute_precision=compute_precision, |
371 | | - compute_unit=kwargs.get("compute_unit", ct.ComputeUnit.ALL), |
372 | | - minimum_deployment_target=kwargs.get("minimum_deployment_target", None), |
| 112 | + **kwargs, |
373 | 113 | ) |
374 | 114 |
|
375 | 115 | minimum_deployment_target = kwargs.get("minimum_deployment_target", None) |
|
0 commit comments