Skip to content

Commit b714026

Browse files
authored
float8 with delayed scaling: fix autocast handling (#1306)
Summary: Fixes a bug with delayed scaling + autocast. Before, the last input dtype when in autocast was queried from the input to `torch._scaled_mm`: ``` x_hp -> {query_dtype_here} -> to_autocast -> torch._scaled_mm ``` This is incorrect because the dtype was saved from before the place where autocast could change it. This happened to work if `x_hp` was already of the correct dtype, but did not work in cases such as the new test case added in this PR, or real models such as the repro from #1297. The reason we haven't caught this for so long is we've been using FSDP's mixed precision and not single-GPU autocast. The fix I'm taking here is to query the original post-autocast dtype based on the output of `torch._scaled_mm`. Since this dtype is based on the dtype of the input to `torch._scaled_mm`, this will properly capture autocasting: ``` x_hp -> to_autocast -> x_autocast_dtype -> to_fp8 -> x_fp8 -> torch._scaled_mm -> {query_dtype_here} ``` Test Plan: ``` // first, test the updated test case - it passes // second - test a modified version of the repro in // #1297: // code: https://gist.github.com/vkuzo/6c53a1deca19856238d38746b1e52ee7 // logs: https://gist.github.com/vkuzo/60846b1f6b2822f91d2dfa67cab10a10 // we now see a speedup with float8 ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 26648c2 commit b714026

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

test/float8/test_base.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -424,33 +424,36 @@ def test_autocast_outputs(
424424
emulate: bool,
425425
linear_dtype: torch.dtype,
426426
):
427-
m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
427+
m_ref = nn.Sequential(
428+
nn.Linear(32, 32, device="cuda", dtype=linear_dtype),
429+
nn.Linear(32, 32, device="cuda", dtype=linear_dtype),
430+
)
428431
config = Float8LinearConfig(
429432
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
430433
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
431434
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
432435
emulate=emulate,
433436
)
434-
m = Float8Linear.from_float(copy.deepcopy(m_ref), config)
437+
m = convert_to_float8_training(copy.deepcopy(m_ref), config=config)
435438

436439
# autocast off
437440
x = torch.randn(16, 32, device="cuda", dtype=linear_dtype)
441+
y = m(x)
438442
if linear_requires_sync(config):
439443
sync_float8_amax_and_scale_history(m)
440-
y = m(x)
441444
assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}"
442445

443446
# autocast on
444447
with torch.autocast("cuda"):
448+
y = m(x)
445449
if linear_requires_sync(config):
446450
sync_float8_amax_and_scale_history(m)
447-
y = m(x)
448451
assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}"
449452

450453
with torch.autocast("cuda", dtype=torch.bfloat16):
454+
y = m(x)
451455
if linear_requires_sync(config):
452456
sync_float8_amax_and_scale_history(m)
453-
y = m(x)
454457
assert (
455458
y.dtype == torch.bfloat16
456459
), f"y.dtype is {y.dtype}, expected {torch.bfloat16}"

torchao/float8/float8_linear.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def __init__(self, *args, **kwargs):
336336

337337
# This is needed to properly handle autocast in the amax/scale
338338
# update function for torch.float16
339-
self.last_seen_input_dtype = None
339+
self.last_seen_output_dtype = None
340340

341341
# pre_forward and post_forward are currently broken with FSDP
342342
# and torch.compile, this option can disable them
@@ -538,11 +538,14 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
538538
return output
539539

540540
def float8_pre_forward(self, input):
541+
# TODO(future PR): deprecate these functions and the corresponding
542+
# config setting
541543
if not self.enable_pre_and_post_forward:
542544
return
543-
self.last_seen_input_dtype = input.dtype
544545

545546
def float8_post_forward(self):
547+
# TODO(future PR): deprecate these functions and the corresponding
548+
# config setting
546549
if not self.enable_pre_and_post_forward:
547550
return
548551
self.is_amax_initialized = True
@@ -624,6 +627,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
624627

625628
if self.has_any_delayed_scaling:
626629
self.float8_post_forward()
630+
self.last_seen_output_dtype = output.dtype
627631
return output
628632

629633
def extra_repr(self):

torchao/float8/float8_linear_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def inner_func():
236236
fp8_weight_amax_history_stack[idx] = child.fp8_amax_history_weight
237237
fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output
238238

239-
x_dtypes.add(child.last_seen_input_dtype)
239+
x_dtypes.add(child.last_seen_output_dtype)
240240
scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name)
241241

242242
# TODO This way to get the activation dtype is not ideal

0 commit comments

Comments
 (0)