diff --git a/torchao/experimental/tests/test_embedding_xbit_quantizer.py b/torchao/experimental/tests/test_embedding_xbit_quantizer.py index f7ce5ef6d4..511a2d2c9f 100644 --- a/torchao/experimental/tests/test_embedding_xbit_quantizer.py +++ b/torchao/experimental/tests/test_embedding_xbit_quantizer.py @@ -9,6 +9,7 @@ import unittest import torch +from parameterized import param, parameterized from torch.testing import FileCheck from torchao.dtypes import ( @@ -19,8 +20,15 @@ SharedEmbeddingQuantizer, ) from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization.qat import ( + FakeQuantizeConfig, + FromIntXQuantizationAwareTrainingConfig, + Int4WeightOnlyEmbeddingQATQuantizer, + IntXQuantizationAwareTrainingConfig, +) from torchao.quantization.quant_api import ( Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, MappingType, quantize_, ) @@ -184,6 +192,184 @@ def test_shared_embedding(self): exported_program.graph_module.code ) + @parameterized.expand( + [ + param( + weight_dtype=weight_dtype, + granularity=granularity, + mapping_type=mapping_type, + model_dtype=model_dtype, + ) + for weight_dtype in [getattr(torch, f"int{x}") for x in range(1, 9)] + for granularity in [PerGroup(32), PerGroup(128), PerAxis(0)] + for mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC] + for model_dtype in [torch.float32] + ], + name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", + ) + def test_identical_to_IntxWeightOnlyConfig( + self, weight_dtype, granularity, mapping_type, model_dtype + ): + embedding_dim = 4096 + num_embeddings = 131 + model = torch.nn.Sequential( + *[torch.nn.Embedding(num_embeddings, embedding_dim)] + ) + model = model.to(model_dtype) + indices = torch.randint(0, num_embeddings, (7,), dtype=torch.int32) + + quantized_model = copy.deepcopy(model) + quantizer = EmbeddingQuantizer( + weight_dtype=weight_dtype, + granularity=granularity, + mapping_type=mapping_type, + ) + quantized_model = quantizer.quantize(quantized_model) + actual_result = quantized_model(indices) + + reference_model = copy.deepcopy(model) + quantize_( + reference_model, + IntxWeightOnlyConfig( + weight_dtype=weight_dtype, + granularity=granularity, + mapping_type=mapping_type, + scale_dtype=None, + ), + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + expected_result = reference_model(indices) + self.assertTrue(torch.allclose(actual_result, expected_result)) + + @parameterized.expand( + [ + param( + weight_dtype=weight_dtype, + granularity=granularity, + mapping_type=mapping_type, + scale_dtype=scale_dtype, + model_dtype=model_dtype, + ) + for weight_dtype in [getattr(torch, f"int{x}") for x in range(1, 9)] + for granularity in [PerGroup(32), PerGroup(128), PerAxis(0)] + for mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC] + for scale_dtype in [torch.float32, torch.bfloat16, torch.float16] + for model_dtype in [torch.float32, torch.bfloat16] + ], + name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", + ) + def test_identical_to_IntXQuantizationAwareTrainingConfig( + self, weight_dtype, granularity, mapping_type, scale_dtype, model_dtype + ): + # ASYMMETRIC in QAT is very different that PTQ configs + if mapping_type == MappingType.ASYMMETRIC: + return + + embedding_dim = 4096 + num_embeddings = 131 + model = torch.nn.Sequential( + *[torch.nn.Embedding(num_embeddings, embedding_dim)] + ) + model = model.to(model_dtype) + indices = torch.randint(0, num_embeddings, (7,), dtype=torch.int32) + + is_symmetric = mapping_type == MappingType.SYMMETRIC + group_size = ( + granularity.group_size + if isinstance(granularity, PerGroup) + else embedding_dim + ) + + embedding_filter = lambda m, fqn: isinstance(m, torch.nn.Embedding) + weight_config = FakeQuantizeConfig( + weight_dtype, + group_size=group_size, + is_symmetric=is_symmetric, + scale_precision=scale_dtype, + ) + quantize_( + model, + IntXQuantizationAwareTrainingConfig(weight_config=weight_config), + embedding_filter, + ) + expected_out = model(indices) + + quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter) + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=weight_dtype, + granularity=granularity, + mapping_type=mapping_type, + scale_dtype=scale_dtype, + ), + embedding_filter, + ) + actual_out = model(indices) + self.assertTrue(torch.allclose(expected_out, actual_out)) + + @parameterized.expand( + [ + param( + granularity=granularity, + scale_dtype=scale_dtype, + model_dtype=model_dtype, + ) + for granularity in [PerGroup(32), PerGroup(128), PerAxis(0)] + for scale_dtype in [torch.float32, torch.bfloat16, torch.float16] + for model_dtype in [torch.float32, torch.bfloat16] + ], + name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", + ) + def test_identical_to_Int4WeightOnlyEmbeddingQATQuantizer( + self, granularity, scale_dtype, model_dtype + ): + embedding_dim = 4096 + num_embeddings = 131 + model = torch.nn.Sequential( + *[torch.nn.Embedding(num_embeddings, embedding_dim)] + ) + model = model.to(model_dtype) + indices = torch.randint(0, num_embeddings, (7,), dtype=torch.int32) + + group_size = ( + granularity.group_size + if isinstance(granularity, PerGroup) + else embedding_dim + ) + + embedding_filter = lambda m, fqn: isinstance(m, torch.nn.Embedding) + + qat_quantizer = Int4WeightOnlyEmbeddingQATQuantizer( + group_size=group_size, + scale_precision=scale_dtype, + zero_point_precision=torch.int32, + ) + model = qat_quantizer.prepare(model) + expected_out = model(indices) + + # Convert model method 1 + quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter) + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=granularity, + mapping_type=MappingType.SYMMETRIC, + scale_dtype=scale_dtype, + ), + embedding_filter, + ) + actual_out1 = model(indices) + self.assertTrue(torch.allclose(expected_out, actual_out1)) + + # TODO: method 2 does not work because the converted embedding op + # incorrectly casts output of to indices.dtype + # Convert model method 2 + # qat_quantizer.convert(prepared_model_copy) + # actual_out2 = prepared_model_copy(indices) + # self.assertTrue(torch.allclose(expected_out, actual_out2)) + if __name__ == "__main__": unittest.main() diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index 3e94aa3bc0..b217aa349e 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -14,7 +14,14 @@ from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization.qat import ( + FakeQuantizeConfig, + FromIntXQuantizationAwareTrainingConfig, + Int8DynActInt4WeightQATQuantizer, + IntXQuantizationAwareTrainingConfig, +) from torchao.quantization.quant_api import ( + Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationIntxWeightConfig, MappingType, quantize_, @@ -418,6 +425,216 @@ def test_moved_error(self): granularity=PerGroup(64), ) + @parameterized.expand( + [ + param( + group_size=group_size, + mapping_type=mapping_type, + act_mapping_type=act_mapping_type, + ) + for group_size, mapping_type, act_mapping_type in zip( + [32, 64], + [MappingType.ASYMMETRIC, MappingType.SYMMETRIC], + [MappingType.ASYMMETRIC, MappingType.SYMMETRIC], + ) + ], + name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", + ) + def test_identical_to_Int8DynamicActivationInt4WeightConfig( + self, group_size, mapping_type, act_mapping_type + ): + """ + Checks that Int8DynamicActivationIntxWeightConfig with weight_dtype=torch.int4 is identical to Int8DynamicActivationInt4WeightConfig + """ + k0 = 512 + k1 = 256 + layers = [ + torch.nn.Linear(k0, k1), + ] + model = torch.nn.Sequential(*layers) + activations = torch.randn(3, 1, k0) + + model_copy = copy.deepcopy(model) + + quantize_( + model, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(group_size), + weight_mapping_type=mapping_type, + weight_scale_dtype=None, + act_mapping_type=act_mapping_type, + ), + ) + quantize_( + model_copy, + Int8DynamicActivationInt4WeightConfig( + group_size=group_size, + mapping_type=mapping_type, + act_mapping_type=act_mapping_type, + ), + ) + with torch.no_grad(): + torch.allclose(model(activations), model_copy(activations)) + + @parameterized.expand( + [ + param( + weight_dtype=weight_dtype, + group_size=group_size, + mapping_type=mapping_type, + act_mapping_type=act_mapping_type, + scale_dtype=scale_dtype, + model_dtype=model_dtype, + ) + for weight_dtype in list(getattr(torch, f"int{x}") for x in range(1, 9)) + for group_size in [32, 64, 128] + for mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC] + for act_mapping_type in [MappingType.ASYMMETRIC, MappingType.SYMMETRIC] + for scale_dtype in [torch.float32, torch.bfloat16, torch.float16] + for model_dtype in [torch.float32, torch.bfloat16] + ], + name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", + ) + def test_identical_to_IntXQuantizationAwareTrainingConfig( + self, + weight_dtype, + group_size, + mapping_type, + act_mapping_type, + scale_dtype, + model_dtype, + ): + # TODO: the QAT logic for asymmetric mapping is very different from PTQ, so we don't test that case here + # Unify the two? + if mapping_type == MappingType.ASYMMETRIC: + return + + # TODO: QAT logic for non-float32 models does not match PTQ right now + # QAT's default scale-precision is float32, but PTQ's is None (which defaults to input's dtype) + if model_dtype != torch.float32: + return + + assert mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC] + assert act_mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC] + is_symmetric = mapping_type == MappingType.SYMMETRIC + is_act_symmetric = act_mapping_type == MappingType.SYMMETRIC + + k0 = 512 + k1 = 256 + layers = [ + torch.nn.Linear(k0, k1), + ] + model = torch.nn.Sequential(*layers) + activations = torch.randn( + k0, + ) + + model = model.to(model_dtype) + activations = activations.to(model_dtype) + + activation_config = FakeQuantizeConfig( + torch.int8, + "per_token", + is_symmetric=is_act_symmetric, + ) + weight_config = FakeQuantizeConfig( + weight_dtype, + group_size=group_size, + is_symmetric=is_symmetric, + scale_precision=scale_dtype, + ) + + quantize_( + model, + IntXQuantizationAwareTrainingConfig(activation_config, weight_config), + ) + try: + expected_out = model(activations) + except NotImplementedError as e: + # QAT does not support act_mapping_type == MappingType.SYMMETRIC yet + if act_mapping_type == MappingType.SYMMETRIC: + return + raise e + + quantize_(model, FromIntXQuantizationAwareTrainingConfig()) + quantize_( + model, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=weight_dtype, + weight_granularity=PerGroup(group_size), + weight_mapping_type=mapping_type, + weight_scale_dtype=scale_dtype, + act_mapping_type=act_mapping_type, + ), + ) + actual_out = model(activations) + self.assertTrue(torch.allclose(expected_out, actual_out)) + + @parameterized.expand( + [ + param( + group_size=group_size, + scale_dtype=scale_dtype, + model_dtype=model_dtype, + ) + for group_size in [32, 64, 128] + for scale_dtype in [torch.float32, torch.bfloat16, torch.float16] + for model_dtype in [torch.float32, torch.bfloat16] + ], + name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", + ) + def test_identical_to_Int8DynActInt4WeightQATQuantizer( + self, group_size, scale_dtype, model_dtype + ): + # Currently this does not match + # TODO: investigat + if scale_dtype != torch.float32: + return + if model_dtype != torch.float32: + return + + k0 = 512 + k1 = 256 + layers = [ + torch.nn.Linear(k0, k1), + ] + model = torch.nn.Sequential(*layers) + activations = torch.randn( + k0, + ) + + model = model.to(model_dtype) + activations = activations.to(model_dtype) + + qat_quantizer = Int8DynActInt4WeightQATQuantizer( + groupsize=group_size, precision=model_dtype, scales_precision=scale_dtype + ) + model = qat_quantizer.prepare(model) + expected_out = model(activations) + + prepared_model_copy = copy.deepcopy(model) + + # Convert model method 1 + quantize_(model, FromIntXQuantizationAwareTrainingConfig()) + quantize_( + model, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(group_size), + weight_mapping_type=MappingType.SYMMETRIC, + weight_scale_dtype=scale_dtype, + act_mapping_type=MappingType.ASYMMETRIC, + ), + ) + actual_out1 = model(activations) + self.assertTrue(torch.allclose(expected_out, actual_out1)) + + # Convert model method 2 + qat_quantizer.convert(prepared_model_copy) + actual_out2 = prepared_model_copy(activations) + self.assertTrue(torch.allclose(expected_out, actual_out2)) + if __name__ == "__main__": unittest.main()