Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import tempfile
import unittest

import torch
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

from torchao import quantize_
from torchao.quantization import PerGroup, PerRow, PerTensor
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
)
from torchao.quantization.utils import compute_error
from torchao.utils import (
torch_version_at_least,
)


def get_config(granularity):
return Float8DynamicActivationFloat8WeightConfig(
activation_dtype=torch.float8_e4m3fn,
granularity=granularity,
float8_packing_format="opaque",
)


class ToyLinearModel(torch.nn.Module):
def __init__(self, K=64, N=32, bias=False):
super().__init__()
self.linear1 = torch.nn.Linear(K, N, bias=bias).to(torch.float)
self.linear2 = torch.nn.Linear(N, K, bias=bias).to(torch.float)

def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
return (
torch.rand(batch_size, self.linear1.in_features, dtype=dtype, device=device)
* 0.1,
)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x


class TestFloat8OpaqueTensor(TestCase):
"""Test cases for Float8OpaqueTensor on CPU"""

@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"),
reason="cpp kernels not built",
)
@unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+")
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
@common_utils.parametrize("x_dim", [2, 3])
@common_utils.parametrize("bias", [True, False])
@common_utils.parametrize("bs", [1, 160])
@common_utils.parametrize(
"x_granularity",
[PerTensor(), PerRow(), PerGroup(32), PerGroup(64), PerGroup(128)],
)
@common_utils.parametrize(
"w_granularity",
[PerTensor(), PerRow(), PerGroup(32), PerGroup(64), PerGroup(128)],
)
def test_dynamic_float8_linear(
self, dtype, x_dim, bias, bs, x_granularity, w_granularity
):
if isinstance(x_granularity, PerGroup):
if not isinstance(w_granularity, PerGroup):
return
if w_granularity.group_size != x_granularity.group_size:
return
device = "cpu"
m = ToyLinearModel(256, 256, bias=bias).eval().to(dtype).to(device)
example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device)
if x_dim == 3:
example_inputs = (example_inputs[0].unsqueeze(0),)
y = m(*example_inputs)

with torch.no_grad():
quantize_(
m,
get_config([x_granularity, w_granularity]),
)
y1 = m(*example_inputs)
assert compute_error(y, y1) > 20
y2, code = torch._inductor.utils.run_and_get_code(
torch.compile(m, fullgraph=True, dynamic=True),
*example_inputs,
)
# ensure the expected op is in the code
assert "torch.ops.torchao.float8_linear_cpu.default" in code[0]
assert compute_error(y, y2) > 20

@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"),
reason="cpp kernels not built",
)
@unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+")
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
@common_utils.parametrize("x_dim", [2, 3])
@common_utils.parametrize("bias", [True, False])
@common_utils.parametrize("bs", [4, 128])
def test_dynamic_float8_linear_ref(self, dtype, x_dim, bias, bs):
device = "cpu"
# the shape is not supported by cpp kernel, so the ref path will be used.
m = ToyLinearModel(120, 120, bias=bias).eval().to(dtype).to(device)
example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device)
if x_dim == 3:
example_inputs = (example_inputs[0].unsqueeze(0),)
y = m(*example_inputs)

with torch.no_grad():
quantize_(
m,
get_config(PerRow()),
)
y1 = m(*example_inputs)
assert compute_error(y, y1) > 20
y2, code = torch._inductor.utils.run_and_get_code(
torch.compile(m, fullgraph=True, dynamic=True),
*example_inputs,
)
# ensure the expected op is in the code
assert "torch.ops.torchao.float8_linear_cpu.default" in code[0]
assert compute_error(y, y2) > 20

@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"),
reason="cpp kernels not built",
)
@common_utils.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
def test_module_path(self, dtype):
linear = torch.nn.Linear(128, 256, dtype=dtype)
quantize_(linear, get_config(PerRow()))
self.assertEqual(
str(type(linear.weight)),
"<class 'torchao.quantization.Float8OpaqueTensor'>",
)

with tempfile.NamedTemporaryFile() as f:
torch.save(linear.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)
self.assertEqual(
str(type(state_dict["weight"])),
"<class 'torchao.quantization.Float8OpaqueTensor'>",
)


common_utils.instantiate_parametrized_tests(TestFloat8OpaqueTensor)


if __name__ == "__main__":
run_tests()
28 changes: 21 additions & 7 deletions torchao/float8/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul
from torchao.float8.types import FP8Granularity
from torchao.quantization.granularity import (
PerGroup,
PerRow,
PerTensor,
)
Expand Down Expand Up @@ -204,28 +205,41 @@ def _normalize_granularity(
list[FP8Granularity],
]
],
supported_granularities: tuple[FP8Granularity] = (PerTensor, PerRow),
support_different_granularities: bool = False,
) -> Tuple[FP8Granularity, FP8Granularity]:
processed_granularity = None
if granularity is None:
processed_granularity = (PerTensor(), PerTensor())
elif isinstance(granularity, (PerTensor, PerRow)):
elif isinstance(granularity, supported_granularities):
processed_granularity = (granularity, granularity)
elif isinstance(granularity, (tuple, list)) and len(granularity) == 2:
if not (
isinstance(granularity[0], (PerTensor, PerRow))
and isinstance(granularity[1], (PerTensor, PerRow))
isinstance(granularity[0], supported_granularities)
and isinstance(granularity[1], supported_granularities)
):
raise ValueError(
f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported."
f"Invalid granularity types: {granularity}, only {supported_granularities} are supported."
)
if not isinstance(granularity[0], type(granularity[1])):
if not support_different_granularities and not isinstance(
granularity[0], type(granularity[1])
):
raise ValueError(
f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported."
f"Different granularities for activation and weight are not supported: {granularity}, only {supported_granularities} are supported."
)
if isinstance(granularity[0], PerGroup):
if not isinstance(granularity[1], PerGroup):
raise ValueError(
"When granularity for activation is PerGroup, granularity for weight must be PerGroup, too."
)
if granularity[0].group_size != granularity[1].group_size:
raise ValueError(
f"Group sizes for activation and weight must be the same, got {granularity[0].group_size} and {granularity[1].group_size}."
)
processed_granularity = tuple(granularity)
else:
raise ValueError(
f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported."
f"Invalid granularity specification: {granularity}, only {supported_granularities} are supported."
)
return processed_granularity

Expand Down
4 changes: 2 additions & 2 deletions torchao/float8/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from typing import TYPE_CHECKING, Union

if TYPE_CHECKING:
from torchao.quantization.granularity import PerRow, PerTensor
from torchao.quantization.granularity import PerGroup, PerRow, PerTensor


# Define FP8Granularity type alias to break circular import dependencies
FP8Granularity = Union["PerTensor", "PerRow"]
FP8Granularity = Union["PerTensor", "PerRow", "PerGroup"]
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
quantize_affine,
)
from .quantize_.workflows import (
Float8OpaqueTensor,
Float8Tensor,
Int4MarlinSparseTensor,
Int4OpaqueTensor,
Expand Down Expand Up @@ -172,6 +173,7 @@
"Int4TilePackedTo4dTensor",
"Float8Tensor",
"Int4OpaqueTensor",
"Float8OpaqueTensor",
# smooth quant - subject to change
"get_scale",
"SmoothFakeDynQuantMixin",
Expand Down
Loading
Loading