From bf49f410dc8e99807a92ab9c3686af52de2bc4c1 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Wed, 18 Sep 2024 20:06:39 -0700 Subject: [PATCH] [float8] fuse abs/max with torch.linalg.vector_norm --- torchao/float8/float8_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 54613e5b05..0d22c95aba 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -99,7 +99,7 @@ def amax_history_to_scale_stack( @torch.no_grad() def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: - amax = torch.max(torch.abs(x)) + amax = torch.linalg.vector_norm(x, ord=float("inf")) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will