Skip to content

Commit 3d0da20

Browse files
committed
use LinearMMConfig
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 9346afd commit 3d0da20

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

test/float8/test_compile.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from torchao.float8.float8_tensor import (
3333
LinearMMConfig,
3434
GemmInputRole,
35+
ScaledMMConfig,
3536
)
3637
from torchao.float8.float8_utils import e4m3_dtype
3738

@@ -379,17 +380,40 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
379380
float8_config = Float8LinearConfig(
380381
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
381382
)
383+
linear_mm_config = LinearMMConfig(
384+
# output
385+
ScaledMMConfig(
386+
False,
387+
float8_config.gemm_config_output.use_fast_accum,
388+
False,
389+
float8_config.pad_inner_dim,
390+
),
391+
# grad_input
392+
ScaledMMConfig(
393+
False,
394+
float8_config.gemm_config_grad_input.use_fast_accum,
395+
False,
396+
float8_config.pad_inner_dim,
397+
),
398+
# grad_weight
399+
ScaledMMConfig(
400+
False,
401+
float8_config.gemm_config_grad_weight.use_fast_accum,
402+
False,
403+
float8_config.pad_inner_dim,
404+
),
405+
)
382406
float8_eager = hp_tensor_to_float8_dynamic(
383407
hp_tensor1,
384408
torch.float8_e4m3fn,
385-
float8_config,
409+
linear_mm_config,
386410
gemm_input_role=GemmInputRole.WEIGHT,
387411
)
388412
torch._dynamo.reset()
389413
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)(
390414
hp_tensor2,
391415
torch.float8_e4m3fn,
392-
float8_config,
416+
linear_mm_config,
393417
gemm_input_role=GemmInputRole.WEIGHT,
394418
)
395419
assert torch.equal(float8_eager._scale, float8_compile._scale)

torchao/float8/fsdp_utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,18 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
5959
return
6060

6161
# inf-norm is equivalent to max(abs(w))
62-
# keep consistent with float8_utils.amax_to_scale
63-
# torch.compile and eager show different numerics for 1.0 / float32,
64-
# upcast to float64 to ensure same numeric between compile and eager
6562
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
6663
amax_tensor = torch.stack(max_weights) # Partial
6764
# clamp is dispatched through DTensor
6865
# it will issue a single all-reduce
6966
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
70-
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor.to(torch.float64) # Replicate
71-
if amax_tensor.dtype is torch.float16:
67+
# keep consistent with float8_utils.amax_to_scale
68+
# torch.compile and eager show different numerics for 1.0 / float32,
69+
# upcast to float64 to ensure same numeric between compile and eager
70+
origin_dtype = amax_tensor.dtype
71+
amax_tensor = amax_tensor.to(torch.float64)
72+
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
73+
if origin_dtype is torch.float16:
7274
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
7375
local_scale_tensor = scale_tensor.to_local().to(torch.float32)
7476
for i, float8_linear in enumerate(float8_linears):

0 commit comments

Comments
 (0)