diff --git a/test/prototype/module_swap_quantization/test_kmeans_codebook.py b/test/prototype/module_swap_quantization/test_kmeans_codebook.py new file mode 100644 index 0000000000..c0f71fc788 --- /dev/null +++ b/test/prototype/module_swap_quantization/test_kmeans_codebook.py @@ -0,0 +1,56 @@ +import copy +import unittest +from typing import Union + +import torch +import torch.nn as nn + +from torchao.prototype.quantization.module_swap import ( + CodeBookQuantizer, + QuantizedLinear, +) +from torchao.prototype.quantization.module_swap.algorithms import kmeans_codebook + + +class SimpleTestNetwork(nn.Module): + def __init__(self, weight_group_size: Union[int, str] = "per_channel") -> None: + super().__init__() + if weight_group_size == "per_channel": + weight_group_size = 8 + assert isinstance(weight_group_size, int) + weight_quantizer = CodeBookQuantizer( + n_bits=2, + features=16, + codebook_dim=2, + ) + + self.linear = QuantizedLinear( + in_features=16, + out_features=8, + bias=False, + weight_quantizer=weight_quantizer, + activation_bits=8, + input_quantization=False, + output_quantization=False, + weight_quantization=True, + activation_quantization=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +class TestKmeansCodebook(unittest.TestCase): + @unittest.skip("No module named 'faiss'") + def test_kmeans_codebook(self) -> None: + model = SimpleTestNetwork() + codebook_before = copy.deepcopy(model.linear.weight_quantizer.codebook) + kmeans_codebook(model) + assert not torch.allclose( + codebook_before, + model.linear.weight_quantizer.codebook, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/prototype/module_swap_quantization/test_llm_ptq_data_getter.py b/test/prototype/module_swap_quantization/test_llm_ptq_data_getter.py new file mode 100644 index 0000000000..89bc1d5775 --- /dev/null +++ b/test/prototype/module_swap_quantization/test_llm_ptq_data_getter.py @@ -0,0 +1,35 @@ +import unittest +from typing import Tuple + +import torch +from transformers.models.llama.modeling_llama import LlamaConfig, LlamaForCausalLM + +from torchao.prototype.quantization.module_swap.data_getters import LLMPTQDataGetter + +test_config = LlamaConfig( + vocab_size=10, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=2, + intermediate_size=64, +) + + +def get_test_llama_model_data() -> Tuple[LlamaForCausalLM, torch.Tensor]: + model = LlamaForCausalLM(test_config) + input_ids = torch.randint(0, test_config.vocab_size, (1, 10)) + return model, input_ids + + +class TestPTQDataGetter(unittest.TestCase): + @unittest.skip("TypeError: cannot unpack non-iterable NoneType object") + def test_data_getter(self) -> None: + model, data = get_test_llama_model_data() + data_getter = LLMPTQDataGetter(model, data, 1) + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + data = data_getter.pop(model, name) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/prototype/module_swap_quantization/test_module_swap.py b/test/prototype/module_swap_quantization/test_module_swap.py new file mode 100644 index 0000000000..24612d55d2 --- /dev/null +++ b/test/prototype/module_swap_quantization/test_module_swap.py @@ -0,0 +1,35 @@ +import unittest + +import torch +import torch.nn as nn + +from torchao.prototype.quantization.module_swap import ( + QuantizationRecipe, + quantize_module_swap, +) + + +class SimpleEmbeddingTestNetwork(nn.Module): + def __init__(self) -> None: + super().__init__() + self.embedding = nn.Embedding(10, 64) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.embedding(x) + + +class TestEmbeddingSwap(unittest.TestCase): + def test_embedding_swap(self) -> None: + model = SimpleEmbeddingTestNetwork() + recipe = QuantizationRecipe() + recipe.embedding_bits = 4 + recipe.embedding_quantization = True + model = quantize_module_swap(model, recipe) + x = torch.randint(0, 10, (10, 64)) + model(x) + assert model.embedding.weight_quantizer.num_bits == 4 + assert model.embedding.weight_quantizer.group_size == 32 + + +if __name__ == "__main__": + unittest.main() diff --git a/test/prototype/module_swap_quantization/test_module_swap_quantization_utils.py b/test/prototype/module_swap_quantization/test_module_swap_quantization_utils.py new file mode 100644 index 0000000000..7c8efbbc8e --- /dev/null +++ b/test/prototype/module_swap_quantization/test_module_swap_quantization_utils.py @@ -0,0 +1,65 @@ +import unittest + +import torch +from transformers.models.llama.modeling_llama import LlamaConfig, LlamaForCausalLM + +from torchao.prototype.quantization.module_swap import QuantizedLinear +from torchao.prototype.quantization.module_swap.module_swap import ( + QuantizationRecipe, + replace_all_linear_with_quantized_linear, +) +from torchao.prototype.quantization.module_swap.utils import set_bit_widths_by_name + +test_config = LlamaConfig( + vocab_size=10, + hidden_size=32, + num_hidden_layers=1, + num_attention_heads=2, + intermediate_size=64, +) + +base_recipe = QuantizationRecipe( + weight_bits=4, + weight_group_size=32, + weight_quantization=True, + dynamic_weights=False, + activation_bits=8, + activation_group_size="per_token", + activation_quantization=True, + input_quantization=True, + output_quantization=True, + dynamic_activations=True, + range_learning=False, + exclude_layers=["lm_head"], +) + + +def get_test_llama_model_data() -> tuple[LlamaForCausalLM, torch.Tensor]: + model = LlamaForCausalLM(test_config) + input_ids = torch.randint(0, test_config.vocab_size, (1, 10)) + return model, input_ids + + +class TestQuantizedModuleUtils(unittest.TestCase): + def test_set_bit_widths_by_name(self) -> None: + model, _ = get_test_llama_model_data() + replace_all_linear_with_quantized_linear(model, base_recipe) + + bit_width_dict = {} + for name, module in model.named_modules(): + if isinstance(module, QuantizedLinear): + bit_width_dict[name] = {"weight": 7, "activation": 9} + + set_bit_widths_by_name(model, bit_width_dict) + + for _, module in model.named_modules(): + if isinstance(module, QuantizedLinear): + assert module.weight_quantizer.num_bits == 7 + assert module.input_quantizer is not None + assert module.input_quantizer.num_bits == 9 + assert module.output_quantizer is not None + assert module.output_quantizer.num_bits == 9 + + +if __name__ == "__main__": + unittest.main() diff --git a/test/prototype/module_swap_quantization/test_quantized_modules.py b/test/prototype/module_swap_quantization/test_quantized_modules.py new file mode 100644 index 0000000000..62a1c82303 --- /dev/null +++ b/test/prototype/module_swap_quantization/test_quantized_modules.py @@ -0,0 +1,369 @@ +import unittest +from itertools import product + +import torch + +from torchao.prototype.quantization.module_swap import ( + IntQuantizer, + QuantizedEmbedding, + QuantizedLinear, +) + + +class TestQuantizedLinear(unittest.TestCase): + def test_quantized_linear_init(self) -> None: + in_features = 16 + out_features = 8 + weight_bits = 8 + weight_group_size = "per_channel" + activation_bits = 8 + activation_group_size = "per_token" + input_quantization = True + output_quantization = False + weight_quantization = True + activation_quantization = True + dynamic_weights = False + range_learning = False + scale_eps = 1e-6 + weight_quantizer = IntQuantizer( + num_bits=weight_bits, + group_size=weight_group_size, + dynamic=dynamic_weights, + quantization_mode="symmetric", + range_learning=range_learning, + ) + QuantizedLinear( + in_features=in_features, + out_features=out_features, + weight_quantizer=weight_quantizer, + activation_bits=activation_bits, + activation_group_size=activation_group_size, + input_quantization=input_quantization, + output_quantization=output_quantization, + weight_quantization=weight_quantization, + activation_quantization=activation_quantization, + scale_eps=scale_eps, + ) + + def test_quantized_linear(self) -> None: + for ( + weight_group_size, + activation_group_size, + input_quantization, + output_quantization, + weight_quantization, + dynamic_weights, + ) in product( + ["per_channel", "per_tensor", 4], + ["per_token", "per_tensor", 4], + [True, False], + [True, False], + [True, False], + [True, False], + ): + for x in [ + torch.FloatTensor(torch.randn(2, 16)), + torch.FloatTensor(torch.randn(2, 2, 16)), + ]: + weight_quantizer = IntQuantizer( + num_bits=4, + group_size=weight_group_size, + dynamic=dynamic_weights, + quantization_mode="symmetric", + range_learning=False, + ) + linear = QuantizedLinear( + in_features=16, + out_features=8, + weight_quantizer=weight_quantizer, + activation_bits=8, + activation_group_size=activation_group_size, + input_quantization=input_quantization, + output_quantization=output_quantization, + weight_quantization=weight_quantization, + activation_quantization=True, + ) + if not dynamic_weights: + assert isinstance(linear.weight_quantizer, IntQuantizer) + linear.weight_quantizer.set_scale_offset_to_min_max(linear.weight) + linear(x) + + def test_quantized_linear_passes_gradients(self) -> None: + for ( + weight_group_size, + activation_group_size, + input_quantization, + output_quantization, + weight_quantization, + dynamic_weights, + ) in product( + ["per_channel", "per_tensor", 4], + ["per_token", "per_tensor", 4], + [True, False], + [True, False], + [True, False], + [True, False], + ): + for x in [ + torch.FloatTensor(torch.randn(2, 16)), + torch.FloatTensor(torch.randn(2, 2, 16)), + ]: + x = x.requires_grad_(True) + weight_quantizer = IntQuantizer( + num_bits=4, + group_size=weight_group_size, + dynamic=dynamic_weights, + quantization_mode="symmetric", + range_learning=False, + ) + linear = QuantizedLinear( + in_features=16, + out_features=8, + weight_quantizer=weight_quantizer, + activation_bits=8, + activation_group_size=activation_group_size, + input_quantization=input_quantization, + output_quantization=output_quantization, + weight_quantization=weight_quantization, + activation_quantization=True, + ) + if not dynamic_weights: + assert isinstance(linear.weight_quantizer, IntQuantizer) + linear.weight_quantizer.set_scale_offset_to_min_max(linear.weight) + y = linear(x) + (y.sum() ** 2).backward() + assert linear.weight.grad is not None + assert x.grad is not None + + def test_quantized_linear_passes_gradients_to_weight_scale(self) -> None: + in_features = 16 + out_features = 8 + weight_bits = 8 + weight_group_size = "per_channel" + activation_bits = 8 + activation_group_size = "per_token" + input_quantization = True + output_quantization = False + weight_quantization = True + activation_quantization = True + scale_eps = 1e-6 + weight_quantizer = IntQuantizer( + num_bits=weight_bits, + group_size=weight_group_size, + dynamic=False, + quantization_mode="symmetric", + range_learning=True, + ) + linear = QuantizedLinear( + in_features=in_features, + out_features=out_features, + weight_quantizer=weight_quantizer, + activation_bits=activation_bits, + activation_group_size=activation_group_size, + input_quantization=input_quantization, + output_quantization=output_quantization, + weight_quantization=weight_quantization, + activation_quantization=activation_quantization, + scale_eps=scale_eps, + range_learning=True, + ) + x = torch.FloatTensor(torch.randn(2, 6, 16)) + y = linear(x) + loss = y.sum() + loss.backward() + assert linear.weight_quantizer.scale is not None + scale = linear.weight_quantizer.scale + assert isinstance(scale, torch.Tensor) + assert scale.grad is not None + + def test_quantized_linear_passes_gradients_to_activation_scale(self) -> None: + in_features = 16 + out_features = 8 + weight_bits = 8 + weight_group_size = "per_channel" + activation_bits = 8 + activation_group_size = "per_tensor" + input_quantization = True + output_quantization = False + weight_quantization = False + activation_quantization = True + scale_eps = 1e-6 + weight_quantizer = IntQuantizer( + num_bits=weight_bits, + group_size=weight_group_size, + dynamic=False, + quantization_mode="symmetric", + range_learning=True, + ) + linear = QuantizedLinear( + in_features=in_features, + out_features=out_features, + weight_quantizer=weight_quantizer, + activation_bits=activation_bits, + activation_group_size=activation_group_size, + input_quantization=input_quantization, + output_quantization=output_quantization, + weight_quantization=weight_quantization, + activation_quantization=activation_quantization, + scale_eps=scale_eps, + dynamic_activations=False, + range_learning=True, + ) + x = torch.FloatTensor(torch.randn(2, 8, 16)) + assert linear.input_quantizer is not None + linear.input_quantizer.set_scale_offset_to_min_max(x) + y = linear(x) + loss = y.sum() + loss.backward() + assert linear.input_quantizer is not None + assert linear.input_quantizer._range_learning is True + + assert linear.input_quantizer is not None + assert linear.input_quantizer.scale is not None + assert linear.input_quantizer.offset is not None + assert linear.input_quantizer.scale.requires_grad is True + + assert ( + linear.input_quantizer.scale.grad is not None + ), linear.input_quantizer.scale + assert linear.input_quantizer.offset.grad is not None + + def test_set_weight_scale_to_min_max_test_all_options(self) -> None: + for ( + x, + weight_group_size, + ) in product( + [ + torch.FloatTensor(torch.randn(2, 16)), + torch.FloatTensor(torch.randn(2, 2, 16)), + ], + ["per_channel", "per_tensor", 4], + ): + x = x.requires_grad_(True) + weight_quantizer = IntQuantizer( + num_bits=4, + group_size=weight_group_size, + dynamic=False, + quantization_mode="symmetric", + range_learning=True, + ) + linear = QuantizedLinear( + in_features=16, + out_features=8, + weight_quantizer=weight_quantizer, + activation_bits=8, + activation_group_size="per_token", + input_quantization=False, + output_quantization=False, + weight_quantization=True, + activation_quantization=False, + ) + assert isinstance(linear.weight_quantizer, IntQuantizer) + linear.weight_quantizer.set_scale_offset_to_min_max(linear.weight) + + def test_set_weight_scale_to_min_max_test_correct(self) -> None: + weight_group_size = "per_channel" + + weight_quantizer = IntQuantizer( + num_bits=4, + group_size=weight_group_size, + dynamic=False, + quantization_mode="symmetric", + range_learning=True, + ) + linear = QuantizedLinear( + in_features=16, + out_features=1, + weight_quantizer=weight_quantizer, + activation_bits=8, + activation_group_size="per_token", + input_quantization=False, + output_quantization=False, + weight_quantization=True, + activation_quantization=False, + ) + + linear.weight.data = torch.ones_like(linear.weight.data) * 7 + linear.weight.data[0][0] = -8 + + assert isinstance(linear.weight_quantizer, IntQuantizer) + linear.weight_quantizer.set_scale_offset_to_min_max(linear.weight) + assert linear.weight_quantizer.scale is not None + scale = linear.weight_quantizer.scale + assert isinstance(scale, torch.Tensor) + assert torch.allclose(scale, torch.FloatTensor([1.0])) + + def test_quantize_dynamic(self) -> None: + weight_quantizer = IntQuantizer( + num_bits=4, + group_size="per_channel", + dynamic=False, + quantization_mode="symmetric", + range_learning=True, + ) + linear = QuantizedLinear( + in_features=32, + out_features=64, + weight_quantizer=weight_quantizer, + activation_bits=8, + activation_group_size="per_token", + input_quantization=True, + output_quantization=False, + weight_quantization=True, + activation_quantization=True, + ) + + x = torch.FloatTensor([0, 5, 10, 15]) + assert linear.input_quantizer is not None + linear.input_quantizer(x) + + torch.testing.assert_close(x, torch.FloatTensor([0, 5, 10, 15])) + + def test_quantize_dynamic_vectorized(self) -> None: + weight_quantizer = IntQuantizer( + num_bits=4, + group_size="per_channel", + dynamic=False, + quantization_mode="symmetric", + range_learning=True, + ) + linear = QuantizedLinear( # noqa + in_features=32, + out_features=64, + weight_quantizer=weight_quantizer, + activation_bits=8, + activation_group_size="per_token", + input_quantization=True, + output_quantization=False, + weight_quantization=True, + activation_quantization=True, + ) + + x = torch.FloatTensor([0, 5, 10, 15, 0, 10, 20, 30]).reshape([2, 4]) + assert linear.input_quantizer is not None + linear.input_quantizer(x) + + torch.testing.assert_close( + x, torch.FloatTensor([0, 5, 10, 15, 0, 10, 20, 30]).reshape([2, 4]) + ) + + +class TestQuantizedEmbedding(unittest.TestCase): + def test_quantized_embedding(self) -> None: + for weight_group_size in ["per_channel", "per_tensor", 4]: + linear = QuantizedEmbedding( + num_embeddings=16, + embedding_dim=12, + num_bits=4, + group_size=weight_group_size, + quantization_mode="symmetric", + ) + x = torch.Tensor(torch.zeros(2, 16).to(torch.int32)) + x[0][0] = 1 + x[1][0] = 1 + linear.weight_quantizer.set_scale_offset_to_min_max(linear.weight) + linear(x) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/prototype/module_swap_quantization/test_quantizers.py b/test/prototype/module_swap_quantization/test_quantizers.py new file mode 100644 index 0000000000..883e95fd9a --- /dev/null +++ b/test/prototype/module_swap_quantization/test_quantizers.py @@ -0,0 +1,163 @@ +import unittest + +import torch + +from torchao.prototype.quantization.module_swap import IntQuantizer + + +class TestIntQuantizer(unittest.TestCase): + def test_get_scale_param_size(self) -> None: + x = torch.FloatTensor([0, 5, 10, 15]) + group_size = 4 + scale_param_size = IntQuantizer.get_scale_param_size(x, group_size) + assert scale_param_size == torch.Size([1]) + + x = torch.FloatTensor([0, 5, 10, 15]) + group_size = 2 + scale_param_size = IntQuantizer.get_scale_param_size(x, group_size) + assert scale_param_size == torch.Size([2]) + + x = torch.FloatTensor([0, 5, 10, 15, 20, 25, 30, 35]).reshape(2, 4) + group_size = 4 + scale_param_size = IntQuantizer.get_scale_param_size(x, group_size) + assert scale_param_size == torch.Size([2, 1]) + + x = torch.FloatTensor([0, 5, 10, 15, 20, 25, 30, 35]).reshape(2, 4) + group_size = 2 + scale_param_size = IntQuantizer.get_scale_param_size(x, group_size) + assert scale_param_size == torch.Size([2, 2]) + + def test_get_qmin_qmax(self) -> None: + qmin, qmax = IntQuantizer.get_qmin_qmax(4, signed=False) + assert qmin == 0 + assert qmax == 15 + + qmin, qmax = IntQuantizer.get_qmin_qmax(4, signed=True) + assert qmin == -8 + assert qmax == 7 + + def test_get_scale_offset_asymmetric(self) -> None: + x = torch.FloatTensor([0, 5, 10, 15]) + group_size = 4 + quantization_mode = "asymmetric" + q_min = 0 + q_max = 15 + scale, offset = IntQuantizer.get_scale_offset( + x, group_size, quantization_mode, q_min, q_max + ) + assert scale == 1 + assert offset == 0 + + def test_get_scale_offset_symmetric(self) -> None: + x = torch.FloatTensor([-8, -5, -3, -1, 0, 1, 3, 5, 7]) + group_size = 9 + quantization_mode = "symmetric" + q_min = -8 + q_max = 7 + scale, offset = IntQuantizer.get_scale_offset( + x, group_size, quantization_mode, q_min, q_max + ) + assert scale == 1 + assert offset is None + + def test_quantize_forward(self) -> None: + x = torch.FloatTensor([0, 5, 10, 15]) + scale = torch.FloatTensor([1]) + offset = torch.FloatTensor([0]) + group_size = 4 + q_min = 0 + q_max = 15 + output = IntQuantizer.quantize_forward( + x, scale, offset, q_min, q_max, group_size + ) + + torch.testing.assert_close(output, torch.FloatTensor([0, 5, 10, 15])) + + def test_quantize_forward_asymmetric_clipping(self) -> None: + x = torch.FloatTensor([0, 5, 10, 100]) + scale = torch.FloatTensor([1]) + offset = torch.FloatTensor([0]) + group_size = 4 + q_min = 0 + q_max = 15 + output = IntQuantizer.quantize_forward( + x, scale, offset, q_min, q_max, group_size + ) + + torch.testing.assert_close(output, torch.FloatTensor([0, 5, 10, 15])) + + def test_quantize_forward_symmetric(self) -> None: + x = torch.FloatTensor([0, 1, 2, 3]) + scale = torch.FloatTensor([1.0]) + group_size = 4 + q_min = -8 + q_max = 7 + output = IntQuantizer.quantize_forward( + x, scale, offset=None, group_size=group_size, q_min=q_min, q_max=q_max + ) + + torch.testing.assert_close(output, torch.FloatTensor([0, 1, 2, 3])) + + def test_quantize_forward_symmetric_clipping(self) -> None: + x = torch.FloatTensor([0, 1, 2, 10]) + scale = torch.FloatTensor([1.0]) + group_size = 4 + q_min = -8 + q_max = 7 + output = IntQuantizer.quantize_forward( + x, scale, offset=None, group_size=group_size, q_min=q_min, q_max=q_max + ) + + torch.testing.assert_close(output, torch.FloatTensor([0, 1, 2, 7])) + + def test_get_scale_offset_from_min_max(self) -> None: + x_min = torch.FloatTensor([-8]) + x_max = torch.FloatTensor([7]) + q_min = 0 + q_max = 15 + scale, offset = IntQuantizer.get_scale_offset_from_min_max( + x_min, x_max, q_min, q_max + ) + assert scale == 1 + assert offset == 8 + + def test_get_scale_offset_from_min_max_tensorized(self) -> None: + x_min = torch.FloatTensor([-8, 0]) + x_max = torch.FloatTensor([7, 15]) + q_min = 0 + q_max = 15 + scale, offset = IntQuantizer.get_scale_offset_from_min_max( + x_min, x_max, q_min, q_max + ) + assert torch.allclose(scale, torch.FloatTensor([1, 1])) + assert torch.allclose(offset, torch.FloatTensor([8, 0])) + + def test_get_scale_from_min_max(self) -> None: + x_min = torch.FloatTensor([-8]) + x_max = torch.FloatTensor([7]) + q_min = -8 + q_max = 7 + scale = IntQuantizer.get_scale_from_min_max(x_min, x_max, q_min, q_max) + + assert scale == 1 + + def test_get_scale_from_min_max_vectorized(self) -> None: + x_min = torch.FloatTensor([-8, -16]) + x_max = torch.FloatTensor([7, 14]) + q_min = -8 + q_max = 7 + scale = IntQuantizer.get_scale_from_min_max(x_min, x_max, q_min, q_max) + + assert torch.allclose(scale, torch.FloatTensor([1, 2])) + + +class TestCodebookQuantizer(unittest.TestCase): + def test_codebook_quantizer(self) -> None: + pass + + def test_vector_quantizer(self) -> None: + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/test/prototype/module_swap_quantization/test_range_setting_methods.py b/test/prototype/module_swap_quantization/test_range_setting_methods.py new file mode 100644 index 0000000000..bbb497d8a9 --- /dev/null +++ b/test/prototype/module_swap_quantization/test_range_setting_methods.py @@ -0,0 +1,158 @@ +import copy +import unittest +from typing import Union + +import torch +from torch import nn + +from torchao.prototype.quantization.module_swap import ( + IntQuantizer, + QuantizedLinear, +) +from torchao.prototype.quantization.module_swap.range_setting_methods import ( + quantize_per_group_scales, + set_activation_min_max, + set_weight_min_max, + set_weight_mse, + set_weight_range_activation_loss, +) + + +class SimpleTestNetwork(nn.Module): + def __init__(self, weight_group_size: Union[int, str] = "per_channel") -> None: + super().__init__() + weight_quantizer = IntQuantizer( + num_bits=4, + group_size=weight_group_size, + dynamic=False, + quantization_mode="symmetric", + range_learning=False, + ) + self.linear = QuantizedLinear( + in_features=16, + out_features=8, + weight_quantizer=weight_quantizer, + bias=False, + activation_bits=8, + input_quantization=False, + output_quantization=False, + weight_quantization=True, + activation_quantization=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +class SimpleTestNetworkStaticActivation(nn.Module): + def __init__(self, weight_group_size: Union[int, str] = "per_channel") -> None: + super().__init__() + weight_quantizer = IntQuantizer( + num_bits=4, + group_size=weight_group_size, + dynamic=False, + quantization_mode="symmetric", + range_learning=False, + ) + self.linear = QuantizedLinear( + in_features=16, + out_features=8, + weight_quantizer=weight_quantizer, + bias=False, + activation_bits=8, + input_quantization=True, + output_quantization=True, + weight_quantization=True, + activation_quantization=True, + dynamic_activations=False, + activation_group_size="per_tensor", + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +class TestSetWeightMinMax(unittest.TestCase): + def test_set_weight_min_max(self) -> None: + model = SimpleTestNetwork() + + set_weight_min_max(model) + + def test_set_weight_min_max_grouped(self) -> None: + model = SimpleTestNetwork(weight_group_size=8) + + set_weight_min_max(model) + + +class TestSetWeightMSE(unittest.TestCase): + def test_set_weight_mse(self) -> None: + model = SimpleTestNetwork() + set_weight_mse(model, num_points=5) + + def test_set_weight_mse_grouped(self) -> None: + model = SimpleTestNetwork(weight_group_size=8) + set_weight_mse(model, num_points=5) + + +class TestSetWeightRangeActivationLoss(unittest.TestCase): + def test_set_weight_range_activation_loss(self) -> None: + model = SimpleTestNetwork() + test_data = torch.rand(2, 16) + set_weight_range_activation_loss( + model, + test_data, + batch_size=1, + num_points=5, + progressive=False, + ) + + def test_set_weight_range_activation_loss_progressive(self) -> None: + model = SimpleTestNetwork(weight_group_size=8) + test_data = torch.rand(2, 16) + set_weight_range_activation_loss( + model, + test_data, + batch_size=1, + num_points=5, + progressive=True, + ) + + +class TestStaticActivationRangeSetting(unittest.TestCase): + def test_static_activation_range_setting(self) -> None: + model = SimpleTestNetworkStaticActivation() + + test_data = torch.rand(2, 16) + set_activation_min_max(model, test_data, batch_size=1) + + def test_static_activation_range_setting_no_input(self) -> None: + model = SimpleTestNetworkStaticActivation() + + test_data = torch.rand(2, 16) + set_activation_min_max(model, test_data, batch_size=1) + + +class TestQuantizePerGroupScales(unittest.TestCase): + def test_quantize_per_group_scales(self) -> None: + model = SimpleTestNetwork(weight_group_size=8) + + set_weight_min_max(model) + assert model.linear.weight_quantizer.scale is not None + scale_before = copy.deepcopy(model.linear.weight_quantizer.scale) + quantize_per_group_scales(model, bit_width=4) + assert model.linear.weight_quantizer.scale is not None + assert not torch.allclose(scale_before, model.linear.weight_quantizer.scale) + + def test_quantize_per_group_scales_dont_change_per_channel(self) -> None: + model = SimpleTestNetwork(weight_group_size="per_channel") + + set_weight_min_max(model) + assert model.linear.weight_quantizer.scale is not None + scale_before = copy.deepcopy(model.linear.weight_quantizer.scale) + quantize_per_group_scales(model, bit_width=4) + assert model.linear.weight_quantizer.scale is not None + assert torch.allclose(scale_before, model.linear.weight_quantizer.scale) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/prototype/quantization/module_swap/README.md b/torchao/prototype/quantization/module_swap/README.md new file mode 100644 index 0000000000..af2035a06a --- /dev/null +++ b/torchao/prototype/quantization/module_swap/README.md @@ -0,0 +1,67 @@ +# Module Swap Quantization (prototype) + +This is an alternative to quantization based on tensor subclasses, +bypassing the entire `AffineQuantizedTensor` stack for simplicity. +Quantized modules supported today include: + +``` +torch.nn.Linear -> QuantizedLinear +torch.nn.Embedding -> QuantizedEmbedding +``` + +Within each of these quantized modules, the user can specify different +quantization settings to quantize the weights and the activations +separately. For example: + +``` +quantized_linear = QuantizedLinear(...) +quantized_linear.input_quantization = IntQuantizer(...) +quantized_linear.weight_quantization = CodeBookQuantizer(...) +``` + +The current entry point API is `quantize_module_swap`, which takes +in a `QuantizationRecipe` and performs module swap on the model, +applying the configured quantizers to weights and activations on +the swapped quantized modules. However, **this API is highly subject +to change and will be replaced by `quantize_` in the future**. +Example usage: + +``` +import torch +import torch.nn as nn +from torchao.prototype.quantization.module_swap import ( + quantize_module_swap, + QuantizationRecipe, +) + +class MyModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.embedding = nn.Embedding(10, 64) + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.embedding(x) + + +model = MyModel() +recipe = QuantizationRecipe( + embedding_quantization=True, + embedding_bits=4 +) +model = quantize_module_swap(model, recipe) +``` + +``` +>>> model +MyModel( + (embedding): QuantizedEmbedding( + 10, 64 + (weight_quantizer): IntQuantizer() + ) +) +>>> x = torch.randint(0, 10, (10, 64)) +>>> model(x) +tensor([[[-0.0000, 1.7221, 0.6888, ..., 0.5700, -0.5700, -0.8550], + ... + [ 1.2896, -0.0000, 0.3224, ..., -0.5430, -1.9005, 0.5430]]], + grad_fn=) +``` diff --git a/torchao/prototype/quantization/module_swap/__init__.py b/torchao/prototype/quantization/module_swap/__init__.py new file mode 100644 index 0000000000..af62c3899f --- /dev/null +++ b/torchao/prototype/quantization/module_swap/__init__.py @@ -0,0 +1,21 @@ +from .module_swap import ( + QuantizationRecipe, + quantize_module_swap, +) +from .quantized_modules import ( + QuantizedEmbedding, + QuantizedLinear, +) +from .quantizers import ( + CodeBookQuantizer, + IntQuantizer, +) + +__all__ = [ + "CodeBookQuantizer", + "IntQuantizer", + "QuantizedEmbedding", + "QuantizedLinear", + "QuantizationRecipe", + "quantize_module_swap", +] diff --git a/torchao/prototype/quantization/module_swap/algorithms/__init__.py b/torchao/prototype/quantization/module_swap/algorithms/__init__.py new file mode 100644 index 0000000000..f5e61235c7 --- /dev/null +++ b/torchao/prototype/quantization/module_swap/algorithms/__init__.py @@ -0,0 +1,5 @@ +from .kmeans_codebook import kmeans_codebook + +__all__ = [ + "kmeans_codebook", +] diff --git a/torchao/prototype/quantization/module_swap/algorithms/kmeans_codebook.py b/torchao/prototype/quantization/module_swap/algorithms/kmeans_codebook.py new file mode 100644 index 0000000000..7a47f74200 --- /dev/null +++ b/torchao/prototype/quantization/module_swap/algorithms/kmeans_codebook.py @@ -0,0 +1,40 @@ +import torch +import torch.nn as nn + +from torchao.prototype.quantization.module_swap.quantized_modules import QuantizedLinear +from torchao.prototype.quantization.module_swap.quantizers import CodeBookQuantizer + + +def kmeans_codebook( + model: nn.Module, + niter: int = 30, + nredo: int = 1, + dtype: torch.dtype = torch.float32, +) -> None: + import faiss + + with torch.no_grad(): + for layer in model.modules(): + if isinstance(layer, QuantizedLinear): + if isinstance(layer.weight_quantizer, CodeBookQuantizer): + weight = layer.weight + codebook_dim = layer.weight_quantizer.codebook_dim + weight = weight.reshape( + weight.shape[0] * (weight.shape[1] // codebook_dim), + codebook_dim, + ) + num_centroids = layer.weight_quantizer.codebook.shape[0] + kmeans = faiss.Kmeans( + weight.shape[1], + num_centroids, + niter=niter, + nredo=nredo, + verbose=True, + gpu=True if torch.cuda.is_available() else False, + ) + kmeans.train(weight.to(device="cpu", dtype=dtype)) + C = kmeans.centroids + + layer.weight_quantizer.codebook.data = torch.FloatTensor(C).to( + weight.dtype + ) diff --git a/torchao/prototype/quantization/module_swap/data_getters/__init__.py b/torchao/prototype/quantization/module_swap/data_getters/__init__.py new file mode 100644 index 0000000000..75b8cfba71 --- /dev/null +++ b/torchao/prototype/quantization/module_swap/data_getters/__init__.py @@ -0,0 +1,13 @@ +from .llm_ptq_data_getter import ( + LLMPTQDataGetter, +) +from .ptq_data_getter import ( + DataGetter, + get_module_input_data, +) + +__all__ = [ + "DataGetter", + "get_module_input_data", + "LLMPTQDataGetter", +] diff --git a/torchao/prototype/quantization/module_swap/data_getters/llm_ptq_data_getter.py b/torchao/prototype/quantization/module_swap/data_getters/llm_ptq_data_getter.py new file mode 100644 index 0000000000..ced42d1676 --- /dev/null +++ b/torchao/prototype/quantization/module_swap/data_getters/llm_ptq_data_getter.py @@ -0,0 +1,133 @@ +from typing import Dict + +import torch +import torch.nn as nn +from transformers.models.llama.modeling_llama import LlamaForCausalLM + +from torchao.prototype.quantization.module_swap.data_getters.ptq_data_getter import ( + DataGetter, + get_module_input_data, +) +from torchao.prototype.quantization.module_swap.utils import ( + get_layer_by_name, +) + + +class LLMPTQDataGetter(DataGetter): + """ + This datagetter can be used to efficiently retrieve layer-wise input data from a LlamaForCausalLM model + The two benefits are + 1) It caches the data in between layer's residuals, so previous layers dont have to be computed + 2) Layers with the same input have their input cached and returned + + Usage is simple, give it a model, and data + then data_getter.pop(model, layer_name) for each batch of data you require + + the actual model passed is used to get the data, so if you want to e.g. quantize + the entire network with weight quantizers, it uses that data + The datagetter has to be called in-order of the layer's occurence in the network, otherwise it will fail + + """ + + def __init__( + self, model: LlamaForCausalLM, data: torch.Tensor, batch_size: int + ) -> None: + super().__init__() + self.initialize(model, data, batch_size) + + def initialize(self, model: nn.Module, data: torch.Tensor, batch_size: int) -> None: + assert isinstance(model, LlamaForCausalLM) + assert isinstance(data, torch.Tensor) + + # set attention_mask and/or position_ids + self.layer_kwargs: Dict[str, torch.Tensor] = self.get_layer_kwargs(model, data) + + self.input_data_cache: torch.Tensor = get_module_input_data( + model, data, model.model.layers[0], batch_size + ) + self.current_layer_idx = 0 + self.previously_called_name: str = "" + self.batch_size = batch_size + self.output_data_cache: torch.Tensor = torch.zeros_like(data) + self.matched_input_layers = [ + ["q_proj", "k_proj", "v_proj"], + ["up_proj", "gate_proj"], + ] + + def pop(self, model: nn.Module, name: str) -> torch.Tensor: + assert isinstance(model, LlamaForCausalLM) + with torch.no_grad(): + # special case for the last layer + if name != "lm_head": + query_layer_idx = int(name.split(".")[2]) + else: + query_layer_idx = len(model.model.layers) + + assert ( + query_layer_idx >= self.current_layer_idx + ), "pop() called out of order, layers have to be called in order" + + # TODO: batch the next two parts + + # progress the progress over layers + while query_layer_idx > self.current_layer_idx: + self.input_data_cache = model.model.layers[self.current_layer_idx]( + self.input_data_cache, **self.layer_kwargs + )[0] + self.current_layer_idx += 1 + + # special case for the final layer + if name == "lm_head": + return self.input_data_cache + + # use cached output data if the outputs are matching + base_name = self.get_base_name(name) + for matching_list in self.matched_input_layers: + if base_name in matching_list: + previous_base_name = self.get_base_name(self.previously_called_name) + if previous_base_name in matching_list: + self.previously_called_name = name + return self.output_data_cache + + # get data from current requested layer + query_layer = get_layer_by_name(model, name) + layer_output_data = get_module_input_data( + model.model.layers[self.current_layer_idx], + self.input_data_cache, + query_layer, + self.batch_size, + self.layer_kwargs, + ) + + # cache the data for next call + self.output_data_cache = layer_output_data + self.previously_called_name = name + return layer_output_data + + def get_layer_kwargs( + self, model: LlamaForCausalLM, data: torch.Tensor + ) -> Dict[str, torch.Tensor]: + # get used attention_mask and position_ids + layer_kwargs: Dict[str, torch.Tensor] = {} + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, inp, **kwargs): + if kwargs["attention_mask"] is not None: + layer_kwargs["attention_mask"] = kwargs["attention_mask"] + if kwargs["position_ids"] is not None: + layer_kwargs["position_ids"] = kwargs["position_ids"] + raise ValueError + + device = model.parameters().__next__().device + model.model.layers[0] = Catcher(model.model.layers[0]) + try: + model(data[0:2].to(device)) + except ValueError: + pass + model.model.layers[0] = model.model.layers[0].module + + return layer_kwargs diff --git a/torchao/prototype/quantization/module_swap/data_getters/ptq_data_getter.py b/torchao/prototype/quantization/module_swap/data_getters/ptq_data_getter.py new file mode 100644 index 0000000000..6f17ebdc49 --- /dev/null +++ b/torchao/prototype/quantization/module_swap/data_getters/ptq_data_getter.py @@ -0,0 +1,67 @@ +from typing import Dict, List + +import torch +import torch.nn as nn + + +class ExpectedError(Exception): + pass + + +class DataGetter: + def __init__(self) -> None: + return + + def pop(self, model: nn.Module, name: str) -> torch.Tensor: + raise NotImplementedError() + + def get_base_name(self, name: str) -> str: + base_name = name.split(".")[-1] + if base_name.isnumeric(): + base_name = name.split(".")[-2] + return base_name + + def initialize(self, model: nn.Module, data: torch.Tensor, batch_size: int) -> None: + raise NotImplementedError() + + +def get_module_input_data( + model: nn.Module, + data: torch.Tensor, + module: nn.Module, + batch_size: int, + layer_kwargs: Dict[str, torch.Tensor] = {}, # noqa +) -> torch.Tensor: + with torch.no_grad(): + if isinstance(data, list): + num_data = len(data) + else: + num_data = data.shape[0] + num_batches = num_data // batch_size + assert num_data % batch_size == 0 + + input_data: List[torch.Tensor] = [] + + def _input_data_hook( + module: nn.Module, input: List[torch.Tensor], output: List[torch.Tensor] + ) -> None: + input_data.append(input[0].detach()) + assert len(input) == 1 + raise ExpectedError + + hook = module.register_forward_hook(_input_data_hook) + + for i in range(num_batches): + try: + this_batch = data[i * batch_size : (i + 1) * batch_size] + this_batch = this_batch.to(next(model.parameters()).device) + if layer_kwargs: + model(this_batch, **layer_kwargs) + else: + model(this_batch) + except ExpectedError: + pass + + hook.remove() + return_data = torch.cat(input_data, dim=0) + return return_data diff --git a/torchao/prototype/quantization/module_swap/module_swap.py b/torchao/prototype/quantization/module_swap/module_swap.py new file mode 100644 index 0000000000..e56965c8d2 --- /dev/null +++ b/torchao/prototype/quantization/module_swap/module_swap.py @@ -0,0 +1,167 @@ +import logging +from dataclasses import dataclass, field +from typing import List, Union + +import torch +import torch.nn as nn + +from torchao.prototype.quantization.module_swap.algorithms import ( + kmeans_codebook, +) +from torchao.prototype.quantization.module_swap.quantized_modules import ( + QuantizedEmbedding, + QuantizedLinear, +) +from torchao.prototype.quantization.module_swap.quantizers import ( + CodeBookQuantizer, + IntQuantizer, +) +from torchao.prototype.quantization.module_swap.range_setting_methods import ( + set_weight_min_max, +) + +logger: logging.Logger = logging.getLogger(__name__) + + +# TODO: express this using AOBaseConfig +@dataclass +class QuantizationRecipe: + # weights + weight_bits: int = 4 + weight_group_size: Union[int, str] = 32 + weight_quantization: bool = True + dynamic_weights: bool = False + + # weight codebooking settings + weight_codebook: bool = False # if we're using weight codebooks + codebook_dim: int = 1 + + # activations + activation_bits: int = 8 + activation_group_size: Union[int, str] = "per_token" + activation_quantization: bool = False + input_quantization: bool = False + output_quantization: bool = False + dynamic_activations: bool = True + + # general + range_learning: bool = False + embedding_quantization: bool = True + embedding_bits: int = 4 + embedding_group_size: Union[int, str] = 32 + exclude_layers: List[str] = field(default_factory=lambda: ["lm_head"]) + + +def get_layer_parent_by_name(model: nn.Module, input_name: str) -> nn.Module: + parent_name = input_name.rsplit(".", 1)[:-1] + if len(parent_name) == 0: # parent is model itself + return model + else: + parent_name = parent_name[0] + + for name, module in model.named_modules(): + if parent_name == name: + return module + raise ValueError(f"Layer {input_name} not found in model") + + +# TODO: delete this, use quantize_ instead +def quantize_module_swap( + model: nn.Module, recipe: QuantizationRecipe, dtype: torch.dtype = torch.float32 +) -> nn.Module: + model = replace_all_linear_with_quantized_linear(model, recipe) + if recipe.embedding_quantization: + model = replace_all_embedding_with_quantized(model, recipe) + initialize_model_parameters(model, recipe, dtype) + return model + + +def replace_all_embedding_with_quantized( + model: nn.Module, recipe: QuantizationRecipe +) -> nn.Module: + for name, module in model.named_modules(): + if isinstance(module, nn.Embedding): + if name in recipe.exclude_layers: + logger.info(f"skip layer {name} in exclude list") + else: + quantized_embedding = QuantizedEmbedding( + num_embeddings=module.num_embeddings, + embedding_dim=module.embedding_dim, + padding_idx=module.padding_idx, + max_norm=module.max_norm, + norm_type=module.norm_type, + scale_grad_by_freq=module.scale_grad_by_freq, + sparse=module.sparse, + _weight=module.weight, + num_bits=recipe.embedding_bits, + group_size=recipe.embedding_group_size, + quantization_mode="symmetric", + range_learning=recipe.range_learning, + dynamic_weights=recipe.dynamic_weights, + ) + attribute_name = name.rsplit(".", 1)[-1] + parent_of_module = get_layer_parent_by_name(model, name) + setattr(parent_of_module, attribute_name, quantized_embedding) + + logger.info(f"replaced {name} with quantized embedding") + return model + + +def replace_all_linear_with_quantized_linear( + model: nn.Module, recipe: QuantizationRecipe +) -> nn.Module: + for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + if name in recipe.exclude_layers: + logger.info(f"skip layer {name} in exclude list") + else: + if recipe.weight_codebook: + weight_quantizer = CodeBookQuantizer( + n_bits=recipe.weight_bits, + features=module.out_features, + codebook_dim=recipe.codebook_dim, + ) + else: + weight_quantizer = IntQuantizer( + num_bits=recipe.weight_bits, + group_size=recipe.weight_group_size, + dynamic=recipe.dynamic_weights, + quantization_mode="symmetric", + range_learning=recipe.range_learning, + ) + quantized_linear = QuantizedLinear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + weight_quantizer=weight_quantizer, + weight_quantization=recipe.weight_quantization, + activation_bits=recipe.activation_bits, + activation_group_size=recipe.activation_group_size, + activation_quantization=recipe.activation_quantization, + input_quantization=recipe.input_quantization, + output_quantization=recipe.output_quantization, + dynamic_activations=recipe.dynamic_activations, + range_learning=recipe.range_learning, + ) + quantized_linear.weight = module.weight + quantized_linear.bias = module.bias + + # replace the module with the quantized linear module + attribute_name = name.rsplit(".", 1)[-1] + parent_of_module = get_layer_parent_by_name(model, name) + setattr(parent_of_module, attribute_name, quantized_linear) + + # logger.info(f"replaced {name} with quantized linear") + return model + + +def initialize_model_parameters( + model: nn.Module, recipe: QuantizationRecipe, dtype: torch.dtype = torch.float32 +) -> None: + """ + Initialize the model weights and/or codebook if codebook quantization is used + """ + if not recipe.dynamic_weights: + set_weight_min_max(model) + if recipe.weight_codebook: + kmeans_codebook(model, dtype=dtype) diff --git a/torchao/prototype/quantization/module_swap/quantized_modules.py b/torchao/prototype/quantization/module_swap/quantized_modules.py new file mode 100644 index 0000000000..8cf9d5fb06 --- /dev/null +++ b/torchao/prototype/quantization/module_swap/quantized_modules.py @@ -0,0 +1,216 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torchao.prototype.quantization.module_swap.quantizers import ( + CodeBookQuantizer, + IntQuantizer, + __supported_group_size_strings__, +) + +SupportedQuantizers = Union[CodeBookQuantizer, IntQuantizer] + + +class WeightModuleQuantizerBase: + def set_weight_scale_to_min_max(self) -> None: + if not self.weight_quantizer.dynamic: + self.weight_quantizer.set_scale_offset_to_min_max(self.weight) + else: + raise ValueError( + "Weights are quantized dynamically, no range/scale is used" + ) + + @property + def weight_scale(self) -> torch.Tensor: + return self.weight_quantizer.scale + + @property + def quantized_weight(self) -> torch.Tensor: + return self.weight_quantizer(self.weight) + + +class QuantizedLinear(nn.Linear, WeightModuleQuantizerBase): + def __init__( + self, + activation_bits: int, + weight_quantizer: SupportedQuantizers, + input_quantization: bool = True, + output_quantization: bool = False, + activation_group_size: Union[int, str] = "per_token", + weight_quantization: bool = True, + activation_quantization: bool = True, + dynamic_activations: bool = True, + range_learning: bool = False, + scale_eps: float = 1e-9, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + + weight_group_size, activation_group_size = self.validate_group_sizes( + weight_quantizer.group_size, activation_group_size, dynamic_activations + ) + + self.weight_quantizer: SupportedQuantizers = weight_quantizer + self.weight_quantizer.scale_eps = scale_eps + self.input_quantizer: Optional[IntQuantizer] = None + self.output_quantizer: Optional[IntQuantizer] = None + + if input_quantization: + self.input_quantizer = IntQuantizer( + num_bits=activation_bits, + group_size=activation_group_size, + dynamic=dynamic_activations, + quantization_mode="asymmetric", + range_learning=range_learning, + scale_eps=scale_eps, + ) + else: + self.input_quantizer = None + if output_quantization: + self.output_quantizer = IntQuantizer( + num_bits=activation_bits, + group_size=activation_group_size, + dynamic=dynamic_activations, + quantization_mode="asymmetric", + range_learning=range_learning, + scale_eps=scale_eps, + ) + else: + self.output_quantizer = None + + self.input_quantization = input_quantization + self.output_quantization = output_quantization + self.weight_quantization = weight_quantization + self.activation_quantization = activation_quantization + + if not weight_quantizer.dynamic: + assert isinstance(self.weight_quantizer, IntQuantizer) + self.weight_quantizer.set_scale_offset_to_min_max(self.weight) + + self.pre_transforms: nn.ModuleList = nn.ModuleList() + self.post_transforms: nn.ModuleList = nn.ModuleList() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + x = input + + if self.input_quantization: + assert self.input_quantizer is not None + self.input_quantizer.quantize = self.activation_quantization + x = self.input_quantizer(x) + + for transform in self.pre_transforms: + x = transform(x) + + if self.weight_quantization: + weight = self.quantized_weight + else: + weight = self.weight + + x = F.linear(x, weight, self.bias) + + for transform in self.post_transforms: + x = transform(x) + + if self.output_quantization: + assert self.output_quantizer is not None + self.output_quantizer.quantize = self.activation_quantization + x = self.output_quantizer(x) + return x + + @staticmethod + def validate_group_sizes( + weight_group_size: Union[int, str], + activation_group_size: Union[int, str], + dynamic_activations: bool, + ) -> Tuple[Union[int, str], Union[int, str]]: + assert ( + isinstance(weight_group_size, int) + or weight_group_size in __supported_group_size_strings__ + ) + if weight_group_size == "per_token": + raise ValueError( + "per_token is only available for dynamic activation quantization" + ) + + assert ( + isinstance(activation_group_size, int) + or activation_group_size in __supported_group_size_strings__ + ) + if activation_group_size == "per_channel": + raise ValueError("per_channel is not supported for activatins") + if not dynamic_activations and activation_group_size != "per_tensor": + raise ValueError("Only per-tensor supported for static activations") + + return weight_group_size, activation_group_size + + def __repr__(self) -> str: + output_string = "QuantizedLinear(" + empty_space = " " * len(output_string) + output_string += ( + f"weight_quantizer={self.weight_quantizer} - {self.weight_quantization}, \n" + ) + if self.input_quantizer is not None: + output_string += ( + empty_space + + f"input_quant={self.input_quantizer} - {self.activation_quantization}, \n" + ) + if self.output_quantizer is not None: + output_string += ( + empty_space + + f"output_quant={self.output_quantizer} - {self.activation_quantization}, \n" + ) + if self.pre_transforms: + output_string += empty_space + f"pre_transforms={self.pre_transforms}, \n" + if self.post_transforms: + output_string += empty_space + f"post_transforms={self.post_transforms}, \n" + output_string = output_string[:-3] + output_string += ")" + return output_string + + +class QuantizedEmbedding(nn.Embedding, WeightModuleQuantizerBase): + def __init__( + self, + num_bits: int, + group_size: Union[int, str], + quantization_mode: str, + range_learning: bool = False, + scale_eps: float = 1e-9, + dynamic_weights: bool = False, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.weight_quantizer = IntQuantizer( + num_bits=num_bits, + group_size=group_size, + dynamic=dynamic_weights, + quantization_mode="symmetric", + range_learning=range_learning, + scale_eps=scale_eps, + ) + self.weight_quantization = True + self.dynamic_weights = dynamic_weights + + if not self.dynamic_weights: + self.weight_quantizer.set_scale_offset_to_min_max(self.weight) + self._range_learning = range_learning + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.weight_quantization: + weight = self.weight_quantizer(self.weight) + else: + weight = self.weight + + return torch.nn.functional.embedding( + input, + weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) diff --git a/torchao/prototype/quantization/module_swap/quantizers.py b/torchao/prototype/quantization/module_swap/quantizers.py new file mode 100644 index 0000000000..a6f5e7ad95 --- /dev/null +++ b/torchao/prototype/quantization/module_swap/quantizers.py @@ -0,0 +1,347 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.autograd import Function + +__supported_group_size_strings__ = ["per_token", "per_channel", "per_tensor"] + + +class RoundStraightThrough(Function): + @staticmethod + def forward(ctx, x: torch.Tensor) -> torch.Tensor: + return torch.round(x) + + @staticmethod + def backward(ctx, output_grad: torch.Tensor) -> torch.Tensor: + return output_grad + + +class IntQuantizer(nn.Module): + def __init__( + self, + num_bits: int, + group_size: Union[int, str], + dynamic: bool, + quantization_mode: str, + range_learning: bool = False, + scale_eps: float = 1e-9, + ) -> None: + super().__init__() + + self.num_bits = num_bits + self.group_size = group_size + self.dynamic = dynamic + self.quantization_mode = quantization_mode + self.scale_eps = scale_eps + self.scale: Optional[torch.Tensor] = None + self.offset: Optional[torch.Tensor] = None + self._range_learning = range_learning + self.quant_dtype: torch._C.dtype = torch.float32 + self.quantize = True + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.quantize: + return x + + q_min, q_max = self.get_qmin_qmax( + self.num_bits, self.quant_mode_to_signed(self.quantization_mode) + ) + + if self.dynamic: + scale, offset = IntQuantizer.get_scale_offset( + x, + self.group_size, + self.quantization_mode, + q_min, + q_max, + ) + else: + scale = self.scale + offset = self.offset + + if scale is None: + raise ValueError("Initialize scale before first forward pass") + + group_size = self.process_group_size_to_int(self.group_size, x) + + if self._range_learning: + scale = torch.clamp(scale, min=self.scale_eps) + if offset is not None: + offset = RoundStraightThrough.apply(offset) + offset = torch.clamp(offset, q_min, q_max) + + return self.quantize_forward( + x, + scale=scale, + offset=offset, + group_size=group_size, + q_min=q_min, + q_max=q_max, + ) + + @property + def q_min(self) -> int: + q_min, _ = self.get_qmin_qmax( + self.num_bits, self.quant_mode_to_signed(self.quantization_mode) + ) + return q_min + + @property + def q_max(self) -> int: + _, q_max = self.get_qmin_qmax( + self.num_bits, self.quant_mode_to_signed(self.quantization_mode) + ) + return q_max + + @staticmethod + def quant_mode_to_signed(quant_mode: str) -> bool: + if quant_mode == "symmetric": + return True + elif quant_mode == "asymmetric": + return False + else: + raise NotImplementedError + + @staticmethod + def get_scale_param_size( + x: torch.Tensor, group_size: Union[int, str] + ) -> torch.Size: + int_group_size = IntQuantizer.process_group_size_to_int(group_size, x) + if int_group_size is None: + return torch.Size([1]) + else: + if len(x.shape) == 1: + return torch.Size([x.shape[0] // int_group_size]) + else: + size_list = [x.shape[i] for i in range(len(x.shape) - 1)] + size_list.append(x.shape[-1] // int_group_size) + return torch.Size(size_list) + + @staticmethod + def get_qmin_qmax(n_bits: int, signed: bool) -> Tuple[int, int]: + if signed: + qmin = -(2 ** (n_bits - 1)) + qmax = 2 ** (n_bits - 1) - 1 + else: + qmin = 0 + qmax = 2**n_bits - 1 + return qmin, qmax + + @staticmethod + def get_scale_offset( + x: torch.Tensor, + group_size: Union[int, str], + quantization_mode: str, + q_min: int, + q_max: int, + scale_eps: float = 1e-9, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if not isinstance(group_size, int) and group_size == "per_tensor": + x_min = torch.min(x) + x_max = torch.max(x) + else: + int_group_size = IntQuantizer.process_group_size_to_int(group_size, x) + assert int_group_size is not None + reshaped_x = x.reshape( + [*x.shape[:-1], x.shape[-1] // int_group_size, int_group_size] + ) + x_min = torch.min(reshaped_x, dim=-1)[0] + x_max = torch.max(reshaped_x, dim=-1)[0] + + if quantization_mode == "symmetric": + scale = IntQuantizer.get_scale_from_min_max( + x_min, x_max, q_min, q_max, scale_eps + ) + offset = None + elif quantization_mode == "asymmetric": + scale, offset = IntQuantizer.get_scale_offset_from_min_max( + x_min, x_max, q_min, q_max, scale_eps + ) + else: + raise NotImplementedError + + return scale, offset + + @staticmethod + def quantize_forward( + x: torch.Tensor, + scale: torch.Tensor, + offset: Optional[torch.Tensor], + q_min: int, + q_max: int, + group_size: Optional[int] = None, + ) -> torch.Tensor: + # if quantization is per_group, we need to reshape the tensor to apply it over + reshaped = False + orig_shape = x.shape + if group_size is not None and group_size != x.shape[-1]: + x = x.reshape(*x.shape[:-1], x.shape[-1] // group_size, group_size) + scale = torch.unsqueeze(scale, -1) + if offset is not None: + offset = torch.unsqueeze(offset, -1) + reshaped = True + + if scale.device != x.device: + scale = scale.to(x.device) + if offset is not None: + offset = offset.to(x.device) + + input_dtype = x.dtype + + if offset is None or offset.numel() == 0: + x = torch.clamp(RoundStraightThrough.apply(x / scale), q_min, q_max) + x = x * scale + else: + x = torch.clamp( + RoundStraightThrough.apply(x / scale) + offset, + q_min, + q_max, + ) + x = (x - offset) * scale + + # reshape back to original shape if groups were used + if reshaped: + x = x.reshape(orig_shape) + + return x.to(input_dtype) + + @staticmethod + def process_group_size_to_int( + group_size: Union[int, str], x: torch.Tensor + ) -> Optional[int]: + if isinstance(group_size, int): + return group_size + elif group_size == "per_channel" or group_size == "per_token": + return x.shape[-1] + elif group_size == "per_tensor": + return None + else: + raise NotImplementedError + + @staticmethod + def get_scale_from_min_max( + x_min: torch.Tensor, + x_max: torch.Tensor, + q_min: int, + q_max: int, + eps: float = 1e-9, + ) -> torch.Tensor: + smin = x_min / float(q_min) + smax = x_max / float(q_max) + mask = smin > smax + scale = torch.where(mask, smin, smax) + scale = torch.clamp(scale, min=eps) + return scale + + @staticmethod + def get_scale_offset_from_min_max( + x_min: torch.Tensor, + x_max: torch.Tensor, + q_min: int, + q_max: int, + eps: float = 1e-9, + ) -> Tuple[torch.Tensor, torch.Tensor]: + scale = ((x_max - x_min) / (abs(q_min) + abs(q_max))).detach() + offset = (-x_min / scale).detach() + + offset = torch.round(offset) + offset = torch.clamp(offset, q_min, q_max) + + scale = torch.clamp(scale, min=eps) + return scale, offset + + def set_scale_offset_to_min_max(self, x: torch.Tensor) -> None: + assert not self.dynamic + + x = x.detach() + x = x.to(self.quant_dtype) + + signed = self.quant_mode_to_signed(self.quantization_mode) + + q_min, q_max = self.get_qmin_qmax(self.num_bits, signed=signed) + scale, offset = self.get_scale_offset( + x, self.group_size, self.quantization_mode, q_min, q_max + ) + + # with per-tensor we get empty tensors sometimes + if scale.data.shape == torch.Size([]): + scale_size_fix = torch.ones([1]).to(scale.device) + scale_size_fix *= scale + scale.data = scale_size_fix.data + + if self._range_learning: + self.scale = torch.nn.Parameter(scale, requires_grad=True).to( + self.quant_dtype + ) + if offset is not None: + self.offset = torch.nn.Parameter(offset, requires_grad=True).to( + self.quant_dtype + ) + else: + self.scale = scale + self.offset = offset + + +class CodeBookQuantizer(nn.Module): + def __init__( + self, + n_bits: int, + features: int, + codebook_dim: int = 1, + seed: int = 1337, + ) -> None: + super().__init__() + self.group_size = codebook_dim + self.codebook_dim = codebook_dim + self.dynamic = True + self.scale_eps = 1e-9 + assert features % codebook_dim == 0 + torch.manual_seed(seed) + + N = 2 ** (n_bits * codebook_dim) + codebook_shape = (N, codebook_dim) + self.codebook = nn.Parameter( + torch.Tensor(codebook_shape[0], codebook_shape[1]), requires_grad=False + ) + + self._codebook_shape: Tuple[int, int] = codebook_shape + + def forward(self, x: torch.Tensor) -> torch.Tensor: + scaling_factor = torch.mean(abs(x), dim=1, keepdim=True).detach() + output = scaling_factor * VectorQuantizerFunction.apply( + x / scaling_factor, self.codebook + ) + return output + + +class VectorQuantizerFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[2, 14] + def forward(ctx, inputs: torch.Tensor, codebook: torch.Tensor) -> torch.Tensor: + flat_inputs = inputs.view(-1, codebook.shape[1]) + distances = torch.cdist(flat_inputs, codebook) + indices = torch.argmin(distances, dim=1) + sums = torch.zeros_like(codebook).float().cuda() + counts = torch.zeros(codebook.size(0), dtype=torch.float).cuda() + + # Accumulate sums using index_add_ + sums.index_add_(0, indices, flat_inputs.float()) + + # Accumulate counts using index_add_ + ones = torch.ones(flat_inputs.size(0), dtype=torch.float).cuda() + counts.index_add_(0, indices, ones) + + # Avoid division by zero + counts = counts.unsqueeze(1).clamp(min=1.0) + + # Update each centroid with the average of assigned inputs + codebook.copy_(sums / counts) + + quantized = codebook[indices].view_as(inputs) + return quantized + + @staticmethod + # pyre-ignore[2, 14] + def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: + return grad_output, None diff --git a/torchao/prototype/quantization/module_swap/range_setting_methods.py b/torchao/prototype/quantization/module_swap/range_setting_methods.py new file mode 100644 index 0000000000..2ba93321fb --- /dev/null +++ b/torchao/prototype/quantization/module_swap/range_setting_methods.py @@ -0,0 +1,216 @@ +import copy +import logging +from typing import Callable, Optional + +import torch +import torch.nn as nn + +from torchao.prototype.quantization.module_swap.data_getters import ( + DataGetter, + get_module_input_data, +) +from torchao.prototype.quantization.module_swap.quantized_modules import ( + QuantizedLinear, +) +from torchao.prototype.quantization.module_swap.utils import ( + all_activation_quantizers_off, + all_quantizers_off, + all_weight_quantizers_on, +) + +logger: logging.Logger = logging.getLogger(__name__) + + +def set_weight_min_max(model: nn.Module) -> None: + for _, module in model.named_modules(): + if isinstance(module, QuantizedLinear): + module.set_weight_scale_to_min_max() + + +def set_weight_mse( + model: nn.Module, num_points: int = 100, max_shrink: float = 0, norm: float = 2.0 +) -> None: + for _, module in model.named_modules(): + if isinstance(module, QuantizedLinear): + loss_fn = lambda m: torch.sum( # noqa + torch.pow(torch.abs(m.weight - m.quantized_weight), norm), dim=-1 + ) + best_scale = find_optimal_scales_with_loss( + module, loss_fn, num_points, max_shrink + ) + module.weight_scale.data = best_scale + + +def get_batched_output( + module: nn.Module, input_data: torch.Tensor, batch_size: int +) -> torch.Tensor: + device = module.weight.device + dtype = module.weight.dtype + + num_samples = input_data.shape[0] + num_batches = num_samples // batch_size + + output_data = [] + for i in range(num_batches): + this_batch = input_data[i * batch_size : (i + 1) * batch_size] + this_batch = this_batch.to(device).to(dtype) + output_data.append(module(this_batch).to(torch.float32).to("cpu")) + + return torch.vstack(output_data) + + +def set_weight_range_activation_loss( + model: nn.Module, + data: torch.Tensor, + batch_size: int, + num_points: int = 100, + progressive: bool = True, + data_getter: Optional[DataGetter] = None, +) -> None: + # store quantization settings so this algorithm does not change those implicitly + quantization_setting_mapping_dict = { + name: [module.weight_quantization, module.activation_quantization] + for name, module in model.named_modules() + if isinstance(module, QuantizedLinear) + } + + data_getter_progressive = None + if data_getter is not None: + data_getter.initialize(model, data, batch_size) + if progressive: + data_getter_progressive = copy.deepcopy(data_getter) + + # TODO: This can all be optimized for efficiency (keep data on GPU) or Memory (keep data on CPU) + # Do the actual range setting + with torch.no_grad(): + for name, module in model.named_modules(): + if isinstance(module, QuantizedLinear): + logger.info(f"Range setting for {name}") + model.apply(all_quantizers_off) + # TODO: Some form of smart subsampling from all this sequential data + if data_getter is not None: + input_data = data_getter.pop(model, name) + else: + input_data = get_module_input_data(model, data, module, batch_size) + output_data = get_batched_output(module, input_data, batch_size) + + if progressive: + model.apply(all_weight_quantizers_on) + if data_getter_progressive is not None: + input_data = data_getter_progressive.pop(model, name) + else: + input_data = get_module_input_data( + model, data, module, batch_size + ) + + input_data = input_data.to(module.weight.device).to(module.weight.dtype) + output_data = output_data.to(module.weight.device).to( + module.weight.dtype + ) + + module.weight_quantization = True + dim = tuple(range(input_data.dim() - 1)) # all but last + + # TODO: batched loss getting + loss_fn = lambda m: torch.mean( # noqa + torch.pow(m(input_data) - output_data, 2), + dim=dim, # noqa + ) + + best_scale = find_optimal_scales_with_loss(module, loss_fn, num_points) + module.weight_scale.data = best_scale + + # reset quantization settings to original values + for name, module in model.named_modules(): + if isinstance(module, QuantizedLinear): + module.weight_quantization, module.activation_quantization = ( + quantization_setting_mapping_dict[name] + ) + + +def set_activation_min_max( + model: nn.Module, data: torch.Tensor, batch_size: int +) -> None: + # store quantization settings so this algorithm does not change those implicitly + quantization_setting_mapping_dict = { + name: [module.weight_quantization, module.activation_quantization] + for name, module in model.named_modules() + if isinstance(module, QuantizedLinear) + } + + model.apply(all_activation_quantizers_off) + for name, module in model.named_modules(): + if isinstance(module, QuantizedLinear): + logger.info(f"Activation min/max setting for {name}") + input_data = None + if module.input_quantization: + input_data = get_module_input_data(model, data, module, batch_size) + assert module.input_quantizer is not None + module.input_quantizer.set_scale_offset_to_min_max(input_data) + if module.output_quantization: + if input_data is None: + input_data = get_module_input_data(model, data, module, batch_size) + output_data = module(input_data) + assert module.output_quantizer is not None + module.output_quantizer.set_scale_offset_to_min_max(output_data) + + # reset quantization settings to original values + for name, module in model.named_modules(): + if isinstance(module, QuantizedLinear): + module.weight_quantization, module.activation_quantization = ( + quantization_setting_mapping_dict[name] + ) + + +def find_optimal_scales_with_loss( + module: QuantizedLinear, + loss_fn: Callable[[nn.Module], torch.Tensor], + num_points: int, + max_shrink: float = 0, +) -> torch.Tensor: + assert max_shrink >= 0 and max_shrink < 1.0 + assert num_points > 0 + + with torch.no_grad(): + grid = torch.linspace(max_shrink, 1, num_points + 1) + module.set_weight_scale_to_min_max() + + orig_scales = module.weight_scale.clone() + best_scale = module.weight_scale.clone() + best_loss = loss_fn(module) + + for i in range(0, num_points - 1): + test_scale = orig_scales * grid[i] + module.weight_scale.data = test_scale + loss = loss_fn(module) + mask = loss < best_loss + best_loss[mask] = loss[mask] + best_scale[mask] = test_scale[mask] + return best_scale + + +def quantize_per_group_scales(model: nn.Module, bit_width: int) -> None: + for name, module in model.named_modules(): + if isinstance(module, QuantizedLinear): + scale = module.weight_quantizer.scale + assert isinstance(scale, torch.Tensor) + + if len(scale.shape) < 2 or scale.shape[-1] == 1: + logger.warning( + f"Module {name} is not quantized with group_wise quantization" + ) + continue + + per_channel_scales = torch.max(scale, dim=-1, keepdim=True).values + scale = scale / per_channel_scales # scale to [0, 1] + + # quantize the per_group scale to bit_width + quant_max = 2.0**bit_width + scale = ( + torch.clamp(torch.ceil(scale * quant_max), min=1.0, max=quant_max) + / quant_max + ) # ceil to make sure clipping error is certainly 0 + + scale = scale * per_channel_scales # fuse the fp16 scale + + module.weight_quantizer.scale.data = scale diff --git a/torchao/prototype/quantization/module_swap/utils.py b/torchao/prototype/quantization/module_swap/utils.py new file mode 100644 index 0000000000..b39dd9479a --- /dev/null +++ b/torchao/prototype/quantization/module_swap/utils.py @@ -0,0 +1,71 @@ +from typing import Dict + +import torch.nn as nn + +from torchao.prototype.quantization.module_swap.quantized_modules import QuantizedLinear +from torchao.prototype.quantization.module_swap.quantizers import IntQuantizer + + +def get_layer_by_name(model: nn.Module, query_name: str) -> nn.Module: + """ + Retrieves a layer from a PyTorch model by its name. + + Args: + model (nn.Module): The PyTorch model. + name (str): The name of the layer to retrieve. + + Returns: + nn.Module: The retrieved layer. + """ + for name, module in model.named_modules(): + if name == query_name: + return module + raise ValueError(f"Layer '{query_name}' not found in model") + + +def all_quantizers_off(module: nn.Module) -> None: + if isinstance(module, QuantizedLinear): + module.weight_quantization = False + module.activation_quantization = False + + +def all_quantizers_on(module: nn.Module) -> None: + if isinstance(module, QuantizedLinear): + module.weight_quantization = True + module.activation_quantization = True + + +def all_activation_quantizers_off(module: nn.Module) -> None: + if isinstance(module, QuantizedLinear): + module.activation_quantization = False + + +def all_activation_quantizers_on(module: nn.Module) -> None: + if isinstance(module, QuantizedLinear): + module.activation_quantization = True + + +def all_weight_quantizers_on(module: nn.Module) -> None: + if isinstance(module, QuantizedLinear): + module.weight_quantization = True + + +def set_bit_widths_by_name( + model: nn.Module, bit_width_dict: Dict[str, Dict[str, int]] +) -> None: + for name, bit_width_assignment in bit_width_dict.items(): + this_layer = get_layer_by_name(model, name) + for quantizer, bit_width in bit_width_assignment.items(): + assert isinstance(this_layer, QuantizedLinear) + if quantizer == "weight": + assert isinstance(this_layer.weight_quantizer, IntQuantizer) + this_layer.weight_quantizer.num_bits = bit_width + elif quantizer == "activation": + if this_layer.input_quantizer is not None: + this_layer.input_quantizer.num_bits = bit_width + if this_layer.output_quantizer is not None: + this_layer.output_quantizer.num_bits = bit_width + else: + raise ValueError( + f"Unknown quantizer {quantizer}, should be either 'weight' or 'activation'" + )