@@ -59,16 +59,18 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
5959        return 
6060
6161    # inf-norm is equivalent to max(abs(w)) 
62-     # keep consistent with float8_utils.amax_to_scale 
63-     # torch.compile and eager show different numerics for 1.0 / float32, 
64-     # upcast to float64 to ensure same numeric between compile and eager 
6562    max_weights  =  torch ._foreach_norm (weights , ord = math .inf )  # Partial 
6663    amax_tensor  =  torch .stack (max_weights )  # Partial 
6764    # clamp is dispatched through DTensor 
6865    # it will issue a single all-reduce 
6966    amax_tensor  =  torch .clamp (amax_tensor , EPS )  # Replicate 
70-     scale_tensor  =  torch .finfo (torch .float8_e4m3fn ).max  /  amax_tensor .to (torch .float64 )  # Replicate 
71-     if  amax_tensor .dtype  is  torch .float16 :
67+     # keep consistent with float8_utils.amax_to_scale 
68+     # torch.compile and eager show different numerics for 1.0 / float32, 
69+     # upcast to float64 to ensure same numeric between compile and eager 
70+     origin_dtype  =  amax_tensor .dtype 
71+     amax_tensor  =  amax_tensor .to (torch .float64 )
72+     scale_tensor  =  torch .finfo (torch .float8_e4m3fn ).max  /  amax_tensor   # Replicate 
73+     if  origin_dtype  is  torch .float16 :
7274        scale_tensor  =  torch .clamp (scale_tensor , max = torch .finfo (torch .float16 ).max )
7375    local_scale_tensor  =  scale_tensor .to_local ().to (torch .float32 )
7476    for  i , float8_linear  in  enumerate (float8_linears ):
0 commit comments