diff --git a/test/prototype/test_parq.py b/test/prototype/test_parq.py index 24154ab703..bd43b06137 100644 --- a/test/prototype/test_parq.py +++ b/test/prototype/test_parq.py @@ -54,9 +54,16 @@ def split_param_groups(model) -> tuple[list, list, list]: params_quant, params_embed, params_no_quant = [], [], [] def get_param_groups(model): + seen_data_ptrs = set() # avoid duplicates in case of tied weights for module in model.children(): is_linear = _is_linear(module) for n, p in module.named_parameters(): + if n == "weight": + data_ptr = p.data_ptr() + if data_ptr in seen_data_ptrs: + continue + seen_data_ptrs.add(data_ptr) + if is_linear and n == "weight": params_quant.append(p) elif isinstance(module, nn.Embedding) and n == "weight": @@ -152,7 +159,12 @@ def compare_parq_convert( def check_torchao_tensor_subclass( test_case: common_utils.TestCase, model: nn.Module, weight_only: bool = False ): - for module in model.modules(): + for name, module in model.named_modules(): + if not hasattr(module, "weight") or f"{name}.weight" in getattr( + model, "_tied_weights_keys", [] + ): + continue + if not weight_only and _is_linear(module): test_case.assertTrue(isinstance(module.weight, IntxUnpackedToInt8Tensor)) test_case.assertTrue( @@ -163,15 +175,40 @@ def check_torchao_tensor_subclass( test_case.assertTrue(module.weight.activation_quantization is None) +def apply_activation_quantization( + model: nn.Module, optimizer: torch.optim.Optimizer, model_dtype: torch.dtype +): + # apply torchao quantized activations on top + activation_config = IntxFakeQuantizeConfig( + torch.int8, "per_token", is_symmetric=False, scale_precision=model_dtype + ) + qat_config = QATConfig(activation_config=activation_config, step="prepare") + for filter_fn in optimizer.get_filter_fns(model): + try: + quantize_(model, qat_config, filter_fn=filter_fn) + except ValueError as e: + if str(e) == "Activation fake quantization is not supported for embedding": + pass + + class M(nn.Module): - def __init__(self, m=256, n=128, k=16, bias=False, embedding=True): + _tied_weights_keys: list[str] = [] + + def __init__( + self, m=256, n=128, k=16, bias=False, embedding=True, tied_weights=False + ): super().__init__() - self.embedding = nn.Embedding(10, m) if embedding else nn.Identity() + self.embedding = nn.Embedding(k, m) if embedding else nn.Identity() self.linear1 = nn.Linear(m, n, bias=bias) self.linear2 = nn.Linear(n, k, bias=bias) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() + if embedding and tied_weights: + assert self.embedding.weight.shape == self.linear2.weight.shape + self.linear2.weight = self.embedding.weight + self._tied_weights_keys.append("linear2.weight") + def reset_parameters(self): for module in (self.linear1, self.linear2): nn.init.xavier_uniform_(module.weight) @@ -179,18 +216,17 @@ def reset_parameters(self): nn.init.zeros_(module.bias) def example_inputs(self, device=None): - return ( - torch.randint(1, 10, (1, self.linear1.in_features), device=device) - if isinstance(self.embedding, nn.Embedding) - else torch.randn(1, self.linear1.in_features, device=device) - ) + if isinstance(self.embedding, nn.Identity): + inputs = torch.randn(1, self.linear1.in_features, device=device) + else: + k = self.embedding.num_embeddings + inputs = torch.randint(1, k, (1, self.linear1.in_features), device=device) + return inputs def forward(self, x): x = self.embedding(x) - x = self.linear1(x) - x = self.relu(x) - x = self.linear2(x) - x = self.sigmoid(x) + x = self.relu(self.linear1(x)) + x = self.sigmoid(self.linear2(x)) return x @@ -297,7 +333,7 @@ def test_int4_weight_only_e2e(self, group_size: int = 32): ProxHardQuant(), quant_per_channel=True, ) - compare_parq_convert(model, m_ref, optimizer) + compare_parq_convert(model, m_ref, optimizer, weight_only=True) @unittest.skipIf(_DEVICE == "cpu", "Need GPU available") @common_utils.parametrize("b", [2, 3, 4, 8]) @@ -399,6 +435,30 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32): compare_parq_convert(model, m_ref, optimizer, weight_only=True) check_torchao_tensor_subclass(self, model, weight_only=True) + @common_utils.parametrize("b", [2, 3]) + @common_utils.parametrize( + "model_dtype", [torch.float16, torch.float32, torch.bfloat16] + ) + def test_intx_weight_only_tied_embed_linear( + self, b: int = 2, model_dtype: torch.dtype = torch.float32 + ): + model = M(m=256, n=256, tied_weights=True).to(_DEVICE) + + quantizer = StretchedUnifTorchaoQuantizer(b) + base_optimizer = torch.optim.SGD(build_param_groups(model, b)) + optimizer = QuantOptimizer( + base_optimizer, quantizer, ProxHardQuant(), quant_per_channel=True + ) + optimizer.zero_grad() + optimizer.step() + + apply_activation_quantization(model, optimizer, model_dtype) + optimizer.torchao_convert(model) + check_torchao_tensor_subclass(self, model) + self.assertTrue( + torch.equal(model.embedding.weight.qdata, model.linear2.weight.qdata) + ) + class TestInt8DynamicActivationTorchaoQuantizer(common_utils.TestCase): def setUp(self): @@ -435,16 +495,12 @@ def test_int8_dynamic_activation_intx_e2e( optimizer = QuantOptimizer( base_optimizer, quantizer, ProxHardQuant(), quant_per_channel=True ) + optimizer.zero_grad() optimizer.step() - # apply torchao quantized activations on top - activation_config = IntxFakeQuantizeConfig( - torch.int8, "per_token", is_symmetric=False, scale_precision=model_dtype - ) - qat_config = QATConfig(activation_config=activation_config, step="prepare") - for filter_fn in optimizer.get_filter_fns(model): - quantize_(model, qat_config, filter_fn=filter_fn) + apply_activation_quantization(model, optimizer, model_dtype) + out = model(x) torch.testing.assert_close(out, ref_out, atol=0, rtol=0) @@ -462,7 +518,10 @@ def test_int8_dynamic_activation_intx_e2e( check_torchao_tensor_subclass(self, model) if attach_hf_config: - reg_param_names = {n for n, m in model.named_modules() if _is_linear(m)} + reg_param_names = { + n for n, m in model.named_modules() if isinstance(m, nn.Embedding) + } + reg_param_names.add("_default") module_fqn_to_config = ( model.config.quantization_config.quant_type.module_fqn_to_config ) diff --git a/torchao/core/config.py b/torchao/core/config.py index b72ee9d134..330e6a42af 100644 --- a/torchao/core/config.py +++ b/torchao/core/config.py @@ -196,6 +196,7 @@ def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]: "torchao.prototype.parq", "torchao.dtypes", "torchao.prototype.awq", + "torchao.prototype.parq.quant", "torchao.quantization.quantize_.common", "torchao.quantization.quantize_.workflows", } diff --git a/torchao/prototype/parq/optim/quantopt.py b/torchao/prototype/parq/optim/quantopt.py index 54fcaea3ab..bfa651dcc9 100644 --- a/torchao/prototype/parq/optim/quantopt.py +++ b/torchao/prototype/parq/optim/quantopt.py @@ -14,6 +14,7 @@ from torch.optim import Optimizer from torchao.quantization import quantize_ +from torchao.quantization.quant_api import _is_linear from ..quant import Quantizer, UnifTorchaoQuantizer from ..quant.config_torchao import ( @@ -158,17 +159,20 @@ def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None: self.restore_latent_params() # TODO(lvj): find more robust way to identify embedding layers - embed_data_ptrs = { - module.weight.data_ptr() - for module in model.modules() - if isinstance(module, nn.Embedding) - } + embed_data_ptrs = set() + linear_data_ptrs = set() + for module in model.modules(): + if isinstance(module, nn.Embedding): + embed_data_ptrs.add(module.weight.data_ptr()) + elif _is_linear(module) and module.weight.data_ptr() not in embed_data_ptrs: + linear_data_ptrs.add(module.weight.data_ptr()) filter_fns = [] configs = [] attach_hf_config = _is_hf_model(model) - for group, filter_fn in zip( - self.regularized_param_groups(), self.get_filter_fns(model) + all_linear_layers_idx = -1 + for i, (group, filter_fn) in enumerate( + zip(self.regularized_param_groups(), self.get_filter_fns(model)) ): filter_fns.append(filter_fn) quantizer = group.get("quantizer", self.quantizer) @@ -176,6 +180,9 @@ def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None: configs.append(None) continue + if set((p.data_ptr() for p in group["params"])) == linear_data_ptrs: + all_linear_layers_idx = i + device = group["params"][0].device any_embed = any(p.data_ptr() in embed_data_ptrs for p in group["params"]) config = _get_config_from_quantizer( @@ -187,10 +194,21 @@ def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None: ) configs.append(config) + filter_fns_orig = filter_fns[:] + configs_orig = configs[:] + + # If one group has all the linear layers, then set its config as default + if all_linear_layers_idx > -1: + module_to_config = {"_default": configs[all_linear_layers_idx]} + del filter_fns[all_linear_layers_idx] + del configs[all_linear_layers_idx] + else: + module_to_config = None + if attach_hf_config: - _attach_hf_quantization_config(model, filter_fns, configs) + _attach_hf_quantization_config(model, filter_fns, configs, module_to_config) - for config, filter_fn in zip(configs, filter_fns): + for config, filter_fn in zip(configs_orig, filter_fns_orig): quantize_(model, config, filter_fn=filter_fn) @torch._disable_dynamo diff --git a/torchao/prototype/parq/quant/config_torchao.py b/torchao/prototype/parq/quant/config_torchao.py index b546ecb328..2e2ffcba2e 100644 --- a/torchao/prototype/parq/quant/config_torchao.py +++ b/torchao/prototype/parq/quant/config_torchao.py @@ -8,27 +8,19 @@ from torchao.core.config import AOBaseConfig from torchao.dtypes import Int4CPULayout, Layout, QDQLayout from torchao.quantization import MappingType, PerAxis, PerGroup -from torchao.quantization.linear_activation_quantized_tensor import ( - to_linear_activation_quantized, -) from torchao.quantization.quant_api import ( Granularity, Int4WeightOnlyConfig, Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig, ModuleFqnToConfig, - _int8_asymm_per_token_quant, _linear_extra_repr, ) from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor from torchao.quantization.transform_module import register_quantize_module_handler from torchao.utils import check_cpu_version -from .quant_api import ( - choose_qparams_stretched_affine, - quantize_stretched_affine, - to_stretched_affine_quantized_intx, -) +from .quant_api import choose_qparams_stretched_affine, quantize_stretched_affine from .uniform_torchao import ( _BIT_WIDTH_TO_DTYPE, Int4UnifTorchaoQuantizer, @@ -63,6 +55,9 @@ def _int8_dynamic_activation_stretched_intx_transform( granularity = config.granularity mapping_type = MappingType.ASYMMETRIC + if config.version != 2: + raise NotImplementedError(f"Unsupported {config.version=}") + assert weight.dim() == 2, ( f"StretchedIntxWeightConfig only works for 2-d Tensor, got: {weight.dim()}" ) @@ -79,47 +74,33 @@ def _int8_dynamic_activation_stretched_intx_transform( block_size = (1, group_size) target_dtype = torch.int8 q_args = (weight, mapping_type, block_size, target_dtype, config.b) - if config.version == 2: - scale, zero_point = choose_qparams_stretched_affine( - *q_args, - quant_min=config.quant_min, - quant_max=config.quant_max, - ) - qdata = quantize_stretched_affine( - weight, - block_size, - scale, - zero_point, - target_dtype, - quant_min=config.quant_min, - quant_max=config.quant_max, - ) - n_blocks = [qdata.shape[i] // block_size[i] for i in range(len(block_size))] - scale = scale.reshape(*n_blocks) - zero_point = zero_point.reshape(*n_blocks) - - weight = IntxUnpackedToInt8Tensor( - qdata=qdata, - scale=scale, - zero_point=zero_point, - target_dtype=getattr(torch, f"int{config.b}"), - block_size=block_size, - dtype=weight.dtype, - activation_quantization=config.activation_quantization, - ) - else: - weight = to_stretched_affine_quantized_intx( - *q_args, - quant_min=config.quant_min, - quant_max=config.quant_max, - scale_dtype=config.scale_dtype, - _layout=config.layout, - ) - if config.activation_quantization == "int8_asym_per_token": - weight = to_linear_activation_quantized(weight, _int8_asymm_per_token_quant) - elif config.activation_quantization is not None: - raise ValueError(f"Unsupported {config.activation_quantization=}") - + scale, zero_point = choose_qparams_stretched_affine( + *q_args, + quant_min=config.quant_min, + quant_max=config.quant_max, + ) + qdata = quantize_stretched_affine( + weight, + block_size, + scale, + zero_point, + target_dtype, + quant_min=config.quant_min, + quant_max=config.quant_max, + ) + n_blocks = [qdata.shape[i] // block_size[i] for i in range(len(block_size))] + scale = scale.reshape(*n_blocks) + zero_point = zero_point.reshape(*n_blocks) + + weight = IntxUnpackedToInt8Tensor( + qdata=qdata, + scale=scale, + zero_point=zero_point, + target_dtype=getattr(torch, f"int{config.b}"), + block_size=block_size, + dtype=weight.dtype, + activation_quantization=config.activation_quantization, + ) module.weight = nn.Parameter(weight, requires_grad=False) if isinstance(module, nn.Linear): @@ -184,6 +165,7 @@ def _attach_hf_quantization_config( model: nn.Module, filter_fns: list[Callable[nn.Module, bool]], configs: list[AOBaseConfig], + module_to_config: Optional[dict[str, AOBaseConfig]] = None, ) -> None: """Attaches torchao quantization config(s) to Hugging Face model. @@ -202,11 +184,21 @@ def _attach_hf_quantization_config( "filter_fns and configs must have the same length" ) - module_to_config = {} + if module_to_config is None: + module_to_config = {} + + seen_data_ptrs = set() + modules_to_not_convert = [] for name, module in model.named_modules(): if not hasattr(module, "weight"): continue + data_ptr = module.weight.data_ptr() + if data_ptr in seen_data_ptrs: # do not re-quantize tied weight + modules_to_not_convert.append(name) + continue + seen_data_ptrs.add(data_ptr) + for i, filter_fn in enumerate(filter_fns): if filter_fn(module): module_to_config[name] = configs[i] @@ -214,5 +206,5 @@ def _attach_hf_quantization_config( model.config.quantization_config = TorchAoConfig( quant_type=ModuleFqnToConfig(module_to_config), include_input_output_embeddings=True, - modules_to_not_convert=[], + modules_to_not_convert=modules_to_not_convert, ) diff --git a/torchao/prototype/parq/quant/quant_api.py b/torchao/prototype/parq/quant/quant_api.py index 7931faa37c..608fd9570e 100644 --- a/torchao/prototype/parq/quant/quant_api.py +++ b/torchao/prototype/parq/quant/quant_api.py @@ -8,11 +8,8 @@ import torch -from torchao.dtypes import AffineQuantizedTensor, Layout, QDQLayout from torchao.quantization import ( MappingType, - ZeroPointDomain, - dequantize_affine, ) from torchao.quantization.quant_primitives import ( _SUB_BYTE_UINT_BOUNDS, @@ -96,80 +93,3 @@ def quantize_stretched_affine( quant = torch.round(input_float / scale + zero_point) quant = quant.to(dtype=target_dtype).view(original_shape) return quant - - -class StretchedAffineQuantizedTensor(AffineQuantizedTensor): - @classmethod - def from_hp_to_intx( - cls, - input_float: torch.Tensor, - mapping_type: MappingType, - block_size: Tuple[int, ...], - target_dtype: torch.dtype, - b: int, - quant_min: Optional[float] = None, - quant_max: Optional[float] = None, - scale_dtype: Optional[torch.dtype] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.FLOAT, - _layout: Layout = QDQLayout(), # noqa: B008 - ): - original_shape = input_float.shape - input_float = _layout.pre_process(input_float) - - scale, zero_point = choose_qparams_stretched_affine( - input_float, - mapping_type, - block_size, - target_dtype, - b, - quant_min=quant_min, - quant_max=quant_max, - ) - data = quantize_stretched_affine( - input_float, - block_size, - scale, - zero_point, - target_dtype, - quant_min=quant_min, - quant_max=quant_max, - ) - data, scale, zero_point = _layout.post_process( - data, scale, zero_point, block_size - ) - tensor_impl_ctr = cls.get_tensor_impl_constructor(type(_layout)) - tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout) - return cls( - tensor_impl, - block_size, - original_shape, - quant_min, - quant_max, - zero_point_domain, - dtype=input_float.dtype, - ) - - def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: - if output_dtype is None: - output_dtype = self.dtype - - if not isinstance(self._layout, QDQLayout): - raise NotImplementedError( - f"StretchedAffineQuantizedTensor only supports QDQLayout but got {self._layout}" - ) - - data, scale, zero_point = self.tensor_impl.get_plain() - dq = dequantize_affine( - data, - self.block_size, - scale, - zero_point, - data.dtype, - self.quant_min, - self.quant_max, - output_dtype=output_dtype, - ) - return dq - - -to_stretched_affine_quantized_intx = StretchedAffineQuantizedTensor.from_hp_to_intx