Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions torchao/experimental/tests/test_embedding_xbit_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import unittest

import torch
from parameterized import param, parameterized
from torch.testing import FileCheck

from torchao.dtypes import (
Expand All @@ -19,8 +20,15 @@
SharedEmbeddingQuantizer,
)
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.qat import (
FakeQuantizeConfig,
FromIntXQuantizationAwareTrainingConfig,
Int4WeightOnlyEmbeddingQATQuantizer,
IntXQuantizationAwareTrainingConfig,
)
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
IntxWeightOnlyConfig,
MappingType,
quantize_,
)
Expand Down Expand Up @@ -184,6 +192,184 @@ def test_shared_embedding(self):
exported_program.graph_module.code
)

@parameterized.expand(
[
param(
weight_dtype=weight_dtype,
granularity=granularity,
mapping_type=mapping_type,
model_dtype=model_dtype,
)
for weight_dtype in [getattr(torch, f"int{x}") for x in range(1, 9)]
for granularity in [PerGroup(32), PerGroup(128), PerAxis(0)]
for mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]
for model_dtype in [torch.float32]
],
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
)
def test_identical_to_IntxWeightOnlyConfig(
self, weight_dtype, granularity, mapping_type, model_dtype
):
embedding_dim = 4096
num_embeddings = 131
model = torch.nn.Sequential(
*[torch.nn.Embedding(num_embeddings, embedding_dim)]
)
model = model.to(model_dtype)
indices = torch.randint(0, num_embeddings, (7,), dtype=torch.int32)

quantized_model = copy.deepcopy(model)
quantizer = EmbeddingQuantizer(
weight_dtype=weight_dtype,
granularity=granularity,
mapping_type=mapping_type,
)
quantized_model = quantizer.quantize(quantized_model)
actual_result = quantized_model(indices)

reference_model = copy.deepcopy(model)
quantize_(
reference_model,
IntxWeightOnlyConfig(
weight_dtype=weight_dtype,
granularity=granularity,
mapping_type=mapping_type,
scale_dtype=None,
),
lambda m, fqn: isinstance(m, torch.nn.Embedding),
)
expected_result = reference_model(indices)
self.assertTrue(torch.allclose(actual_result, expected_result))

@parameterized.expand(
[
param(
weight_dtype=weight_dtype,
granularity=granularity,
mapping_type=mapping_type,
scale_dtype=scale_dtype,
model_dtype=model_dtype,
)
for weight_dtype in [getattr(torch, f"int{x}") for x in range(1, 9)]
for granularity in [PerGroup(32), PerGroup(128), PerAxis(0)]
for mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]
for scale_dtype in [torch.float32, torch.bfloat16, torch.float16]
for model_dtype in [torch.float32, torch.bfloat16]
],
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
)
def test_identical_to_IntXQuantizationAwareTrainingConfig(
self, weight_dtype, granularity, mapping_type, scale_dtype, model_dtype
):
# ASYMMETRIC in QAT is very different that PTQ configs
if mapping_type == MappingType.ASYMMETRIC:
return

embedding_dim = 4096
num_embeddings = 131
model = torch.nn.Sequential(
*[torch.nn.Embedding(num_embeddings, embedding_dim)]
)
model = model.to(model_dtype)
indices = torch.randint(0, num_embeddings, (7,), dtype=torch.int32)

is_symmetric = mapping_type == MappingType.SYMMETRIC
group_size = (
granularity.group_size
if isinstance(granularity, PerGroup)
else embedding_dim
)

embedding_filter = lambda m, fqn: isinstance(m, torch.nn.Embedding)
weight_config = FakeQuantizeConfig(
weight_dtype,
group_size=group_size,
is_symmetric=is_symmetric,
scale_precision=scale_dtype,
)
quantize_(
model,
IntXQuantizationAwareTrainingConfig(weight_config=weight_config),
embedding_filter,
)
expected_out = model(indices)

quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter)
quantize_(
model,
IntxWeightOnlyConfig(
weight_dtype=weight_dtype,
granularity=granularity,
mapping_type=mapping_type,
scale_dtype=scale_dtype,
),
embedding_filter,
)
actual_out = model(indices)
self.assertTrue(torch.allclose(expected_out, actual_out))

@parameterized.expand(
[
param(
granularity=granularity,
scale_dtype=scale_dtype,
model_dtype=model_dtype,
)
for granularity in [PerGroup(32), PerGroup(128), PerAxis(0)]
for scale_dtype in [torch.float32, torch.bfloat16, torch.float16]
for model_dtype in [torch.float32, torch.bfloat16]
],
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
)
def test_identical_to_Int4WeightOnlyEmbeddingQATQuantizer(
self, granularity, scale_dtype, model_dtype
):
embedding_dim = 4096
num_embeddings = 131
model = torch.nn.Sequential(
*[torch.nn.Embedding(num_embeddings, embedding_dim)]
)
model = model.to(model_dtype)
indices = torch.randint(0, num_embeddings, (7,), dtype=torch.int32)

group_size = (
granularity.group_size
if isinstance(granularity, PerGroup)
else embedding_dim
)

embedding_filter = lambda m, fqn: isinstance(m, torch.nn.Embedding)

qat_quantizer = Int4WeightOnlyEmbeddingQATQuantizer(
group_size=group_size,
scale_precision=scale_dtype,
zero_point_precision=torch.int32,
)
model = qat_quantizer.prepare(model)
expected_out = model(indices)

# Convert model method 1
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter)
quantize_(
model,
IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=granularity,
mapping_type=MappingType.SYMMETRIC,
scale_dtype=scale_dtype,
),
embedding_filter,
)
actual_out1 = model(indices)
self.assertTrue(torch.allclose(expected_out, actual_out1))

# TODO: method 2 does not work because the converted embedding op
# incorrectly casts output of to indices.dtype
# Convert model method 2
# qat_quantizer.convert(prepared_model_copy)
# actual_out2 = prepared_model_copy(indices)
# self.assertTrue(torch.allclose(expected_out, actual_out2))


if __name__ == "__main__":
unittest.main()
Loading
Loading