|
2 | 2 | from typing import Dict, List, Optional, Tuple, Union |
3 | 3 |
|
4 | 4 | import torch |
5 | | -from compressed_tensors.quantization import disable_quantization |
| 5 | +from compressed_tensors.quantization import ( |
| 6 | + disable_quantization, |
| 7 | + find_name_or_class_matches, |
| 8 | +) |
6 | 9 | from compressed_tensors.utils import ( |
7 | 10 | align_module_device, |
8 | 11 | get_execution_device, |
@@ -308,9 +311,8 @@ def _set_resolved_mappings(self, model: Module) -> None: |
308 | 311 | smooth_names = [ |
309 | 312 | smooth_name |
310 | 313 | for smooth_name in smooth_layers |
311 | | - if ( |
312 | | - smooth_name not in self.ignore |
313 | | - and not smooth_name.endswith("_observer") |
| 314 | + if not find_name_or_class_matches( |
| 315 | + smooth_name, model, self.ignore + ["re:.*_observer$"] |
314 | 316 | ) |
315 | 317 | ] |
316 | 318 |
|
@@ -340,15 +342,15 @@ def _set_resolved_mappings(self, model: Module) -> None: |
340 | 342 | if ( |
341 | 343 | isinstance(smooth_layer, torch.nn.Linear) |
342 | 344 | and isinstance(balance_layer, torch.nn.Linear) |
343 | | - and ".o_proj" in balance_name |
| 345 | + and balance_name.endswith(".o_proj") |
344 | 346 | and ( |
345 | 347 | ( |
346 | | - ".v_proj" in smooth_name |
| 348 | + smooth_name.endswith(".v_proj") |
347 | 349 | and smooth_layer.out_features |
348 | 350 | != balance_layer.in_features |
349 | 351 | ) |
350 | 352 | or ( |
351 | | - ".qkv_proj" in smooth_name |
| 353 | + smooth_name.endswith(".qkv_proj") |
352 | 354 | and smooth_layer.out_features |
353 | 355 | != 3 * balance_layer.in_features |
354 | 356 | ) |
@@ -475,7 +477,7 @@ def _apply_smoothing(self, model: Module) -> None: |
475 | 477 | # [STEP 3]: Compute output of module |
476 | 478 | # could cache from hook, rather than recomputing here |
477 | 479 | fp16_output = self._run_samples(parent_module) |
478 | | - if fp16_output.shape[0] == 0: |
| 480 | + if fp16_output.numel() == 0: |
479 | 481 | logger.info( |
480 | 482 | f"Skipping smooth_layer {mapping.smooth_name}, no activations " |
481 | 483 | "found to scale. This can occasionally occur in MoE models " |
@@ -549,6 +551,7 @@ def _run_samples(self, module: Module) -> torch.Tensor: |
549 | 551 | ] |
550 | 552 | return torch.cat( |
551 | 553 | [ |
| 554 | + # If Tuple, assume that first argument is the input |
552 | 555 | output[0] if isinstance(output, Tuple) else output |
553 | 556 | for output in outputs |
554 | 557 | ], |
@@ -751,9 +754,14 @@ def _accumulate_mean( |
751 | 754 |
|
752 | 755 | def get_lowest_common_parent(names: List[str], module: Module) -> Tuple[str, Module]: |
753 | 756 | """ |
754 | | - Given a list of names, returns the lowest-scope common parent, |
755 | | - excluding parents of type ModuleList, which don't seem to play |
756 | | - nicely with hooks. |
| 757 | + Given a list of names, returns the lowest-scope common parent. |
| 758 | +
|
| 759 | + NOTE: function excludes parents of type ModuleList, which don't play |
| 760 | + nicely with hooks because their forward method is never directly |
| 761 | + called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts |
| 762 | + are selected based on router output and their forward method is called. |
| 763 | + https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233 |
| 764 | +
|
757 | 765 | Returns name of parent and pointer to parent module |
758 | 766 |
|
759 | 767 | Implementation is a small alteration of os.path.commonprefix |
|
0 commit comments