From c58982f76bed7001c8151388826f341f175cb74d Mon Sep 17 00:00:00 2001 From: 103yiran <1039105206@qq.com> Date: Thu, 26 Jun 2025 19:57:59 +0800 Subject: [PATCH] fix scale shape --- torchao/prototype/mx_formats/mx_tensor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index e98878af77..b025f5b0dc 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -331,6 +331,8 @@ def to_mx( raise AssertionError("unsupported") scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu) + scale_shape = [*orig_shape[:-1], orig_shape[-1] // block_size] + scale_e8m0_biased = scale_e8m0_biased.reshape(scale_shape) return scale_e8m0_biased, data_lp