diff --git a/docs/source/api_ref_qat.rst b/docs/source/api_ref_qat.rst index 0179af2f3d..e0cacab667 100644 --- a/docs/source/api_ref_qat.rst +++ b/docs/source/api_ref_qat.rst @@ -26,10 +26,12 @@ Custom QAT APIs FakeQuantizeConfigBase IntxFakeQuantizeConfig + Float8FakeQuantizeConfig FakeQuantizedLinear FakeQuantizedEmbedding FakeQuantizerBase IntxFakeQuantizer + Float8FakeQuantizer linear.enable_linear_fake_quant linear.disable_linear_fake_quant diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index bb613d2c99..489ae2758b 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -14,17 +14,22 @@ import torch import torch.nn.functional as F -from parameterized import parameterized from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, +) from torchao import quantize_ -from torchao.float8.config import ScalingGranularity -from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic -from torchao.float8.float8_training_tensor import LinearMMConfig +from torchao.core.config import AOBaseConfig +from torchao.quantization import Float8Tensor from torchao.quantization.granularity import ( + Granularity, PerAxis, PerGroup, PerRow, + PerTensor, PerToken, ) from torchao.quantization.linear_quant_modules import ( @@ -43,11 +48,12 @@ FakeQuantizedEmbedding, ) from torchao.quantization.qat.fake_quantize_config import ( + Float8FakeQuantizeConfig, IntxFakeQuantizeConfig, ) from torchao.quantization.qat.fake_quantizer import ( + Float8FakeQuantizer, IntxFakeQuantizer, - _Float8RowwiseActivationFakeQuantizer, ) from torchao.quantization.qat.linear import ( FakeQuantizedLinear, @@ -58,10 +64,11 @@ from torchao.quantization.qat.utils import ( _fake_quantize_per_channel_group, _fake_quantize_per_token, - _Float8RowwiseFakeQuantize, _get_qmin_qmax, ) from torchao.quantization.quant_api import ( + Float8DynamicActivationFloat8WeightConfig, + Float8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt4WeightConfig, ) from torchao.quantization.quant_primitives import ( @@ -83,6 +90,10 @@ get_groupwise_affine_qparams, groupwise_affine_quantize_tensor, ) +from torchao.utils import ( + _is_fbgemm_genai_gpu_available, + is_sm_at_least_89, +) # TODO: put this in a common test utils file _CUDA_IS_AVAILABLE = torch.cuda.is_available() @@ -193,7 +204,7 @@ def forward(self, x): return x -class TestQAT(unittest.TestCase): +class TestQAT(TestCase): SEED = 123 def test_fake_quantize_per_channel_group(self): @@ -1420,7 +1431,7 @@ def test_qat_linear_bias(self): example_inputs = m.example_inputs() m(*example_inputs) - @parameterized.expand([(torch.float32,), (torch.bfloat16,), (torch.float16,)]) + @parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype): """ Test that the following produce the exact same numerics: @@ -1437,7 +1448,7 @@ def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype): baseline_out = per_token_dynamic_quant(x) torch.testing.assert_close(fake_quantizer_out, baseline_out, atol=0, rtol=0) - @parameterized.expand([(torch.float32,), (torch.bfloat16,), (torch.float16,)]) + @parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype): """ Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces @@ -1548,7 +1559,7 @@ def test_qat_8da4w_eps(self): actual_out = converted_model.linear1(x) torch.testing.assert_close(expected_out, actual_out, atol=0, rtol=0) - @parameterized.expand([(True,), (False,)]) + @parametrize("is_symmetric", [True, False]) def test_fake_quantizer_range_learning(self, is_symmetric): """ Test that range learning requires `IntxFakeQuantizer`s to be initialized correctly. @@ -1589,7 +1600,7 @@ def test_fake_quantizer_range_learning(self, is_symmetric): self.assertTrue(fake_quantizer.zero_point.requires_grad) fake_quantizer(*example_inputs) - @parameterized.expand([(True,), (False,)]) + @parametrize("is_symmetric", [True, False]) def test_qat_range_learning(self, is_symmetric): """ Test end-to-end QAT flow with range learning. @@ -1664,24 +1675,6 @@ def test_qat_range_learning(self, is_symmetric): self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0) self.assertFalse(torch.equal(new_weight, prev_weight)) - def test_float8_rowwise_fake_quantize(self): - """ - Test that `_Float8RowwiseFakeQuantize` is numerically close to `Float8TrainingTensor`. - """ - torch.manual_seed(self.SEED) - dtype = torch.float8_e4m3fn - x = torch.randn(32, 64) - axiswise_dim = 0 - out = _Float8RowwiseFakeQuantize.apply(x, dtype, axiswise_dim) - out_expected = hp_tensor_to_float8_dynamic( - x, - dtype, - LinearMMConfig(), - scaling_granularity=ScalingGranularity.AXISWISE, - axiswise_dim=axiswise_dim, - ).to_original_precision() - torch.testing.assert_close(out, out_expected, atol=0, rtol=0) - def test_qat_fp8a4w_quantizer(self): """ Test basic model training with `Float8ActInt4WeightQATQuantizer`. @@ -1693,7 +1686,8 @@ def test_qat_fp8a4w_quantizer(self): for linear in [m.linear1, m.sub.linear, m.linear2]: self.assertIsInstance(linear, FakeQuantizedLinear) self.assertIsInstance( - linear.activation_fake_quantizer, _Float8RowwiseActivationFakeQuantizer + linear.activation_fake_quantizer, + Float8FakeQuantizer, ) self.assertIsInstance(linear.weight_fake_quantizer, IntxFakeQuantizer) prev_weight = copy.deepcopy(m.linear1.weight) @@ -1833,6 +1827,113 @@ def test_qat_api_convert_no_quantization(self): baseline_out = baseline_model(*x2) torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) + def test_float8_fake_quantize_config(self): + """ + Test that the correct errors are thrown if `Float8FakeQuantizeConfig` is not instantiated properly. + """ + # OK + Float8FakeQuantizeConfig(torch.float8_e4m3fn) + Float8FakeQuantizeConfig(torch.float8_e4m3fn, PerRow()) + Float8FakeQuantizeConfig(torch.float8_e4m3fn, PerTensor()) + + with self.assertRaisesRegex(ValueError, "not a float8 dtype"): + Float8FakeQuantizeConfig(torch.int8) + with self.assertRaisesRegex( + ValueError, "Please specify the granularity object instead of the class" + ): + Float8FakeQuantizeConfig(granularity=PerRow) + with self.assertRaisesRegex( + ValueError, "Expected PerRow or PerTensor granularity" + ): + Float8FakeQuantizeConfig(granularity=PerToken()) + + @parametrize("granularity", [PerTensor(), PerRow()]) + def test_float8_fake_quantize(self, granularity: Granularity): + """ + Test that `Float8FakeQuantizer` is numerically close to `Float8Tensor`. + """ + dtype = torch.float8_e4m3fn + fq_config = Float8FakeQuantizeConfig(dtype, granularity) + fake_quantizer = Float8FakeQuantizer(fq_config) + torch.manual_seed(self.SEED) + x = torch.randn(32, 64) + out = fake_quantizer(x) + out_expected = Float8Tensor.to_float8(x, dtype, granularity).dequantize() + sqnr = compute_error(out, out_expected) + self.assertGreater(sqnr, 16) + + def _test_quantize_api_against_ptq( + self, + base_config: AOBaseConfig, + target_prepare_sqnr: float, + target_convert_sqnr: float, + ): + """ + Test the following: + + quantize_(model, QATConfig(base_config, step="prepare")) + quantize_(model, QATConfig(base_config, step="convert")) + + and compare model outputs of each step against: + + quantize_(model, base_config) + """ + torch.manual_seed(self.SEED) + m = M().to(torch.bfloat16).cuda() + example_inputs = (m.example_inputs()[0].to(torch.bfloat16).cuda(),) + + # baseline + m_baseline = copy.deepcopy(m) + quantize_(m_baseline, base_config) + out_baseline = m_baseline(*example_inputs) + + # compare prepare + quantize_(m, QATConfig(base_config, step="prepare")) + out_prepared = m(*example_inputs) + prepare_sqnr = compute_error(out_prepared, out_baseline) + self.assertGreaterEqual(prepare_sqnr, target_prepare_sqnr) + + # compare convert + quantize_(m, QATConfig(base_config, step="convert")) + out_converted = m(*example_inputs) + convert_sqnr = compute_error(out_converted, out_baseline) + self.assertGreaterEqual(convert_sqnr, target_convert_sqnr) + + @parametrize("granularity", [PerTensor(), PerRow()]) + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") + def test_quantize_api_fp8_fp8(self, granularity: Granularity): + """ + Test the following: + quantize_(model, QATConfig(Float8DynamicActivationFloat8Weight(), step="prepare")) + quantize_(model, QATConfig(Float8DynamicActivationFloat8Weight(), step="convert")) + """ + self._test_quantize_api_against_ptq( + Float8DynamicActivationFloat8WeightConfig(granularity=granularity), + target_prepare_sqnr=15, + target_convert_sqnr=float("inf"), + ) + + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") + @unittest.skipIf( + not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" + ) + def test_quantize_api_fp8_int4(self): + """ + Test the following: + quantize_(model, QATConfig(Float8DynamicActivationInt4WeightConfig(), step="prepare")) + quantize_(model, QATConfig(Float8DynamicActivationInt4WeightConfig(), step="convert")) + """ + self._test_quantize_api_against_ptq( + Float8DynamicActivationInt4WeightConfig(group_size=128), + target_prepare_sqnr=15, + target_convert_sqnr=float("inf"), + ) + + +instantiate_parametrized_tests(TestQAT) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/qat/__init__.py b/torchao/quantization/qat/__init__.py index 9a7338623d..4218c763e2 100644 --- a/torchao/quantization/qat/__init__.py +++ b/torchao/quantization/qat/__init__.py @@ -15,11 +15,13 @@ from .fake_quantize_config import ( FakeQuantizeConfig, FakeQuantizeConfigBase, + Float8FakeQuantizeConfig, IntxFakeQuantizeConfig, ) from .fake_quantizer import ( FakeQuantizer, FakeQuantizerBase, + Float8FakeQuantizer, IntxFakeQuantizer, ) from .linear import ( @@ -34,6 +36,8 @@ "QATStep", "FakeQuantizeConfigBase", "FakeQuantizerBase", + "Float8FakeQuantizeConfig", + "Float8FakeQuantizer", "IntxFakeQuantizeConfig", "IntxFakeQuantizer", "FakeQuantizedLinear", diff --git a/torchao/quantization/qat/fake_quantize_config.py b/torchao/quantization/qat/fake_quantize_config.py index 554ed2a065..167cb1f7a2 100644 --- a/torchao/quantization/qat/fake_quantize_config.py +++ b/torchao/quantization/qat/fake_quantize_config.py @@ -11,10 +11,17 @@ import torch from torchao.core.config import AOBaseConfig +from torchao.float8.config import e4m3_dtype +from torchao.float8.inference import ( + FP8Granularity, + _normalize_granularity, +) from torchao.quantization.granularity import ( Granularity, PerAxis, PerGroup, + PerRow, + PerTensor, PerToken, ) from torchao.quantization.quant_primitives import ( @@ -24,6 +31,7 @@ TorchAODType, ZeroPointDomain, ) +from torchao.utils import _is_float8_type from .utils import _log_deprecation_warning @@ -36,6 +44,39 @@ class FakeQuantizeConfigBase(abc.ABC): pass +@dataclass +class Float8FakeQuantizeConfig(FakeQuantizeConfigBase): + """ + Config for float8 fake quantization, targeting :class:`~torchao.quantization.Float8Tensor`. + + Args: + dtype (torch.dtype): the dtype for float8 Tensor + granularity (FP8Granularity): the granularity for the Tensor, currently either PerRow() or PerTensor() + hp_value_lb (Optional[float]): the lower bound for high precision floating point value for calculating scale + hp_value_ub (Optional[float]): the upper bound for high precision floating point value for calculating scale + """ + + dtype: torch.dtype = e4m3_dtype + granularity: FP8Granularity = PerRow() + hp_value_lb: Optional[float] = None + hp_value_ub: Optional[float] = None + + def __post_init__(self): + """ + Verify dtype and granularity are the ones we support. + """ + if not _is_float8_type(self.dtype): + raise ValueError(f"{self.dtype} is not a float8 dtype") + if isinstance(self.granularity, type): + raise ValueError( + "Please specify the granularity object instead of the class, e.g. PerRow() instead of PerRow" + ) + if type(self.granularity) not in [PerRow, PerTensor]: + raise ValueError( + f"Expected PerRow or PerTensor granularity, got {self.granularity}" + ) + + @dataclass class IntxFakeQuantizeConfig(FakeQuantizeConfigBase): """ @@ -279,6 +320,7 @@ def __post_init__(self): _log_deprecation_warning(self) +# TODO: rewrite using registration API? def _infer_fake_quantize_configs( base_config: AOBaseConfig, ) -> Tuple[Optional[FakeQuantizeConfigBase], Optional[FakeQuantizeConfigBase]]: @@ -291,6 +333,8 @@ def _infer_fake_quantize_configs( """ # avoid circular imports from torchao.quantization import ( + Float8DynamicActivationFloat8WeightConfig, + Float8DynamicActivationInt4WeightConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig, ) @@ -302,18 +346,45 @@ def _infer_fake_quantize_configs( is_symmetric=base_config.act_mapping_type == MappingType.SYMMETRIC, ) weight_config = IntxFakeQuantizeConfig( - dtype=TorchAODType.INT4, + dtype=torch.int4, group_size=base_config.group_size, is_symmetric=base_config.mapping_type == MappingType.SYMMETRIC, ) - return (act_config, weight_config) elif isinstance(base_config, Int4WeightOnlyConfig): + if base_config.version != 2: + raise ValueError(f"Only version 2 of {type(base_config)} is supported") + act_config = None + weight_config = IntxFakeQuantizeConfig( + dtype=torch.int4, + group_size=base_config.group_size, + is_symmetric=True, + ) + elif isinstance(base_config, Float8DynamicActivationFloat8WeightConfig): + if base_config.version != 2: + raise ValueError(f"Only version 2 of {type(base_config)} is supported") + (act_granularity, weight_granularity) = _normalize_granularity( + base_config.granularity + ) + act_config = Float8FakeQuantizeConfig( + dtype=base_config.activation_dtype, + granularity=act_granularity, + hp_value_lb=base_config.activation_value_lb, + hp_value_ub=base_config.activation_value_ub, + ) + weight_config = Float8FakeQuantizeConfig( + dtype=base_config.weight_dtype, + granularity=weight_granularity, + ) + elif isinstance(base_config, Float8DynamicActivationInt4WeightConfig): + act_config = Float8FakeQuantizeConfig( + dtype=torch.float8_e4m3fn, + granularity=PerRow(), + ) weight_config = IntxFakeQuantizeConfig( - dtype=torch.uint4, + dtype=torch.int4, group_size=base_config.group_size, - is_symmetric=False, - zero_point_domain=base_config.zero_point_domain, + is_symmetric=True, ) - return (None, weight_config) else: raise ValueError("Unexpected base config: %s" % base_config) + return (act_config, weight_config) diff --git a/torchao/quantization/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py index b63dbdb309..6f7e729f7d 100644 --- a/torchao/quantization/qat/fake_quantizer.py +++ b/torchao/quantization/qat/fake_quantizer.py @@ -13,10 +13,14 @@ PerGroup, PerToken, ) +from torchao.quantization.observer import get_block_size from torchao.quantization.quant_primitives import ( _DTYPE_TO_BIT_WIDTH, _DTYPE_TO_QVALUE_BOUNDS, MappingType, + _choose_scale_float8, + _dequantize_affine_float8, + _quantize_affine_float8, _Round, choose_qparams_affine, ) @@ -28,12 +32,12 @@ from .fake_quantize_config import ( FakeQuantizeConfigBase, + Float8FakeQuantizeConfig, IntxFakeQuantizeConfig, ) from .utils import ( _fake_quantize_per_channel_group, _fake_quantize_per_token, - _Float8RowwiseFakeQuantize, _log_deprecation_warning, ) @@ -55,10 +59,38 @@ def __repr__(self) -> str: def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase": if isinstance(config, IntxFakeQuantizeConfig): return IntxFakeQuantizer(config) + if isinstance(config, Float8FakeQuantizeConfig): + return Float8FakeQuantizer(config) else: raise ValueError(f"Unknown config type: {config}") +class Float8FakeQuantizer(FakeQuantizerBase): + """ + Generic module for applying float8 fake quantization to a tensor, as specified in the config. + """ + + def __init__(self, config: Float8FakeQuantizeConfig): + super().__init__() + self.config = config + + def forward(self, x: torch.Tensor) -> torch.Tensor: + original_dtype = x.dtype + block_size = get_block_size(x.shape, self.config.granularity) + scale = _choose_scale_float8( + x, + block_size, + self.config.dtype, + hp_value_lb=self.config.hp_value_lb, + hp_value_ub=self.config.hp_value_ub, + ) + q = _quantize_affine_float8( + x, scale, self.config.dtype, cast_to_float8_dtype=False + ) + dq = _dequantize_affine_float8(q, scale, original_dtype) + return dq + + class IntxFakeQuantizer(FakeQuantizerBase): """ Generic module for applying integer fake quantization to a tensor, as specified in the config. @@ -218,24 +250,3 @@ class FakeQuantizer(IntxFakeQuantizer): def __init__(self, config: FakeQuantizeConfigBase): super().__init__(config) _log_deprecation_warning(self) - - -# TODO: make this a FakeQuantizerBase -class _Float8RowwiseActivationFakeQuantizer(torch.nn.Module): - """ - Simple fake quantizer for float8 rowwise fake quantization, intended for activations only. - """ - - def __init__(self): - super().__init__() - self.enabled = True - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.enabled: - return _Float8RowwiseFakeQuantize.apply( - x, - torch.float8_e4m3fn, - -1, - ) - else: - return x diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index f94ec6f272..9c13ed1d95 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from torchao.dtypes.utils import is_device -from torchao.quantization.granularity import PerGroup +from torchao.quantization.granularity import PerGroup, PerRow from torchao.quantization.linear_quant_modules import ( Int8DynActInt4WeightLinear, WeightOnlyInt4Linear, @@ -28,11 +28,11 @@ from .fake_quantize_config import ( FakeQuantizeConfigBase, + Float8FakeQuantizeConfig, IntxFakeQuantizeConfig, ) from .fake_quantizer import ( FakeQuantizerBase, - _Float8RowwiseActivationFakeQuantizer, ) from .utils import ( _get_qmin_qmax, @@ -598,6 +598,10 @@ def __init__( weight_granularity = "per_group" else: weight_granularity = "per_channel" + self._activation_config = Float8FakeQuantizeConfig( + dtype=torch.float8_e4m3fn, + granularity=PerRow(), + ) self._weight_config = IntxFakeQuantizeConfig( dtype=torch.int4, granularity=weight_granularity, @@ -616,14 +620,11 @@ def prepare( """ for name, child in model.named_children(): if isinstance(child, torch.nn.Linear): - # TODO: add a config for float8? new_linear = FakeQuantizedLinear.from_linear( child, + activation_config=self._activation_config, weight_config=self._weight_config, ) - new_linear.activation_fake_quantizer = ( - _Float8RowwiseActivationFakeQuantizer() - ) setattr(model, name, new_linear) else: self.prepare(child) diff --git a/torchao/quantization/qat/utils.py b/torchao/quantization/qat/utils.py index e2f425a1d5..c5f339c945 100644 --- a/torchao/quantization/qat/utils.py +++ b/torchao/quantization/qat/utils.py @@ -18,38 +18,6 @@ ) -class _Float8RowwiseFakeQuantize(torch.autograd.Function): - """ - Implementation of float8 rowwise fake quantize with backward STE. - """ - - @staticmethod - def forward( - ctx: torch.autograd.function.FunctionCtx, - x: torch.Tensor, - float8_dtype: torch.dtype, - axiswise_dim: int, - ): - # compute rowwise scale based on `torchao.float8.float8_utils.tensor_to_scale` - eps = 1e-12 - amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True) - amax = amax.to(torch.float64) - scale = torch.finfo(float8_dtype).max / torch.clamp(amax, min=eps) - scale = scale.to(torch.float32) - - # fake quantize - max_value = torch.finfo(float8_dtype).max - x_fq = x.to(torch.float32) * scale - x_fq = x_fq.clamp(min=-max_value, max=max_value) - x_fq = x_fq.to(float8_dtype).to(x.dtype) - x_fq = x_fq / scale - return x_fq.to(x.dtype) - - @staticmethod - def backward(ctx, gy): - return gy, None, None - - def _fake_quantize_per_channel_group( input: torch.Tensor, scales: torch.Tensor, diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index ebd2c7ecd8..c118e0b4ce 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -2181,7 +2181,7 @@ def _choose_scale_float8( hp_value_ub: Optional[float] = None, ) -> torch.Tensor: """ - Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity. + Calculates float8 scaling factor for the given high precision tensor. Args: tensor (torch.Tensor): Input tensor to be quantized. @@ -2192,8 +2192,8 @@ def _choose_scale_float8( hp_value_ub (Optional[float]): the upper bound for high precision floating point value for calculating scale """ quant_max = torch.finfo(float8_dtype).max - # only tensorwise scaling is supported for now: if len(block_size) == 0: + # tensorwise max_abs = tensor.abs().max() if hp_value_lb is not None or hp_value_ub is not None: max_abs = torch.clamp(max_abs, min=hp_value_lb, max=hp_value_ub) @@ -2275,6 +2275,7 @@ def _quantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, float8_dtype: torch.dtype = torch.float8_e4m3fn, + cast_to_float8_dtype: bool = True, ) -> torch.Tensor: """ Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor. @@ -2287,10 +2288,12 @@ def _quantize_affine_float8( tensor_scaled = tensor_fp32 / scale_expanded max_value = torch.finfo(float8_dtype).max tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value) - fp8_tensor = tensor_clamped.to(float8_dtype) - return fp8_tensor + if cast_to_float8_dtype: + tensor_clamped = tensor_clamped.to(float8_dtype) + return tensor_clamped +# TODO: don't register as custom op? @_register_custom_op(quant_lib, False) def _dequantize_affine_float8( tensor: torch.Tensor,