Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,20 @@ def test_weight_only_groupwise_quant(self):
sqnr = compute_error(y_ref, y_wo)
self.assertGreater(sqnr, 45.0)

def test_weight_only_groupwise_embedding_quant(self):
group_size = 64
m = nn.Embedding(4096, 128)
input = torch.randint(0, 4096, (1, 6))

quantize_(m, int8_weight_only(group_size=group_size), filter_fn=lambda x, *args: isinstance(x, nn.Embedding))
y_q = m(input)
y_ref = m.weight.dequantize()[input]

sqnr = compute_error(y_ref, y_q)

self.assertGreater(sqnr, 45.0)


@parameterized.expand(COMMON_DEVICE_DTYPE)
@torch.no_grad()
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down
4 changes: 3 additions & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ def main(
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
if "fp6" in quantization:
quantize_(model, fpx_weight_only(3, 2))
if "embed-int8wo" in quantization:
quantize_(model, int8_weight_only(group_size=64), filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding))
if quantization.startswith("awq"):
from torchao._models._eval import TransformerEvalWrapper
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
Expand Down Expand Up @@ -463,7 +465,7 @@ def callback(x):
parser.add_argument('-q', '--quantization', type=str,
help=(
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant'
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, embed-int8wo'
)
)
parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples")
Expand Down
26 changes: 26 additions & 0 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1840,6 +1840,32 @@ def _(func, types, args, kwargs):
weight_tensor = weight_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)

@implements(torch.nn.functional.embedding)
def _(func, types, args, kwargs):
# new_arg1 = args[1].dequantize()
# return torch.nn.embedding(args[0], new_arg1, *args[2:], **kwargs)
assert isinstance(args[1].tensor_impl, PlainAQTTensorImpl), f"embedding only works with PlainAQTTensorImpl but got {type(args[1].tensor_impl)}"
assert kwargs["padding_idx"] is None and kwargs["max_norm"] is None and not kwargs["scale_grad_by_freq"] and not kwargs["sparse"] and kwargs["norm_type"]==2.0
idx = args[0]
int_data, scale, zero_point = args[1].tensor_impl.get_plain()

sliced_data, sliced_scale, sliced_zero_point = int_data[idx], scale[idx], zero_point[idx]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any restrictions on idx for this to be valid?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not as far as our tests show

# Block size is expecting 2 dimensions [1, group size] but
# batchsize or other dims gets added to sliced_data, sliced_scale and sliced_zero_point so
# we need to increase block size to correct dim
new_blocks = idx.dim()-1
return dequantize_affine(
sliced_data,
new_blocks*[1]+list(args[1].block_size),
sliced_scale,
sliced_zero_point,
sliced_data.dtype,
args[1].quant_min,
args[1].quant_max,
args[1].zero_point_domain,
output_dtype=sliced_scale.dtype,
)

@implements(aten.addmm.default)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
Expand Down
Loading