diff --git a/test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py new file mode 100644 index 0000000000..5872e91086 --- /dev/null +++ b/test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest + +import torch +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +from torchao import quantize_ +from torchao.quantization import PerGroup, PerRow, PerTensor +from torchao.quantization.quant_api import ( + Float8DynamicActivationFloat8WeightConfig, +) +from torchao.quantization.utils import compute_error +from torchao.utils import ( + torch_version_at_least, +) + + +def get_config(granularity): + return Float8DynamicActivationFloat8WeightConfig( + activation_dtype=torch.float8_e4m3fn, + granularity=granularity, + float8_packing_format="opaque", + ) + + +class ToyLinearModel(torch.nn.Module): + def __init__(self, K=64, N=32, bias=False): + super().__init__() + self.linear1 = torch.nn.Linear(K, N, bias=bias).to(torch.float) + self.linear2 = torch.nn.Linear(N, K, bias=bias).to(torch.float) + + def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): + return ( + torch.rand(batch_size, self.linear1.in_features, dtype=dtype, device=device) + * 0.1, + ) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +class TestFloat8OpaqueTensor(TestCase): + """Test cases for Float8OpaqueTensor on CPU""" + + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), + reason="cpp kernels not built", + ) + @unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+") + @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + @common_utils.parametrize("bias", [True, False]) + @common_utils.parametrize("bs", [1, 160]) + @common_utils.parametrize( + "x_granularity", + [PerTensor(), PerRow(), PerGroup(32), PerGroup(64), PerGroup(128)], + ) + @common_utils.parametrize( + "w_granularity", + [PerTensor(), PerRow(), PerGroup(32), PerGroup(64), PerGroup(128)], + ) + def test_dynamic_float8_linear( + self, dtype, x_dim, bias, bs, x_granularity, w_granularity + ): + if isinstance(x_granularity, PerGroup): + if not isinstance(w_granularity, PerGroup): + return + if w_granularity.group_size != x_granularity.group_size: + return + device = "cpu" + m = ToyLinearModel(256, 256, bias=bias).eval().to(dtype).to(device) + example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + y = m(*example_inputs) + + with torch.no_grad(): + quantize_( + m, + get_config([x_granularity, w_granularity]), + ) + y1 = m(*example_inputs) + assert compute_error(y, y1) > 20 + y2, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + # ensure the expected op is in the code + assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] + assert compute_error(y, y2) > 20 + + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), + reason="cpp kernels not built", + ) + @unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+") + @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + @common_utils.parametrize("bias", [True, False]) + @common_utils.parametrize("bs", [4, 128]) + def test_dynamic_float8_linear_ref(self, dtype, x_dim, bias, bs): + device = "cpu" + # the shape is not supported by cpp kernel, so the ref path will be used. + m = ToyLinearModel(120, 120, bias=bias).eval().to(dtype).to(device) + example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + y = m(*example_inputs) + + with torch.no_grad(): + quantize_( + m, + get_config(PerRow()), + ) + y1 = m(*example_inputs) + assert compute_error(y, y1) > 20 + y2, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + # ensure the expected op is in the code + assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] + assert compute_error(y, y2) > 20 + + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), + reason="cpp kernels not built", + ) + @common_utils.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) + def test_module_path(self, dtype): + linear = torch.nn.Linear(128, 256, dtype=dtype) + quantize_(linear, get_config(PerRow())) + self.assertEqual( + str(type(linear.weight)), + "", + ) + + with tempfile.NamedTemporaryFile() as f: + torch.save(linear.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + self.assertEqual( + str(type(state_dict["weight"])), + "", + ) + + +common_utils.instantiate_parametrized_tests(TestFloat8OpaqueTensor) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index f15d38576c..eed63d4b03 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -14,6 +14,7 @@ from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul from torchao.float8.types import FP8Granularity from torchao.quantization.granularity import ( + PerGroup, PerRow, PerTensor, ) @@ -204,28 +205,41 @@ def _normalize_granularity( list[FP8Granularity], ] ], + supported_granularities: tuple[FP8Granularity] = (PerTensor, PerRow), + support_different_granularities: bool = False, ) -> Tuple[FP8Granularity, FP8Granularity]: processed_granularity = None if granularity is None: processed_granularity = (PerTensor(), PerTensor()) - elif isinstance(granularity, (PerTensor, PerRow)): + elif isinstance(granularity, supported_granularities): processed_granularity = (granularity, granularity) elif isinstance(granularity, (tuple, list)) and len(granularity) == 2: if not ( - isinstance(granularity[0], (PerTensor, PerRow)) - and isinstance(granularity[1], (PerTensor, PerRow)) + isinstance(granularity[0], supported_granularities) + and isinstance(granularity[1], supported_granularities) ): raise ValueError( - f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported." + f"Invalid granularity types: {granularity}, only {supported_granularities} are supported." ) - if not isinstance(granularity[0], type(granularity[1])): + if not support_different_granularities and not isinstance( + granularity[0], type(granularity[1]) + ): raise ValueError( - f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported." + f"Different granularities for activation and weight are not supported: {granularity}, only {supported_granularities} are supported." ) + if isinstance(granularity[0], PerGroup): + if not isinstance(granularity[1], PerGroup): + raise ValueError( + "When granularity for activation is PerGroup, granularity for weight must be PerGroup, too." + ) + if granularity[0].group_size != granularity[1].group_size: + raise ValueError( + f"Group sizes for activation and weight must be the same, got {granularity[0].group_size} and {granularity[1].group_size}." + ) processed_granularity = tuple(granularity) else: raise ValueError( - f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported." + f"Invalid granularity specification: {granularity}, only {supported_granularities} are supported." ) return processed_granularity diff --git a/torchao/float8/types.py b/torchao/float8/types.py index b332a9629a..63cabc9582 100644 --- a/torchao/float8/types.py +++ b/torchao/float8/types.py @@ -12,8 +12,8 @@ from typing import TYPE_CHECKING, Union if TYPE_CHECKING: - from torchao.quantization.granularity import PerRow, PerTensor + from torchao.quantization.granularity import PerGroup, PerRow, PerTensor # Define FP8Granularity type alias to break circular import dependencies -FP8Granularity = Union["PerTensor", "PerRow"] +FP8Granularity = Union["PerTensor", "PerRow", "PerGroup"] diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index c8774e9426..a334cd80ac 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -90,6 +90,7 @@ quantize_affine, ) from .quantize_.workflows import ( + Float8OpaqueTensor, Float8Tensor, Int4MarlinSparseTensor, Int4OpaqueTensor, @@ -172,6 +173,7 @@ "Int4TilePackedTo4dTensor", "Float8Tensor", "Int4OpaqueTensor", + "Float8OpaqueTensor", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 15caddcadc..b5abb331e5 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -69,6 +69,8 @@ KernelPreference, ) from torchao.quantization.quantize_.workflows import ( + Float8OpaqueTensor, + Float8PackingFormat, Float8Tensor, Int4ChooseQParamsAlgorithm, Int4MarlinSparseTensor, @@ -93,6 +95,7 @@ ) from torchao.utils import ( _ConfigDeprecationWrapper, + check_cpu_version, is_MI300, is_sm_at_least_89, is_sm_at_least_90, @@ -1722,6 +1725,22 @@ def _input_activation_quant_func_fp8( return activation +def _input_activation_quant_cpu_fp8( + x: torch.Tensor, + activation_granularity: FP8Granularity, + activation_dtype: torch.dtype, +): + """Dynamic quantize activation to fp8 for CPU.""" + block_size = get_block_size(x.shape, activation_granularity) + return to_affine_quantized_floatx( + input_float=x, + block_size=block_size, + target_dtype=activation_dtype, + scale_dtype=torch.float32, + _layout=PlainLayout(), + ) + + def _fp8_mm_compat(weight: torch.Tensor) -> bool: """ Check if a weight tensor meets float8 quantization requirements. @@ -1780,6 +1799,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): kernel_preference: KernelPreference = KernelPreference.AUTO set_inductor_config: bool = True version: int = 2 + float8_packing_format: Float8PackingFormat = Float8PackingFormat.PLAIN def __post_init__(self): torch._C._log_api_usage_once( @@ -1787,8 +1807,16 @@ def __post_init__(self): ) if self.mm_config is None: self.mm_config = Float8MMConfig(use_fast_accum=True) + supported_granularities = () + if self.float8_packing_format == Float8PackingFormat.PLAIN: + supported_granularities = (PerTensor, PerRow) + elif self.float8_packing_format == Float8PackingFormat.OPAQUE: + supported_granularities = (PerTensor, PerRow, PerGroup) + support_different_granularities = ( + self.float8_packing_format == Float8PackingFormat.OPAQUE + ) activation_granularity, weight_granularity = _normalize_granularity( - self.granularity + self.granularity, supported_granularities, support_different_granularities ) self.granularity = [activation_granularity, weight_granularity] @@ -1807,17 +1835,12 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): activation_value_lb = config.activation_value_lb activation_value_ub = config.activation_value_ub kernel_preference = config.kernel_preference + float8_packing_format = config.float8_packing_format # Ensure works on device - _check_hardware_support(granularity) activation_granularity, weight_granularity = granularity - if not _fp8_mm_compat(weight): - # TODO(future PR): this should really throw an exception instead of silently - # not doing what the user asked - return weight - - if isinstance(weight_granularity, PerRow): + if weight.device.type != "cpu" and isinstance(weight_granularity, PerRow): assert weight.dtype == torch.bfloat16, ( "PerRow quantization only works for bfloat16 precision input weight" ) @@ -1827,6 +1850,12 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): "Config Deprecation: version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details" ) + _check_hardware_support(granularity) + if not _fp8_mm_compat(weight): + # TODO(future PR): this should really throw an exception instead of silently + # not doing what the user asked + return weight + block_size = get_block_size(weight.shape[-2:], weight_granularity) if weight.dim() == 3: block_size = tuple([1] + list(block_size)) @@ -1857,14 +1886,26 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): kernel_preference=kernel_preference, ) - quantized_weight = Float8Tensor.from_hp( - weight, - float8_dtype=weight_dtype, - granularity=weight_granularity, - mm_config=mm_config, - kernel_preference=kernel_preference, - act_quant_kwargs=act_quant_kwargs, - ) + if float8_packing_format == Float8PackingFormat.PLAIN: + quantized_weight = Float8Tensor.from_hp( + weight, + float8_dtype=weight_dtype, + granularity=weight_granularity, + mm_config=mm_config, + kernel_preference=kernel_preference, + act_quant_kwargs=act_quant_kwargs, + ) + elif float8_packing_format == Float8PackingFormat.OPAQUE: + block_size = get_block_size(weight.shape, weight_granularity) + quantized_weight = Float8OpaqueTensor.from_hp( + weight, + block_size=block_size, + act_quant_kwargs=act_quant_kwargs, + ) + else: + raise ValueError( + f"Unsupported float8 packing format: {float8_packing_format}" + ) return quantized_weight @@ -1873,16 +1914,19 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): def _float8_dynamic_activation_float8_weight_transform( module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig ): - assert is_sm_at_least_89() or is_MI300(), ( - "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" - ) - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - assert hasattr(module, "weight"), ( "applying float8 dynamic activation quant requires module to have weight attribute" + f"but {module} does not have one" ) + assert ( + check_cpu_version(module.weight.device, "2.6.0") + or is_sm_at_least_89() + or is_MI300() + ), ( + "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+ or on CPU with PyTorch >= 2.6.0" + ) + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor( module.weight, config ) diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 229c94c73a..2c033c6425 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -1,3 +1,7 @@ +from .float8.float8_opaque_tensor import ( + Float8OpaqueTensor, +) +from .float8.float8_packing_format import Float8PackingFormat from .float8.float8_tensor import ( Float8Tensor, QuantizeTensorToFloat8Kwargs, @@ -36,7 +40,9 @@ "Int4MarlinSparseTensor", "Int4PlainInt32Tensor", "Int4TilePackedTo4dTensor", + "Float8OpaqueTensor", "Float8Tensor", + "Float8PackingFormat", "QuantizeTensorToFloat8Kwargs", "Int4OpaqueTensor", "Int4ChooseQParamsAlgorithm", diff --git a/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py new file mode 100644 index 0000000000..39d949f02a --- /dev/null +++ b/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py @@ -0,0 +1,227 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List, Optional + +import torch + +from torchao.quantization.granularity import ( + PerGroup, +) +from torchao.quantization.observer import get_block_size +from torchao.quantization.quant_primitives import ( + _choose_scale_float8, + _quantize_affine_float8, +) +from torchao.utils import ( + TorchAOBaseTensor, +) + +from .float8_tensor import QuantizeTensorToFloat8Kwargs + +__all__ = [ + "Float8OpaqueTensor", +] + +aten = torch.ops.aten + + +class Float8OpaqueTensor(TorchAOBaseTensor): + """ + Float8 dynamic activation float8 weight on CPU. The weight tensor is reordered to a blocked layout + for better memory locality from [N, K] to [N/block_n, K/block_k, block_k, block_n], where block_n = 32 + and block_k depends on group-size for quantization (=32/64/128). And the innermost block with shape + [block_k, block_n] may be further reordered to VNNI layout depending on supported CPU ISA. + + Tensor Attributes: + qdata: Reordered float8 weight on CPU with shape = [N/block_n, K/block_k, block_k, block_n]. + scale: Scale tensor for weight, dtype = float32. For per-group/row quantization, shape = + [N / block_n, num_groups, block_n]. For per-tensor quantization, shape = [1]. + + Non-Tensor Attributes: + block_size: the block size for quantization, representing the granularity. for groupwise quantization, + block_size is (1, group_size). we only support group_size = 32/64/128. For per-row + quantization, blocks_size is (1, K). For per-tensor quantization, block_size is (N, K). + shape: shape of the original Tensor + act_quant_kwargs: the kwargs for from_hp + """ + + tensor_data_names = ["qdata", "scale"] + tensor_attribute_names = ["block_size", "act_quant_kwargs"] + + def __new__( + cls, + qdata, + scale, + block_size, + act_quant_kwargs, + ): + if qdata.ndim == 2: + shape = qdata.shape + else: + assert qdata.ndim == 4 + shape = torch.Size( + [qdata.size(0) * qdata.size(3), qdata.size(1) * qdata.size(2)] + ) + kwargs = {} + kwargs["device"] = qdata.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + qdata: torch.Tensor, + scale: torch.Tensor, + block_size: List[int], + act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None, + ): + self.qdata = qdata + self.scale = scale + self.block_size = block_size + self.act_quant_kwargs = act_quant_kwargs + + def _quantization_type(self): + return f"shape={self.shape}, block_size={self.block_size}, device={self.device}, {self.act_quant_kwargs=}" + + @classmethod + def from_hp( + cls, + hp_tensor: torch.Tensor, + block_size: List[int], + act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None, + ): + assert hp_tensor.ndim == 2 and hp_tensor.device.type == "cpu", ( + f"Expecting 2D tensor on CPU, but got: {hp_tensor.shape} on {hp_tensor.device.type}" + ) + assert len(block_size) == hp_tensor.ndim + N = hp_tensor.size(0) + K = hp_tensor.size(-1) + assert (block_size[0] == 1 or block_size[0] == N) and block_size[1] in ( + 32, + 64, + 128, + K, + ), f"Unsupported block_size: {block_size} for tensor shape {hp_tensor}" + assert act_quant_kwargs is not None, ( + "Activation quantization args must be provided for Float8OpaqueTensor" + ) + act_per_group_quant = isinstance(act_quant_kwargs.granularity, PerGroup) + wei_per_group_quant = block_size[1] < K + if act_per_group_quant: + group_size = act_quant_kwargs.granularity.group_size + if wei_per_group_quant: + # weight_tensor is also per group quantized + assert block_size[1] == group_size, ( + "input and weight should have the same group size but got" + f" {block_size[1]} and {group_size}" + ) + if act_per_group_quant or wei_per_group_quant: + assert N % 32 == 0, ( + f"Expecting out_features {N} to be multiple of 32, but got {N}" + ) + assert K % block_size[1] == 0, ( + f"Expecting in_features {K} to be multiple of group_size {block_size[1]}, but got {K}" + ) + scale = _choose_scale_float8( + hp_tensor, + float8_dtype=torch.float8_e4m3fn, + block_size=block_size, + ) + data = _quantize_affine_float8(hp_tensor, scale, torch.float8_e4m3fn) + # Pack weight from [N, K] to [N / block_n, K / block_k, block_k, block_n]. + # Pack scales from [N, num_groups] to [N / block_n, num_groups, block_n]. + packed_weight, packed_scale = torch.ops.torchao.float8_linear_prepack_cpu( + data, scale + ) + + return Float8OpaqueTensor( + qdata=packed_weight, + scale=packed_scale, + block_size=block_size, + act_quant_kwargs=act_quant_kwargs, + ) + + +implements = Float8OpaqueTensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + assert input_tensor.device.type == "cpu", ( + f"For CPU device only but got: {input_tensor.device}" + ) + assert isinstance(weight_tensor, Float8OpaqueTensor), ( + f"Expected weight_tensor to be Float8OpaqueTensor, got: {type(weight_tensor)}" + ) + assert weight_tensor.ndim in [2, 4] + assert input_tensor.size(-1) == weight_tensor.size(-1), ( + f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}" + ) + + act_mat = input_tensor.contiguous() + packed_weight = weight_tensor.qdata + scale = weight_tensor.scale + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + # reshape to 2D + act_mat = act_mat.reshape(-1, act_mat.shape[-1]) + + # activation float8 quantization + if ( + weight_tensor.act_quant_kwargs is not None + and weight_tensor.act_quant_kwargs.granularity is not None + ): + granularity = weight_tensor.act_quant_kwargs.granularity + if isinstance(granularity, PerGroup): + group_size = granularity.group_size + if weight_tensor.block_size[1] < weight_tensor.size(-1): + # weight_tensor is also per group quantized + assert weight_tensor.block_size[1] == group_size, ( + "input and weight should have the same group size but got" + f" {weight_tensor.block_size[1]} and {group_size}" + ) + act_block_size = get_block_size(act_mat.shape, granularity) + act_scale = _choose_scale_float8( + act_mat, + float8_dtype=torch.float8_e4m3fn, + block_size=act_block_size, + ) + act_mat = _quantize_affine_float8(act_mat, act_scale, torch.float8_e4m3fn) + else: + raise NotImplementedError( + "Activation quantization args not provided for Float8OpaqueTensor" + ) + + # float8 quantized linear operation + y = torch.ops.torchao.float8_linear_cpu.default( + act_mat, + act_scale, + packed_weight, + scale, + bias.float() if bias is not None else bias, # requires bias to be float + torch.float, # out_dtype + ) + + # remove out_feature padding + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + return y.to(orig_dtype) + + +Float8OpaqueTensor.__module__ = "torchao.quantization" + +# Allow a model with Float8OpaqueTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Float8OpaqueTensor]) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_packing_format.py b/torchao/quantization/quantize_/workflows/float8/float8_packing_format.py new file mode 100644 index 0000000000..04ae64241c --- /dev/null +++ b/torchao/quantization/quantize_/workflows/float8/float8_packing_format.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + + +# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum) +# after python 3.10 is end of life (https://devguide.python.org/versions/) +class Float8PackingFormat(str, Enum): + """Packing format for quantized data in Float8 Tensor subclasses in torchao, represents how + the values in quantized data are packed and laid out in memory. + """ + + """ + plain means the format that quantized Tensor data lays out elements in Tensor sequentially, + for example, for a Tensor of shape (4, 6): + a_0_0, a_0_1, ..., a_0_5, + ... + a_3_0, a_3_1, ..., a_3_5 + """ + PLAIN = "plain" + + """ + Opaque packing format that's used for tensors that does not have a predefined packing format + (that may be decided on hardware, tensor shape, library availability etc.) and it's not + needed for the rest of the system to understand the specific format that's adopted. + """ + OPAQUE = "opaque"