diff --git a/benchmarks/float8/bench_linear_float8.py b/benchmarks/float8/bench_linear_float8.py index d160d7241d..a7b1e17934 100644 --- a/benchmarks/float8/bench_linear_float8.py +++ b/benchmarks/float8/bench_linear_float8.py @@ -23,10 +23,6 @@ ScalingType, ) from torchao.float8.float8_linear import Float8Linear -from torchao.float8.float8_linear_utils import ( - linear_requires_sync, - sync_float8_amax_and_scale_history, -) from torchao.float8.float8_tensor import ScaledMMConfig # estimating TOPs for matmuls in fp32, fp16, fp8 @@ -122,39 +118,18 @@ def main( scaling_type_grad_output = ScalingType(scaling_type_grad_output) scaling_granularity = ScalingGranularity(scaling_granularity) - if scaling_type_input is ScalingType.STATIC: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - static_scale=torch.tensor([1.0], device="cuda"), - scaling_granularity=scaling_granularity, - ) - else: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - ) - if scaling_type_weight is ScalingType.STATIC: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - static_scale=torch.tensor([1.0], device="cuda"), - scaling_granularity=scaling_granularity, - ) - else: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity, - ) - if scaling_type_grad_output is ScalingType.STATIC: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - static_scale=torch.tensor([1.0], device="cuda"), - scaling_granularity=scaling_granularity, - ) - else: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity, - ) + cast_config_input = CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) + cast_config_weight = CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) + cast_config_grad_output = CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, @@ -185,7 +160,7 @@ def main( copy.deepcopy(linear_ref), config=config, ) - scaling_repr = f"{linear_float8.scaling_type_repr()},{linear_float8.scaling_granularity_repr()}" + scaling_repr = linear_float8.extra_repr() if fast_accum: linear_float8.forward_config = ScaledMMConfig(False, True, False) @@ -196,8 +171,6 @@ def main( ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward() def float8_forw_backward(): - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(linear_float8) linear_float8(input_tensor).sum().backward() def n_times(n, fn, *args, **kwargs): diff --git a/benchmarks/float8/bench_multi_gpu.py b/benchmarks/float8/bench_multi_gpu.py deleted file mode 100644 index 34a690edbe..0000000000 --- a/benchmarks/float8/bench_multi_gpu.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -import os -from typing import Callable - -import fire -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -import torch.nn as nn -import torch.utils.benchmark as benchmark -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType -from torchao.float8.float8_linear_utils import ( - convert_to_float8_training, - sync_float8_amax_and_scale_history, -) - -torch.manual_seed(0) - -# TODO: Add more shapes for the benchmark -B, M, K, N = 32, 1024, 1024, 1024 -lr = 0.01 - -config = Float8LinearConfig( - cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), -) - - -def benchmark_torch_function_in_microseconds( - func: Callable, - *args, - **kwargs, -) -> float: - t0 = benchmark.Timer( - stmt="func(*args, **kwargs)", - globals={"args": args, "kwargs": kwargs, "func": func}, - ) - return t0.blocked_autorange().median * 1e6 - - -def setup(rank, world_size): - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12355" - - # initialize the process group - dist.init_process_group("nccl", rank=rank, world_size=world_size) - - -def cleanup(): - dist.destroy_process_group() - - -def get_model(K, N, is_fp8, base_dtype=torch.float32): - modules = [ - nn.Linear(K, N, dtype=base_dtype), - nn.ReLU(), - ] - N_LAYERS = 20 - # N linear layers - for _ in range(N_LAYERS - 1): - modules.append(nn.Linear(N, N, dtype=base_dtype)) - modules.append(nn.ReLU()) - m = nn.Sequential(*modules) - if is_fp8: - convert_to_float8_training( - m, - config=config, - ) - return m - - -def fsdp_main(rank, world_size, args): - setup(rank, world_size) - torch.cuda.set_device(rank) - - base_dtype, input_global, compile = args - - # basic distributed data sampling - assert B % world_size == 0 - bsz_local_start = int(rank / world_size * B) - bsz_local_end = int((rank + 1) / world_size * B) - input_tensor = input_global[bsz_local_start:bsz_local_end].to(rank) - - fp8_model = get_model(K, N, is_fp8=True, base_dtype=base_dtype).to(rank) - # Need use_orig_params=True to compile FSDP - fp8_model = FSDP(fp8_model, use_orig_params=True) - fp8_optimizer = torch.optim.SGD(fp8_model.parameters(), lr=lr * world_size) - - # Run one iteration to make compile work, see experiments doc for more context of this issue. - fp8_optimizer.zero_grad() - y_local = fp8_model(input_tensor) - y_local.sum().backward() - fp8_optimizer.step() - sync_float8_amax_and_scale_history(fp8_model) - - sync_float8_func = sync_float8_amax_and_scale_history - if compile: - # TODO: Need to fix issues with compile - fp8_model = torch.compile(fp8_model) - sync_float8_func = torch.compile(sync_float8_amax_and_scale_history) - - def float8_forw_backward(): - fp8_optimizer.zero_grad() - y_local = fp8_model(input_tensor) - y_local.sum().backward() - fp8_optimizer.step() - sync_float8_func(fp8_model) - - ref_model = get_model(K, N, is_fp8=False, base_dtype=base_dtype).to(rank) - ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=lr * world_size) - if compile: - ref_model = torch.compile(ref_model) - - ref_model = FSDP(ref_model, use_orig_params=True) - - def ref_forw_backward(): - ref_optimizer.zero_grad() - ref_model(input_tensor).sum().backward() - ref_optimizer.step() - - def run_n_iterations(n, fn): - for _ in range(n): - fn() - # make sure training is done on all ranks - dist.barrier() - - # warmup - run_n_iterations(50, ref_forw_backward) - run_n_iterations(50, float8_forw_backward) - - N_ITER = 50 - ref_time = ( - benchmark_torch_function_in_microseconds( - run_n_iterations, N_ITER, ref_forw_backward - ) - * 1e-6 - / N_ITER - ) - float8_time = ( - benchmark_torch_function_in_microseconds( - run_n_iterations, N_ITER, float8_forw_backward - ) - * 1e-6 - / N_ITER - ) - - if rank == 0: - print("ref_time", ref_time) - print("float8_time", float8_time) - print("float8 speedup", ref_time / float8_time) - - cleanup() - - -def run(compile: bool): - base_dtype = torch.bfloat16 - WORLD_SIZE = torch.cuda.device_count() - print(f"{base_dtype = }") - print(f"{compile = }") - print(f"{WORLD_SIZE = }") - - # generate input data - ref_input = torch.randn(B, M, K).cuda().to(base_dtype) - # run fsdp model - args = (base_dtype, ref_input, compile) - mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) - - -# Usgae: -# CUDA_VISIBLE_DEVICES=0,1 python benchmarks/bench_multi_gpu.py -if __name__ == "__main__": - fire.Fire(run) diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index 684ed0af2a..6f30e5eff7 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -58,9 +58,7 @@ ) from torchao.float8 import ( - CastConfig, Float8LinearConfig, - ScalingType, convert_to_float8_training, ) from torchao.float8.roofline_utils import ( @@ -219,24 +217,6 @@ def run( scaling_type_weight="dynamic", scaling_type_grad_output="dynamic", ) - fp8_mem_time_sympy_del_limit = get_float8_mem_sympy( - M, - K, - N, - model_torch_compile_limitations=True, - scaling_type_input="delayed", - scaling_type_weight="delayed", - scaling_type_grad_output="delayed", - ) - fp8_mem_time_sympy_del_nolimit = get_float8_mem_sympy( - M, - K, - N, - model_torch_compile_limitations=False, - scaling_type_input="delayed", - scaling_type_weight="delayed", - scaling_type_grad_output="delayed", - ) if gemm_time_strategy == "roofline": bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16) @@ -258,16 +238,12 @@ def run( # roofline memory overhead estimates "fp8_oh_dyn_limit", "fp8_oh_dyn_nolimit", - "fp8_oh_del_limit", - "fp8_oh_del_nolimit", # actual e2e measurements "bf16_s", "fp8_dyn_s", - "fp8_del_s", "fp8_dyn_axs_s", # 'fp8_lw_s', "fp8_dyn_sp", - "fp8_del_sp", "fp8_dyn_axs_sp", # 'fp8_lw_sp', ] @@ -309,12 +285,6 @@ def run( fp8_mem_time_dyn_nolimit_s = ( fp8_mem_time_sympy_dyn_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val) ) - fp8_mem_time_del_limit_s = ( - fp8_mem_time_sympy_del_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val) - ) - fp8_mem_time_del_nolimit_s = ( - fp8_mem_time_sympy_del_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val) - ) # create the model m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16() @@ -333,19 +303,6 @@ def run( m_fp8_dyn = torch.compile(m_fp8_dyn) fp8_dyn_time_actual_s = get_gpu_kernel_time(m_fp8_dyn, x) - # get the float8 delayed scaling gpu kernel time - torch._dynamo.reset() - config = Float8LinearConfig( - enable_amax_init=False, - enable_pre_and_post_forward=False, - cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), - ) - m_fp8_del = convert_to_float8_training(copy.deepcopy(m_orig), config=config) - m_fp8_del = torch.compile(m_fp8_del) - fp8_del_time_actual_s = get_gpu_kernel_time(m_fp8_del, x) - # get the float8 dynamic axiswise scaling gpu kernel time torch._dynamo.reset() config = Float8LinearConfig.from_recipe_name("rowwise") @@ -374,16 +331,12 @@ def run( # roofline overhead estimates fp8_mem_time_dyn_limit_s, fp8_mem_time_dyn_nolimit_s, - fp8_mem_time_del_limit_s, - fp8_mem_time_del_nolimit_s, # e2e numbers bf16_time_actual_s, fp8_dyn_time_actual_s, - fp8_del_time_actual_s, fp8_dyn_axs_time_actual_s, # fp8_lw_time_actual_s, bf16_time_actual_s / fp8_dyn_time_actual_s, - bf16_time_actual_s / fp8_del_time_actual_s, bf16_time_actual_s / fp8_dyn_axs_time_actual_s, # bf16_time_actual_s / fp8_lw_time_actual_s, ] diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index 687684d4e2..e28ed6dcc2 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -33,19 +33,15 @@ kernel_name_to_category, parse_bw_and_kernel_name, profiler_output_to_filtered_time_by_kernel_name, - profiler_output_to_gpu_time_for_key, update_triton_kernels_in_prof_chome_trace_with_torch_logs, ) -from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes from torchao.float8.config import ( Float8LinearConfig, ScalingType, ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, - linear_requires_sync, - sync_float8_amax_and_scale_history, ) from torchao.testing.float8.test_utils import get_test_float8_linear_config @@ -286,9 +282,7 @@ def main( model_type: str = "linear", dtype_filter: str = "both", add_inductor_metadata_to_trace: bool = True, - enable_sync_amax_history: bool = True, enable_activation_checkpointing: bool = False, - enable_float8_delayed_scaling_inductor_passes: bool = False, ): assert model_type in ( "linear", @@ -325,12 +319,6 @@ def main( print( f"enable_activation_checkpointing is set to {enable_activation_checkpointing}" ) - print( - f"enable_float8_delayed_scaling_inductor_passes is set to {enable_float8_delayed_scaling_inductor_passes}" - ) - - if enable_float8_delayed_scaling_inductor_passes: - _prototype_register_float8_delayed_scaling_inductor_passes() device = "cuda" ref_dtype = torch.bfloat16 @@ -388,17 +376,9 @@ def float8_forw(x): out = m_float8(x) return out - sync_amax_history = sync_float8_amax_and_scale_history - def float8_forw_backward_wrapper(x): - # sync_float8_amax_and_scale_history is not full graph torch - # compile friendly, so we add a high level wrapper to allow - # inspection of the fw+bw torch.compile without the scale - # syncing code - # TODO(future): make this better - if linear_requires_sync(config) and enable_sync_amax_history: - with record_function("scale_amax_and_scales"): - sync_amax_history(m_float8) + # TODO(future PR): this wrapper is for delayed scaling, we can clean it + # up now that delayed scaling is deprecated. out = float8_forw(x) # out.sum().backward() is also not torch.compile fullgraph @@ -409,11 +389,6 @@ def float8_forw_backward_wrapper(x): if compile: m_ref = torch.compile(m_ref, fullgraph=True) float8_forw = torch.compile(float8_forw, fullgraph=True) - # Note: it's faster to compile the combination of sync_amax_history wit - # forward because we only look up from dynamo cache once. - # However, compiling the sync function separately makes it more - # convenient to analyze the total time spent on it. - sync_amax_history = torch.compile(sync_amax_history) # if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output # to populate triton kernel bandwidth further down in the script @@ -529,13 +504,6 @@ def float8_forw_backward_wrapper(x): ] ) - # get the time spent per user annotation - sync_time_us = profiler_output_to_gpu_time_for_key( - p, "scale_amax_and_scales" - ) - sync_time_ms = sync_time_us / profile_iters / 1e3 - print(f"Sync time ms: {sync_time_ms}") - finally: if f is not None: # print the redirected stdout back to regular stdout @@ -586,14 +554,6 @@ def float8_forw_backward_wrapper(x): df_p["f8_div_ref"] = df_p["1_float8"] / df_p["0_ref"] df_p["ref_div_f8"] = df_p["0_ref"] / df_p["1_float8"] - # calculate sync time as pct of total float time - # note: this time is not useful if TORCHINDUCTOR_PROFILE is on - total_float8_ms = df_p.iloc[3]["1_float8"] - sync_approx_ratio = sync_time_ms / total_float8_ms - print( - f"\nFloat8 amax/scale sync approx ratio of total time: {sync_approx_ratio:.3f}" - ) - print("\nSummary of time (ms) by kernel category\n\n", df_p) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 156c8abe87..64a96b47de 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -25,7 +25,6 @@ from torchao.float8.config import ( - CastConfig, Float8LinearConfig, Float8LinearRecipeName, ScalingGranularity, @@ -36,8 +35,6 @@ from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( convert_to_float8_training, - linear_requires_sync, - sync_float8_amax_and_scale_history, ) from torchao.float8.float8_python_api import addmm_float8_unwrapped from torchao.float8.float8_scaling_utils import ( @@ -54,11 +51,9 @@ from torchao.float8.float8_utils import ( FP8_TYPES, compute_error, - config_has_stateful_scaling, fp8_tensor_statistics, tensor_to_scale, ) -from torchao.float8.stateful_float8_linear import StatefulFloat8Linear from torchao.testing.float8.test_utils import get_test_float8_linear_config random.seed(0) @@ -284,16 +279,10 @@ def _test_linear_impl( config: Float8LinearConfig, use_ac: bool = False, ): - if config_has_stateful_scaling(config): - m_fp8 = StatefulFloat8Linear.from_float( - copy.deepcopy(m_ref), - config, - ) - else: - m_fp8 = Float8Linear.from_float( - copy.deepcopy(m_ref), - config, - ) + m_fp8 = Float8Linear.from_float( + copy.deepcopy(m_ref), + config, + ) for _ in range(2): if use_ac: @@ -301,8 +290,6 @@ def _test_linear_impl( else: y_fp8 = m_fp8(x) y_fp8.sum().backward() - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(m_fp8) if use_ac: y_ref = torch.utils.checkpoint.checkpoint(m_ref, x, use_reentrant=False) @@ -320,65 +307,21 @@ def _test_linear_impl( if m_ref.bias is not None: torch.testing.assert_close(m_ref.bias.grad, m_fp8.bias.grad) - # verify all of the amax buffers got updated - if linear_requires_sync(config): - # only check buffers that are actually used, based on per-tensor - # scaling settings - amax_buffer_names = [] - amax_history_buffer_names = [] - scale_buffer_names = [] - if config.cast_config_input.scaling_type is ScalingType.DELAYED: - amax_buffer_names.append("fp8_amax_input") - amax_history_buffer_names.append("fp8_amax_history_input") - scale_buffer_names.append("fp8_scale_input") - if config.cast_config_weight.scaling_type is ScalingType.DELAYED: - amax_buffer_names.append("fp8_amax_weight") - amax_history_buffer_names.append("fp8_amax_history_weight") - scale_buffer_names.append("fp8_scale_weight") - if config.cast_config_grad_output.scaling_type is ScalingType.DELAYED: - amax_buffer_names.append("fp8_amax_grad_output") - amax_history_buffer_names.append("fp8_amax_history_grad_output") - scale_buffer_names.append("fp8_scale_grad_output") - - # verify all of the amax buffers got updated - max_float8_pos = {torch.finfo(dtype).max for dtype in FP8_TYPES} - for buffer_name in amax_buffer_names: - buffer_value = getattr(m_fp8, buffer_name) - for init_val in max_float8_pos: - assert torch.ne( - buffer_value, torch.tensor(init_val) - ), f"{buffer_name} not filled, current value {buffer_value}" - - # verify all of the amax history buffers got updated - for buffer_name in amax_history_buffer_names: - buffer_value = getattr(m_fp8, buffer_name) - assert torch.max(buffer_value) > 0.0, f"{buffer_name} not filled" - - # verify all of the scale buffers got updated - for buffer_name in scale_buffer_names: - buffer_value = getattr(m_fp8, buffer_name) - assert torch.ne( - buffer_value, torch.tensor(1.0) - ), f"{buffer_name} not filled, current value {buffer_value}" - - # verify initialization flags got updated - assert m_fp8.is_amax_initialized, "Amax was not properly initialized" - @pytest.mark.parametrize( "emulate", [True, False] if is_sm_at_least_89() else [True] ) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize( "scaling_type_input", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize( "scaling_type_weight", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize( "scaling_type_grad_output", - [ScalingType.DELAYED, ScalingType.DYNAMIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("linear_bias", [False, True]) @@ -465,9 +408,6 @@ def test_autocast_outputs( nn.Linear(32, 32, device="cuda", dtype=linear_dtype), ) config = Float8LinearConfig( - cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), emulate=emulate, ) m = convert_to_float8_training(copy.deepcopy(m_ref), config=config) @@ -475,21 +415,15 @@ def test_autocast_outputs( # autocast off x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) y = m(x) - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(m) assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}" # autocast on with torch.autocast("cuda"): y = m(x) - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(m) assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}" with torch.autocast("cuda", dtype=torch.bfloat16): y = m(x) - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(m) assert ( y.dtype == torch.bfloat16 ), f"y.dtype is {y.dtype}, expected {torch.bfloat16}" @@ -508,40 +442,18 @@ def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): # Cast the module to dtype m = m.to(dtype=linear_dtype) - if linear_requires_sync(config): - # Check amax buffer types - for key in [ - "fp8_amax_input", - "fp8_amax_history_input", - "fp8_scale_input", - "fp8_amax_weight", - "fp8_amax_history_weight", - "fp8_scale_weight", - "fp8_amax_grad_output", - "fp8_amax_history_grad_output", - "fp8_scale_grad_output", - ]: - assert ( - m._buffers[key].dtype == torch.float32 - ), f"{key}.dtype is {m._buffers[key].dtype}, expected torch.float32" # autocast off x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(m) y = m(x) assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}" # autocast on with torch.autocast("cuda"): - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(m) y = m(x) assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}" with torch.autocast("cuda", dtype=torch.bfloat16): - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(m) y = m(x) assert ( y.dtype == torch.bfloat16 @@ -550,7 +462,6 @@ def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): def test_repr(self): m = nn.Linear(32, 16) config = Float8LinearConfig( - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), emulate=True, ) m = Float8Linear.from_float( @@ -558,7 +469,7 @@ def test_repr(self): config=config, ) s = m.__repr__() - assert "i:dyn_ten_e4m3,w:del_ten_e4m3,go:dyn_ten_e5m2" in s + assert "i:dyn_ten_e4m3,w:dyn_ten_e4m3,go:dyn_ten_e5m2" in s @unittest.skipIf(not is_sm_at_least_89(), "CUDA 8.9 not available") def test_inference_mode(self): diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 0c02db26a6..7c31bf6f08 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -7,7 +7,6 @@ import random import sys import unittest -from dataclasses import replace from io import StringIO import pytest @@ -26,7 +25,6 @@ from torch._dynamo.test_case import TestCase as DynamoTestCase from torch._dynamo.testing import CompileCounterWithBackend -from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes from torchao.float8.config import ( CastConfig, Float8LinearConfig, @@ -35,20 +33,11 @@ e4m3_dtype, ) from torchao.float8.float8_linear import Float8Linear -from torchao.float8.float8_linear_utils import ( - convert_to_float8_training, - get_float8_layers, - sync_float8_amax_and_scale_history, -) from torchao.float8.float8_scaling_utils import ( - hp_tensor_to_float8_delayed, hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig -from torchao.float8.float8_utils import config_has_stateful_scaling -from torchao.float8.stateful_float8_linear import StatefulFloat8Linear from torchao.testing.float8.test_utils import get_test_float8_linear_config -from torchao.utils import is_fbcode def _test_compile_base( @@ -66,16 +55,10 @@ def _test_compile_base( x_ref = copy.deepcopy(x) m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) - if config_has_stateful_scaling(config): - m_fp8 = StatefulFloat8Linear.from_float( - copy.deepcopy(m_ref), - config, - ) - else: - m_fp8 = Float8Linear.from_float( - copy.deepcopy(m_ref), - config, - ) + m_fp8 = Float8Linear.from_float( + copy.deepcopy(m_ref), + config, + ) m_fp8 = torch.compile(m_fp8, backend=backend, fullgraph=fullgraph) m_ref = torch.compile(m_ref, backend=backend, fullgraph=fullgraph) @@ -94,16 +77,14 @@ def _test_compile_base( @pytest.mark.parametrize("fullgraph", [True]) -@pytest.mark.parametrize( - "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] -) +@pytest.mark.parametrize("scaling_type_input", [ScalingType.DYNAMIC]) @pytest.mark.parametrize( "scaling_type_weight", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize( "scaling_type_grad_output", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @@ -133,16 +114,14 @@ def test_eager_only( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True]) -@pytest.mark.parametrize( - "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] -) +@pytest.mark.parametrize("scaling_type_input", [ScalingType.DYNAMIC]) @pytest.mark.parametrize( "scaling_type_weight", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize( "scaling_type_grad_output", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @@ -171,16 +150,14 @@ def test_aot_eager( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False]) -@pytest.mark.parametrize( - "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] -) +@pytest.mark.parametrize("scaling_type_input", [ScalingType.DYNAMIC]) @pytest.mark.parametrize( "scaling_type_weight", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize( "scaling_type_grad_output", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @unittest.skipIf( not torch.cuda.is_available() or not is_sm_at_least_89(), @@ -241,16 +218,12 @@ class TestGraphBreaks(DynamoTestCase): class MockLinear(torch.nn.Module): def __init__(self, graph_break: bool): super().__init__() - self.register_buffer("fp8_amax_x", torch.tensor(1.0)) - self.register_buffer("fp8_scale_x", torch.tensor(1.0)) self.graph_break = graph_break def forward(self, x): - x_fp8 = hp_tensor_to_float8_delayed( + x_fp8 = hp_tensor_to_float8_dynamic( x, - self.fp8_scale_x, e4m3_dtype, - self.fp8_amax_x, LinearMMConfig(), ) if self.graph_break: @@ -330,30 +303,6 @@ def test_float8_graph_output(self): ) -@unittest.skipIf( - not torch.cuda.is_available() or not is_sm_at_least_89(), - "CUDA with float8 support not available", -) -def test_sync_amax_func(): - torch._dynamo.reset() - cnts = CompileCounterWithBackend("inductor") - module = torch.nn.Sequential( - nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True) - ) - config = Float8LinearConfig( - cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), - ) - float8_mod = convert_to_float8_training( - module, - config=config, - ) - compiled_swap_func = torch.compile(sync_float8_amax_and_scale_history, backend=cnts) - compiled_swap_func(float8_mod) - assert cnts.frame_count == 1, "Compiled graph should have 1 frame!" - - class capture_stderr(list): """ Replace sys.stderr with a temporary StringIO @@ -371,38 +320,6 @@ def __exit__(self, *args): sys.stderr = self.sys_stderr -@unittest.skipIf( - not torch.cuda.is_available() or not is_sm_at_least_89(), - "CUDA with float8 support not available", -) -def test_sync_amax_func_cuda_graph_success(): - torch._dynamo.reset() - with capture_stderr() as stderr: - my_module = nn.Sequential( - nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True) - ).to("cuda") - config = Float8LinearConfig( - cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), - ) - convert_to_float8_training( - my_module, - config=config, - ) - inpt = torch.randn( - 16, 16, device="cuda", dtype=torch.float32, requires_grad=True - ) - sync_func = torch.compile( - sync_float8_amax_and_scale_history, mode="reduce-overhead", fullgraph=True - ) - fp8_layers = get_float8_layers(my_module) - my_module(inpt) - sync_func(my_module, fp8_layers) - - assert "skipping cudagraphs due to mutaton on input" not in stderr[0] - - @unittest.skipIf( not is_sm_at_least_89(), "CUDA not available", @@ -475,70 +392,5 @@ def test_dynamic_scale_numeric_parity( assert torch.equal(float8_eager._data, float8_compile._data) -@unittest.skipIf( - not is_sm_at_least_89() or not is_fbcode(), - "CUDA with float8 support not available; or not on fbcode (the test needs be run with the latest pytorch package)", -) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) -def test_delayed_scaling_pattern_replacement(dtype: torch.dtype): - from torch._inductor import config as inductor_config - from torch._inductor import metrics - - inductor_config.loop_ordering_after_fusion = True - - def clear_all(): - metrics.reset() - from torch._inductor.fx_passes.post_grad import ( - pass_patterns as post_grad_patterns_all, - ) - - post_grad_patterns_all[1].clear() - post_grad_patterns_all[1].seen_patterns.clear() - - def compile_and_run_single_layer(): - random.seed(0) - torch.manual_seed(0) - x_shape = (2048, 3072) - linear_dtype = dtype - - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype).requires_grad_() - m_ref = nn.Linear(3072, 2048, bias=True, device="cuda", dtype=linear_dtype) - - config = get_test_float8_linear_config( - ScalingType.DELAYED, - ScalingType.DELAYED, - ScalingType.DELAYED, - False, - ) - - config = replace(config, enable_amax_init=False) - - m_fp8 = StatefulFloat8Linear.from_float( - copy.deepcopy(m_ref), - config, - ) - - m_fp8 = torch.compile(m_fp8, backend="inductor", fullgraph=True) - m_ref = torch.compile(m_ref, backend="inductor", fullgraph=True) - - y_fp8 = m_fp8(x) - y_fp8.sum().backward() - - return m_fp8.weight.grad - - clear_all() - ref_output = compile_and_run_single_layer() - ref_count_kernel = metrics.generated_kernel_count - - clear_all() - _prototype_register_float8_delayed_scaling_inductor_passes() - new_output = compile_and_run_single_layer() - new_count_kernel = metrics.generated_kernel_count - - torch.equal(ref_output, new_output) - # With the pattern replacement workaround, amax reduction kernels for the 3 tensors (weight, activation, gradient) are fused. - assert ref_count_kernel == new_count_kernel + 3 - - if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/float8/test_fsdp.py b/test/float8/test_fsdp.py index 863256dc35..3017c8b539 100644 --- a/test/float8/test_fsdp.py +++ b/test/float8/test_fsdp.py @@ -35,11 +35,9 @@ FullyShardedDataParallel as FSDP, ) -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import Float8LinearConfig from torchao.float8.float8_linear_utils import ( convert_to_float8_training, - linear_requires_sync, - sync_float8_amax_and_scale_history, ) from torchao.float8.float8_utils import compute_error @@ -77,19 +75,13 @@ def get_model(K, N, base_dtype=torch.float32): def fsdp_main(rank, world_size, args): setup(rank, world_size) torch.cuda.set_device(rank) + print("args", args) - emulate, base_dtype, compile, use_weight_dynamic_scaling = args + emulate, base_dtype, compile = args model = get_model(K, N, base_dtype=base_dtype).to(rank) model_fp8 = copy.deepcopy(model) - scaling_type_weight = ( - ScalingType.DYNAMIC if use_weight_dynamic_scaling else ScalingType.DELAYED - ) - config = Float8LinearConfig( - cast_config_weight=CastConfig(scaling_type=scaling_type_weight), - # TODO(future): delete this arg as it's always False - emulate=False, - ) + config = Float8LinearConfig() # Note: we only iterate over `scaling_type_weight` because FSDP only interacts # with weights. @@ -110,6 +102,7 @@ def fsdp_main(rank, world_size, args): # Note: we need two different inputs to properly measure the impact of # delayed scaling, before the first input uses dynamic scaling to # populate the buffers + # TODO(future PR): delete ^, since we deleted delayed scaling ref_input_global = [ torch.randn(B, M, K).cuda().to(base_dtype), torch.randn(B, M, K).cuda().to(base_dtype), @@ -133,16 +126,10 @@ def fsdp_main(rank, world_size, args): ref_grad_global[idx][bsz_local_start:bsz_local_end].to(rank) ) - sync_float8_func = sync_float8_amax_and_scale_history - if compile: - sync_float8_func = torch.compile(sync_float8_amax_and_scale_history) - def forward_backward(model, optim, is_fp8, i): optim.zero_grad() y_local = model(ref_input_local[i]) y_local.backward(ref_grad_local[i]) - if is_fp8 and linear_requires_sync(config): - sync_float8_func(model) optim.step() return y_local @@ -193,7 +180,7 @@ def forward_backward(model, optim, is_fp8, i): cleanup() -def run(compile_fsdp: bool = False, use_weight_dynamic_scaling: bool = False): +def run(compile_fsdp: bool = False): base_dtype = torch.bfloat16 emulate = False @@ -207,7 +194,7 @@ def run(compile_fsdp: bool = False, use_weight_dynamic_scaling: bool = False): emulate = True WORLD_SIZE = torch.cuda.device_count() - args = (emulate, base_dtype, compile_fsdp, use_weight_dynamic_scaling) + args = (emulate, base_dtype, compile_fsdp) mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) diff --git a/test/float8/test_fsdp.sh b/test/float8/test_fsdp.sh index 3ff19d917d..6f135a2e76 100755 --- a/test/float8/test_fsdp.sh +++ b/test/float8/test_fsdp.sh @@ -4,12 +4,12 @@ set -e launch() { - echo "launching compile_fsdp $COMPILE, use_weight_dynamic_scaling $USE_WEIGHT_DYNAMIC_SCALING" + echo "launching compile_fsdp $COMPILE" # the NCCL_DEBUG setting is to avoid log spew # the CUDA_VISIBLE_DEVICES setting is for easy debugging NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/float8/test_fsdp.py \ - --compile_fsdp $COMPILE --use_weight_dynamic_scaling $USE_WEIGHT_DYNAMIC_SCALING + --compile_fsdp $COMPILE echo "✅ All Tests Passed ✅" } @@ -19,10 +19,5 @@ if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; exit fi -# COMPILE, USE_WEIGHT_DYNAMIC_SCALING -for i in False,False False,True True,False True,True -do - IFS=","; set -- $i; - COMPILE=$1; USE_WEIGHT_DYNAMIC_SCALING=$2 - launch -done +COMPILE=False launch +COMPILE=True launch diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index fbe5c9b508..8e35c13506 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -101,7 +101,6 @@ def test_transformer_parity(self): "precompute": [False, True], "scaling_type_weight": [ ScalingType.DYNAMIC, - ScalingType.DELAYED, ], "compile_transformer_block": [False, True], "dtype": [torch.float32, torch.bfloat16], @@ -119,8 +118,6 @@ def _test_transformer_parity( ): if not enable_fsdp_float8_all_gather and precompute: return - elif scaling_type_weight is ScalingType.DELAYED and precompute: - return # NOTE: Weight-tying does not compose with fp8 all-gather because the # embedding weight and output linear weight are tied but only the @@ -462,16 +459,10 @@ def test_fp32_fp8_single_module_parity(self): """ choices = itertools.product( [False, True], - [ScalingType.DYNAMIC, ScalingType.DELAYED, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) for enable_fsdp_float8_all_gather, scaling_type_weight in choices: - if scaling_type_weight is ScalingType.STATIC: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_weight = CastConfig(scaling_type=scaling_type_weight) + cast_config_weight = CastConfig(scaling_type=scaling_type_weight) float8_linear_config1 = Float8LinearConfig( enable_fsdp_float8_all_gather=False, @@ -514,7 +505,7 @@ def test_fp32_fp8_multi_module_parity(self): """ choices = itertools.product( [False, True], - [ScalingType.DYNAMIC, ScalingType.DELAYED], + [ScalingType.DYNAMIC], ) for enable_fsdp_float8_all_gather, scaling_type_weight in choices: float8_linear_config1 = Float8LinearConfig( @@ -584,26 +575,6 @@ def test_bf16_mp_fp8_dynamic_multi_parity(self): self.get_local_inp(torch.bfloat16), ) - @unittest.skipIf(not TEST_CUDA, "no cuda") - def test_delayed_scaling_inplace_update(self): - """ - Verify that `WeightWithDelayedFloat8CastTensor` updates buffers inplace - """ - module = self.init_single_module() - float8_linear_config = Float8LinearConfig( - enable_fsdp_float8_all_gather=True, - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - ) - m_fp8 = convert_to_float8_training( - module, - config=float8_linear_config, - ) - - fp8_amax_weight_old = m_fp8.fp8_amax_weight.clone().detach() - dummy_mesh = None - data, scale = m_fp8.weight.fsdp_pre_all_gather(dummy_mesh) - self.assertNotEqual(fp8_amax_weight_old.item(), m_fp8.fp8_amax_weight.item()) - if __name__ == "__main__": run_tests() diff --git a/test/float8/test_fsdp_compile.py b/test/float8/test_fsdp_compile.py index 1d95801f67..a78a30925c 100644 --- a/test/float8/test_fsdp_compile.py +++ b/test/float8/test_fsdp_compile.py @@ -26,10 +26,8 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torchao.float8 import Float8LinearConfig -from torchao.float8.config import CastConfig, ScalingType from torchao.float8.float8_linear_utils import ( convert_to_float8_training, - sync_float8_amax_and_scale_history, ) torch.manual_seed(0) @@ -63,10 +61,6 @@ def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): # https://gist.github.com/vkuzo/ed8e168fd9f7463f1fce34301334ab55 # to get around this, we can disable amax init config = Float8LinearConfig( - enable_amax_init=False, - cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), emulate=emulate, ) @@ -102,7 +96,6 @@ def fsdp_main(rank, world_size, args): optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size) input_local = torch.randn(B, M, K, N, device="cuda") - sync_float8_func = torch.compile(sync_float8_amax_and_scale_history) model = torch.compile(model) @@ -111,7 +104,6 @@ def fsdp_main(rank, world_size, args): with torch.autocast("cuda"): y_local = model(input_local) y_local.sum().backward() - sync_float8_func(model) optimizer.step() print("done!") diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index 01e4cbb20d..f25c876189 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -31,8 +31,6 @@ ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, - linear_requires_sync, - sync_float8_amax_and_scale_history, ) from torchao.float8.float8_utils import IS_ROCM, compute_error from torchao.testing.float8.test_utils import get_test_float8_linear_config @@ -115,7 +113,7 @@ def _test_impl(self, config: Float8LinearConfig) -> None: # Note: you need two different inputs to properly test numerics # of delayed scaling, because the first time around the initialization # logic of delayed scaling behaves as dynamic scaling - # TODO(future): also make unit tests do this properly + # TODO(future PR): delete ^, since we deleted delayed scaling shape = (1, 8192, 4096) data1 = torch.randn(*shape, device="cuda", dtype=data_dtype) data2 = torch.randn(*shape, device="cuda", dtype=data_dtype) @@ -127,36 +125,21 @@ def _test_impl(self, config: Float8LinearConfig) -> None: model_ref_out = model_ref(data2) model_ref_out.sum().backward() - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(model_fp8) model_fp8(data1).sum().backward() # zero out grads without stepping, since we just want to compare grads # of the second datum optim_fp8.zero_grad() - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(model_fp8) model_fp8_out = model_fp8(data2) model_fp8_out.sum().backward() out_sqnr = compute_error(model_ref_out, model_fp8_out) - any_static_scaling = ( - config.cast_config_input.scaling_type is ScalingType.STATIC - or config.cast_config_weight.scaling_type is ScalingType.STATIC - or config.cast_config_grad_output.scaling_type is ScalingType.STATIC - ) - if any_static_scaling: - assert out_sqnr > 10.0 - else: - assert out_sqnr > 20.0 + assert out_sqnr > 20.0 ref_name_to_grad = { name: param.grad for name, param in model_ref.named_parameters() } - if any_static_scaling: - grad_sqnr_threshold = 10.0 - else: - grad_sqnr_threshold = 20.0 + grad_sqnr_threshold = 20.0 for name, param in model_fp8.named_parameters(): ref_grad = ref_name_to_grad[name] @@ -166,15 +149,15 @@ def _test_impl(self, config: Float8LinearConfig) -> None: @pytest.mark.parametrize( "scaling_type_input", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize( "scaling_type_weight", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize( "scaling_type_grad_output", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.skipif( not is_sm_at_least_89(), reason="requires SM89 compatible machine" diff --git a/torchao/float8/README.md b/torchao/float8/README.md index 4dbc556d83..65105d1f89 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -15,8 +15,6 @@ throughput speedups of up to 1.5x on 128 GPU LLaMa 3 70B pretraining jobs. # Single GPU User API -We provide three per-tensor scaling strategies: dynamic, delayed and static. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`input`), weights (`weight`) and gradients (`grad_output`). - ## float8 linear with dynamic tensorwise scaling This is the default recipe, with a good balance of performance and accuracy. @@ -114,67 +112,6 @@ for _ in range(10): optimizer.step() ``` -## float8 linear with delayed scaling - -:warning: We plan to deprecate delayed scaling in a future release, see https://github.com/pytorch/ao/issues/1680 for more details. - -This is theoretically the most performant recipe as it minimizes memory reads. - -```python -import torch -import torch.nn as nn -from torchao.float8 import ( - convert_to_float8_training, - sync_float8_amax_and_scale_history, - Float8LinearConfig, - ScalingType, - CastConfig, -) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") - -# Recommended: enable additional torchinductor passes to improve the performance of delayed scaling -torchao.float8._prototype_register_float8_delayed_scaling_inductor_passes() - -# create model and sample input -m = nn.Sequential( - nn.Linear(2048, 4096), - nn.Linear(4096, 128), -).bfloat16().cuda() -x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16) -optimizer = torch.optim.SGD(m.parameters(), lr=0.1) - -# configure delayed scaling -config = Float8LinearConfig( - cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), -) - -# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying custom scaling behavior -convert_to_float8_training(m, config=config) - -# enable torch.compile for competitive performance -m = torch.compile(m) - -# toy training loop -for _ in range(10): - optimizer.zero_grad() - y = m(x) - y.sum().backward() - - # Specific to delayed scaling: separate step to sync scales/amaxes. - # On the first call, this function also sets the `is_amax_initialized` flag to - # mark the amax and scale buffers as initialized. - # Make sure you run this after every model forward+backward pass. - # In the future, this may move to a context manager. - sync_float8_amax_and_scale_history(m) - - optimizer.step() -``` - # Multi GPU User API We compose with the `DTensor` based [distributed APIs](https://pytorch.org/docs/stable/distributed.tensor.parallel.html), @@ -226,10 +163,6 @@ There are three observations we can make about the formula above: For small shapes, a combination of (2) and (3) leads to speedup < 1. For medium shapes, (1) and (3) are of similar magnitude and the speedup depends on M, K, N and framework and compiler behavior. For large shapes, (1) leads to speedup > 1. -## Scaling type vs speedup - -Delayed scaling is theoretically faster than dynamic scaling because of reduced read/write traffic requirements. Today, torch.compile has a couple of limitations (see the performance section of https://github.com/pytorch/ao/issues/556) which prevent us from reaching the optimal behavior for delayed scaling without workarounds. We have a prototype workaround (API subject to change) with the `torchao.float8._prototype_register_float8_delayed_scaling_inductor_passes()` API to improve delayed scaling performance. - ## torch.compile behavior vs speedup There are a couple of limitations in how torch.compile generates float8 scaling and casting kernels (see the performance section of https://github.com/pytorch/ao/issues/556). As the limitations get resolved, we expect to reach improved performance. diff --git a/torchao/float8/__init__.py b/torchao/float8/__init__.py index 258db53be0..18ef82a507 100644 --- a/torchao/float8/__init__.py +++ b/torchao/float8/__init__.py @@ -6,15 +6,12 @@ # Lets define a few top level things here from torchao.float8.config import ( CastConfig, - DelayedScalingConfig, Float8GemmConfig, Float8LinearConfig, ScalingType, ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, - linear_requires_sync, - sync_float8_amax_and_scale_history, ) from torchao.float8.float8_tensor import ( Float8Tensor, @@ -23,11 +20,7 @@ ScaledMMConfig, ) from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp -from torchao.float8.inductor_utils import ( - _prototype_register_float8_delayed_scaling_inductor_passes, -) from torchao.float8.inference import Float8MMConfig -from torchao.float8.stateful_float8_linear import WeightWithDelayedFloat8CastTensor from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 if TORCH_VERSION_AT_LEAST_2_5: @@ -41,22 +34,17 @@ GemmInputRole, LinearMMConfig, Float8MMConfig, - WeightWithDelayedFloat8CastTensor, ] ) __all__ = [ # configuration - "DelayedScalingConfig", "ScalingType", "Float8GemmConfig", "Float8LinearConfig", "CastConfig", # top level UX "convert_to_float8_training", - "linear_requires_sync", - "sync_float8_amax_and_scale_history", "precompute_float8_dynamic_scale_for_fsdp", - "_prototype_register_float8_delayed_scaling_inductor_passes", # note: Float8Tensor and Float8Linear are not public APIs ] diff --git a/torchao/float8/config.py b/torchao/float8/config.py index fa03d55b11..d2998d890f 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -15,20 +15,14 @@ class ScalingType(enum.Enum): - DELAYED = "delayed" DYNAMIC = "dynamic" - STATIC = "static" # ScalingType.DISABLED means "skip scaling for this tensor, leave it in # its original precision. DISABLED = "disabled" def short_str(self): - if self is ScalingType.DELAYED: - return "del" - elif self is ScalingType.DYNAMIC: + if self is ScalingType.DYNAMIC: return "dyn" - elif self is ScalingType.STATIC: - return "sta" else: assert self is ScalingType.DISABLED return "dis" @@ -90,7 +84,6 @@ class CastConfig: scaling_type: ScalingType = ScalingType.DYNAMIC scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE - static_scale: Optional[torch.Tensor] = None target_dtype: Optional[torch.dtype] = None def short_str(self): @@ -98,10 +91,6 @@ def short_str(self): return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}_{dtype}" def __post_init__(self): - if self.scaling_type is ScalingType.STATIC: - assert ( - self.static_scale is not None - ), "static_scale must be specified for static scaling" if self.scaling_granularity is ScalingGranularity.AXISWISE: assert ( self.scaling_type is ScalingType.DYNAMIC @@ -111,30 +100,6 @@ def __post_init__(self): ), "must specify a 8-bit floating-point dtype" -@dataclass(frozen=True) -class DelayedScalingConfig: - """ - Configuration for delayed scaling. - - Note: for now, `history_len` values must be the same for all layers in the - model using delayed scaling. - - TODO(future): serialization for recipes - """ - - # Controls the history length of amax buffers - history_len: int = 16 - - # Controls the way to calculate current scale from amax history - # TODO(future): add other functions as needed, hardcoded or user defined - scale_fn_name: str = "max" - - def __post_init__(self): - assert ( - self.scale_fn_name == "max" - ), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now." - - @dataclass(frozen=True) class Float8GemmConfig: """ @@ -215,14 +180,6 @@ class Float8LinearConfig: # Per-linear configuration # - # This configuration option is deprecated and no longer has an effect. It may - # be removed in a future release. - enable_amax_init: bool = True - - # This configuration option is deprecated and no longer has an effect. It may - # be removed in a future release. - enable_pre_and_post_forward: bool = True - # If True, then uses a tensor subclass for the float8 linear module's weight that # implements pre/post-all-gather methods to do float8 all-gather with FSDP2. enable_fsdp_float8_all_gather: bool = False @@ -236,13 +193,6 @@ class Float8LinearConfig: # If True, emulation is used instead of hardware accelerated gemm emulate: bool = False - # Configuration for delayed scaling - # Note: this is actually applied per-tensor, but only using the same - # configuration for all tensors and layers in the model is currently - # supported. If in the future we add support for a more fine grained - # configuration, this field may move to per-tensor configs. - delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig() - # If the option is enabled, fp8_weight will always be re-computed in backward. # It's recommended to enable this flag when using FSDP. # Otherwise, the entire fp8_weight, instead of the sharded weight may be saved. @@ -336,16 +286,6 @@ def __post_init__(self): "When using FSDP, it's recommended to enable config.force_recompute_fp8_weight_in_bwd." ) - # Future deprecation warning for delayed scaling - if ( - self.cast_config_input.scaling_type != ScalingType.DYNAMIC - or self.cast_config_weight.scaling_type != ScalingType.DYNAMIC - or self.cast_config_grad_output.scaling_type != ScalingType.DYNAMIC - ): - logger.warning( - "Note: delayed and static scaling will be deprecated in a future release of torchao. Please see https://github.com/pytorch/ao/issues/1680 for more details." - ) - @staticmethod def from_recipe_name( recipe_name: Union[Float8LinearRecipeName, str], diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index d822d33042..9d5cdd3242 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -64,8 +64,6 @@ class matmul_with_hp_or_float8_args(torch.autograd.Function): * if the arguments are in high precision, they are cast to float8 according to the specified config * if the arguments are in float8, we assume the cast honored the config - - Only supports dynamic scaling, does not support delayed/static scaling. """ @staticmethod @@ -259,8 +257,7 @@ class Float8Linear(torch.nn.Linear): inside of this repository. Please file an issue if you would benefit from this being a public API. - A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks - scales in way friendly to delayed scaling. + A wrapper around a `torch.nn.Linear` module which does fp8 compute. """ def __init__(self, *args, **kwargs): @@ -411,6 +408,7 @@ def from_float( # 1. weight needs to be on the correct device to create the buffers # 2. buffers need to be already created for the delayed scaling version # of the weight wrapper to be initialized + # TODO(future PR): see if we can simplify ^ now that delayed scaling is deleted if config.enable_fsdp_float8_all_gather: assert config.cast_config_weight.scaling_type is ScalingType.DYNAMIC new_mod.weight = torch.nn.Parameter( diff --git a/torchao/float8/float8_linear_utils.py b/torchao/float8/float8_linear_utils.py index 3649b741cc..db9889567f 100644 --- a/torchao/float8/float8_linear_utils.py +++ b/torchao/float8/float8_linear_utils.py @@ -6,56 +6,15 @@ import logging from typing import Callable, Optional -import torch -import torch.distributed as dist import torch.nn as nn -from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce -from torchao.float8.config import Float8LinearConfig, ScalingType +from torchao.float8.config import Float8LinearConfig from torchao.float8.float8_linear import Float8Linear -from torchao.float8.float8_utils import ( - amax_history_to_scale_stack, - config_has_stateful_scaling, -) -from torchao.float8.stateful_float8_linear import StatefulFloat8Linear log = logging.getLogger(__name__) log.addHandler(logging.NullHandler()) -def linear_requires_sync(config: Float8LinearConfig): - """Returns whether the given linear_type requires sync before forward.""" - return any( - [ - config.cast_config_input.scaling_type is ScalingType.DELAYED, - config.cast_config_weight.scaling_type is ScalingType.DELAYED, - config.cast_config_grad_output.scaling_type is ScalingType.DELAYED, - ] - ) - - -def _update_history_stack( - new_amax: torch.Tensor, amax_history_stack: torch.Tensor -) -> torch.Tensor: - """ - Updates `amax_history` (the last N cur_amax values) inplace with the value - of `new_amax`. - - Args: - new_amax (torch.Tensor): The new amax value to add to the history. (n_amaxes, 1) - amax_history_stack (torch.Tensor): The history of amax values. (n_amaxes, history_length) - """ - assert ( - amax_history_stack.dim() == 2 - ), f"Expected amat_history_stack to be 2D, got {amax_history_stack.shape()}" - assert ( - new_amax.size(0) == amax_history_stack.size(0) - ), f"Expected new_amax to have the same size as the first dimension of amax_history_stack, got {new_amax.size(0)} and {amax_history_stack.size(0)}" - new_amax_history_stack = torch.roll(amax_history_stack, 1, dims=1) - new_amax_history_stack[:, 0] = new_amax.squeeze(-1) - amax_history_stack.copy_(new_amax_history_stack) - - def swap_linear_layers( module: nn.Module, from_float_func: Callable[[nn.Linear], nn.Linear], @@ -144,196 +103,13 @@ def convert_to_float8_training( if config is None: config = Float8LinearConfig() - if config_has_stateful_scaling(config): - from_float = lambda m: StatefulFloat8Linear.from_float( - m, - config=config, - ) - else: - from_float = lambda m: Float8Linear.from_float( - m, - config=config, - ) + from_float = lambda m: Float8Linear.from_float( + m, + config=config, + ) return swap_linear_layers( module, from_float, module_filter_fn=module_filter_fn, ) - - -def get_float8_layers(model: torch.nn.Module): - """Iterates through the model and returns all the Float8Linear layers. - Args: - model (torch.nn.Module): The model to look for Float8Linear layers in. - """ - - # Get all fp8 layers and tensors - fp8_layers = [child for child in model.modules() if isinstance(child, Float8Linear)] - if not torch.compiler.is_compiling(): - for layer in fp8_layers: - for buf in layer.buffers(): - torch._dynamo.mark_static_address(buf, guard=True) - return fp8_layers - - -@torch.no_grad() -def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) -> None: - """ - Manages the float8 amax and scale bookkeeping. In detail, it does the - following: - 1. in distributed contexts, syncs amax values across workers for activations and gradients - 2. adds the `amax` values to history - 3. calculates the scales to be used for next iteration - 4. sets the `amax_and_scale_synced` flag on the Float8Linear modules - to signal that they have been synced - - TODO(future): design the UX for this (context manager, etc) - - PERFORMANCE NOTE: - When you can, it is much more efficient to call get_float8_layers once at - the beginning of the training loop and pass the result to this function. - Because of how this interacts with torch.compile - - Args: - model (torch.nn.Module): The model to track amaxes for - fp8_layers (optional): If fp8_layers are provided, fp8_classes are ignored, - and we loop over all fp8_layers to sync and update amax scale histories. - Users can use get_float8_layers to get all fp8 layers. - """ - # TODO(future): consider adding a flag to control setting the `is_amax_initialized` - # flag only on the first iteration. - - if fp8_layers is None: - fp8_layers = get_float8_layers(model) - - if len(fp8_layers) == 0: - log.warn( - "Calling sync_float8_amax_and_scale_history on a module with no Float8Linear layers" - ) - return - - def inner_func(): - """Why do we have this inner_function? - - There are two portions of the outer sync_function that cause graph_breaks: - 1. The `get_float8_layers` call can cause graph breaks if the user did not pass - in the fp8_layers. - 2. At the end of syncing all the amaxes and scales we set the attr on the module - signaling that we have synced the amaxes and scales and the next forward can be run. - # TODO Maybe we should remove this safety check to remove the graph break? - - By having this inner function, we can ensure that although the outer function may cause graph breaks - the inner function will not. - """ - # Loop over all fp8 layers and grab the needed tensors - fp8_amax_input_tensor_list = [None] * len(fp8_layers) - fp8_amax_weight_tensor_list = [None] * len(fp8_layers) - fp8_amax_grad_output_tensor_list = [None] * len(fp8_layers) - - fp8_input_amax_history_stack = [None] * len(fp8_layers) - fp8_weight_amax_history_stack = [None] * len(fp8_layers) - fp8_grad_output_amax_history_stack = [None] * len(fp8_layers) - - input_dtypes = set() - weight_dtypes = set() - grad_output_dtypes = set() - scale_fn_recipes = set() - - for idx, child in enumerate(fp8_layers): - fp8_amax_input_tensor_list[idx] = child.fp8_amax_input - fp8_amax_weight_tensor_list[idx] = child.fp8_amax_weight - fp8_amax_grad_output_tensor_list[idx] = child.fp8_amax_grad_output - - fp8_input_amax_history_stack[idx] = child.fp8_amax_history_input - fp8_weight_amax_history_stack[idx] = child.fp8_amax_history_weight - fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output - - input_dtypes.add(child.config.cast_config_input.target_dtype) - weight_dtypes.add(child.config.cast_config_weight.target_dtype) - grad_output_dtypes.add(child.config.cast_config_grad_output.target_dtype) - scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name) - - (input_dtype,) = input_dtypes - (weight_dtype,) = weight_dtypes - (grad_output_dtype,) = grad_output_dtypes - - if len(scale_fn_recipes) != 1: - raise ValueError( - f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}" - ) - scale_fn_recipe = next(iter(scale_fn_recipes)) - - assert ( - len(fp8_amax_input_tensor_list) - == len(fp8_amax_weight_tensor_list) - == len(fp8_amax_grad_output_tensor_list) - ), "Mismatched lengths of amax tensors." - - if dist.is_initialized(): - all_amax_tensors = torch.cat( - fp8_amax_input_tensor_list - + fp8_amax_weight_tensor_list - + fp8_amax_grad_output_tensor_list - ) - all_reduced_amax_tensor = all_reduce( - all_amax_tensors, "MAX", list(range(dist.get_world_size())) - ) - if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor): - all_reduced_amax_tensor = all_reduced_amax_tensor.wait() - - ( - reduced_fp8_amax_input_tensor, - reduced_fp8_amax_weight_tensor, - reduced_fp8_amax_grad_output_tensor, - ) = torch.split(all_reduced_amax_tensor, len(fp8_amax_input_tensor_list)) - - for idx, child in enumerate(fp8_layers): - child.fp8_amax_input.copy_(reduced_fp8_amax_input_tensor[idx]) - child.fp8_amax_weight.copy_(reduced_fp8_amax_weight_tensor[idx]) - child.fp8_amax_grad_output.copy_( - reduced_fp8_amax_grad_output_tensor[idx] - ) - - # We create two stacked tensor groups, one for the amax history and one for the current scales - fp8_amax_input_tensors = torch.vstack(fp8_amax_input_tensor_list) - fp8_amax_weight_tensors = torch.vstack(fp8_amax_weight_tensor_list) - fp8_amax_grad_output_tensors = torch.vstack(fp8_amax_grad_output_tensor_list) - - fp8_input_amax_history_stack = torch.vstack(fp8_input_amax_history_stack) - fp8_weight_amax_history_stack = torch.vstack(fp8_weight_amax_history_stack) - fp8_grad_output_amax_history_stack = torch.vstack( - fp8_grad_output_amax_history_stack - ) - - # Update the history stacks with the new amax values - _update_history_stack(fp8_amax_input_tensors, fp8_input_amax_history_stack) - _update_history_stack(fp8_amax_weight_tensors, fp8_weight_amax_history_stack) - _update_history_stack( - fp8_amax_grad_output_tensors, fp8_grad_output_amax_history_stack - ) - - # Calculate the new scales from the updated history stacks - new_input_scales = amax_history_to_scale_stack( - fp8_input_amax_history_stack, input_dtype, scale_fn_recipe - ) - new_weight_scales = amax_history_to_scale_stack( - fp8_weight_amax_history_stack, weight_dtype, scale_fn_recipe - ) - new_grad_output_scales = amax_history_to_scale_stack( - fp8_grad_output_amax_history_stack, grad_output_dtype, scale_fn_recipe - ) - - # Iterate through the layers and update the scales - for idx, child in enumerate(fp8_layers): - child.fp8_scale_input.copy_(new_input_scales[idx]) - child.fp8_scale_weight.copy_(new_weight_scales[idx]) - child.fp8_scale_grad_output.copy_(new_grad_output_scales[idx]) - - # This allows for the compile to succeed on the inner func and fail on the graph breaks - # at the beginning and and of syncing - inner_func() - - for child in fp8_layers: - # Set a flag to signal that initialization is done - child.is_amax_initialized = True diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index b96c7a9b58..31f2db6b4e 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -21,8 +21,6 @@ hp_tensor_and_scale_to_float8, ) from torchao.float8.float8_utils import ( - amax_history_to_scale, - tensor_to_amax, tensor_to_scale, ) @@ -74,72 +72,6 @@ def hp_tensor_to_float8_dynamic( ) -def hp_tensor_to_float8_delayed( - hp_tensor: torch.Tensor, - s: torch.Tensor, - float8_dtype: torch.dtype, - amax_buffer: torch.Tensor, - linear_mm_config: Optional[LinearMMConfig] = None, - gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, -) -> Float8Tensor: - """ - Given a high precision tensor `hp_tensor` and relevant metadata, scales it using - delayed scaling and returns a `Float8Tensor` of the result. Specifically: - 1. calculates max(abs(hp_tensor)) and stores the result in `amax_buffer`, inplace - 2. scales `hp_tensor` by `s` and returns the result wrapped in Float8Tensor - - Args: - hp_tensor: the tensor to convert - s: the scale to use to convert the tensor - float8_dtype: the float8 dtype to use - amax_buffer: the buffer to modify inplace with max(abs(hp_tensor)) - linear_mm_config: Defines the configuration for the scaled_mm for - the 3 fwd/bwd gemms of linear - gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in - the 3 fwd/bwd gemms of linear - """ - amax_buffer.fill_(tensor_to_amax(hp_tensor)) - return hp_tensor_and_scale_to_float8( - hp_tensor, - s, - float8_dtype, - linear_mm_config, - gemm_input_role, - ) - - -def hp_tensor_to_float8_static( - hp_tensor: torch.Tensor, - scale: torch.Tensor, - float8_dtype: torch.dtype, - linear_mm_config: LinearMMConfig, - gemm_input_role: GemmInputRole = GemmInputRole.INPUT, -) -> Float8Tensor: - """ - Given a high precision tensor `hp_tensor` and a scale, - scales `hp_tensor` returns a `Float8Tensor` of the result. - - Args: - hp_tensor: the tensor to convert - scale: the scale to use - float8_dtype: the float8 dtype to use - linear_mm_config: Defines the configuration for the scaled_mm for - the 3 fwd/bwd gemms of linear - gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in - the 3 fwd/bwd gemms of linear - """ - if tensor_already_casted_to_fp8(hp_tensor): - return hp_tensor - - return hp_tensor_and_scale_to_float8( - hp_tensor, - scale, - float8_dtype, - linear_mm_config, - gemm_input_role, - ) - - def get_maybe_axiswise_dim( axiswise_dim: int, scaling_granularity: ScalingGranularity, @@ -155,95 +87,6 @@ def get_maybe_axiswise_dim( return None -def _maybe_initialize_amaxes_scales_for_float8_cast( - x, - cur_amax, - amax_history, - scale, - scale_fn_name, - float8_dtype, - is_initialized, - reduce_amax, -): - """ - If x is about to be cast to `float8` and the amax buffers are not initialized, - initializes them inplace. - """ - if is_initialized: - return - with torch.no_grad(): - # Note: we need to enable distributed reduction here in order - # to match numerics between single GPU and multi GPU code for - # activations and gradients - new_amax = tensor_to_amax(x, reduce_amax=reduce_amax) - cur_amax.fill_(new_amax) - amax_history[0] = new_amax - new_scale = amax_history_to_scale(amax_history, float8_dtype, scale_fn_name) - scale.copy_(new_scale) - - -@torch._dynamo.allow_in_graph -class NoopFwToFloat8BwDelayed(torch.autograd.Function): - """ - Forward: no-op - Backward: convert to float8_e5m2 with delayed scaling, initialize if needed - """ - - @staticmethod - def forward( - ctx, - tensor, - fp8_amax_grad_output, - fp8_amax_history_grad_output, - fp8_scale_grad_output, - scale_fn_name, - is_amax_initialized, - linear_mm_config: LinearMMConfig, - target_dtype: torch.dtype, - ): - ctx.save_for_backward( - fp8_amax_grad_output, fp8_amax_history_grad_output, fp8_scale_grad_output - ) - ctx.scale_fn_name = scale_fn_name - ctx.is_amax_initialized = is_amax_initialized - ctx.linear_mm_config = linear_mm_config - ctx.target_dtype = target_dtype - return tensor - - @staticmethod - def backward(ctx, go): - ( - fp8_amax_grad_output, - fp8_amax_history_grad_output, - fp8_scale_grad_output, - ) = ctx.saved_tensors - scale_fn_name = ctx.scale_fn_name - is_amax_initialized = ctx.is_amax_initialized - - _maybe_initialize_amaxes_scales_for_float8_cast( - go, - fp8_amax_grad_output, - fp8_amax_history_grad_output, - fp8_scale_grad_output, - scale_fn_name, - ctx.target_dtype, - is_amax_initialized, - reduce_amax=True, - ) - - fp8_amax_grad_output.fill_(tensor_to_amax(go)) - - res = hp_tensor_and_scale_to_float8( - go, - fp8_scale_grad_output, - ctx.target_dtype, - ctx.linear_mm_config, - GemmInputRole.GRAD_OUTPUT, - ) - empty_grads = None, None, None, None, None, None, None - return res, *empty_grads - - @torch._dynamo.allow_in_graph class NoopFwToFloat8BwDynamic(torch.autograd.Function): """ @@ -275,38 +118,3 @@ def backward(ctx, gradY): GemmInputRole.GRAD_OUTPUT, ) return fp8_tensor, None, None - - -@torch._dynamo.allow_in_graph -class NoopFwToFloat8BwStatic(torch.autograd.Function): - """ - Forward: no-op - Backward: convert to float8_e5m2 with static scaling - """ - - @staticmethod - def forward( - ctx, - tensor, - scale, - linear_mm_config: LinearMMConfig, - target_dtype: torch.dtype, - ): - ctx.save_for_backward(scale) - ctx.linear_mm_config = linear_mm_config - ctx.target_dtype = target_dtype - return tensor - - @staticmethod - def backward(ctx, gradY): - if tensor_already_casted_to_fp8(gradY): - return gradY, None, None, None - (gradY_scale,) = ctx.saved_tensors - fp8_tensor = hp_tensor_and_scale_to_float8( - gradY, - gradY_scale, - ctx.target_dtype, - ctx.linear_mm_config, - GemmInputRole.GRAD_OUTPUT, - ) - return fp8_tensor, None, None, None diff --git a/torchao/float8/float8_tensor_parallel.py b/torchao/float8/float8_tensor_parallel.py index a52b38b6bf..abc74e3ff6 100644 --- a/torchao/float8/float8_tensor_parallel.py +++ b/torchao/float8/float8_tensor_parallel.py @@ -27,8 +27,7 @@ def _float8_linear_supports_float8_allgather(m): - # TODO(future): add support for delayed scaling for activations - # and gradients + # TODO(future PR): also gate this by granularity return ( m.scaling_type_input == ScalingType.DYNAMIC and m.scaling_type_grad_output == ScalingType.DYNAMIC diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 926b97edb8..625fb29235 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -4,13 +4,13 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterable, Literal, Optional, Tuple, Union +from typing import Iterable, Optional, Tuple, Union import torch import torch.distributed as dist from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce -from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType +from torchao.float8.config import ScalingGranularity # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html @@ -53,44 +53,6 @@ def amax_to_scale( return res -@torch.no_grad() -def amax_history_to_scale( - amax_history: torch.Tensor, - float8_dtype: torch.Tensor, - history_to_scale_fn_type: Literal["max"], -): - """Takes in a history of amax values and returns a scale tensor. - Args: - amax_history: A tensor containing the history of amax values. - float8_dtype: The float8 dtype. - history_to_scale_fn_type: The type of function to use to convert the history to a scale. - """ - if history_to_scale_fn_type == "max": - amax = torch.max(amax_history) - return amax_to_scale(amax, float8_dtype) - raise NotImplementedError() - - -@torch.no_grad() -def amax_history_to_scale_stack( - amax_history: torch.Tensor, - float8_dtype: torch.dtype, - history_to_scale_fn_type: Literal["max"], -) -> torch.Tensor: - """Takes in a stack of amax_history tensors and returns a scale tensor. - Args: - amax_history: A 2D tensor containing a stack of amax histories. - float8_dtype: The float8 dtype. - history_to_scale_fn_type: The type of function to use to convert the history to a scale. - """ - if history_to_scale_fn_type == "max": - amax_stack = torch.max(amax_history, dim=1).values - return amax_to_scale(amax_stack, float8_dtype) - raise NotImplementedError( - f"Invalid history_to_scale_fn_type, only 'max' is supported. Got: {history_to_scale_fn_type}" - ) - - @torch.no_grad() def tensor_to_amax( x: torch.Tensor, @@ -274,17 +236,6 @@ def pad_tensor_for_matmul( return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1)) -def config_has_stateful_scaling(config: Float8LinearConfig) -> bool: - """ - Returns True if `config` has any delayed or static scaling, and False otherwise. - """ - return ( - config.cast_config_input.scaling_type != ScalingType.DYNAMIC - or config.cast_config_weight.scaling_type != ScalingType.DYNAMIC - or config.cast_config_grad_output.scaling_type != ScalingType.DYNAMIC - ) - - def _round_scale_down_to_power_of_2(scale: torch.Tensor): assert scale.dtype == torch.float32, "scale must be float32 tensor" return torch.exp2(torch.floor(torch.log2(scale))) diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index f246879a7c..7b24dc2b53 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -13,8 +13,6 @@ from torch._prims_common import suggest_memory_format from torchao.float8.float8_scaling_utils import ( - _maybe_initialize_amaxes_scales_for_float8_cast, - hp_tensor_to_float8_delayed, hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import ( @@ -39,14 +37,8 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: """ from torch.distributed._tensor import DTensor - from torchao.float8.config import ScalingType from torchao.float8.float8_linear import Float8Linear - if any( - isinstance(m, Float8Linear) and m.scaling_type_weight is ScalingType.DELAYED - for m in module.modules() - ): - raise NotImplementedError("Only supports dynamic scaling") float8_linears: List[Float8Linear] = [ m for m in module.modules() @@ -274,331 +266,3 @@ def fsdp_post_all_gather( self._linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, ), (data,) - - -class WeightWithDelayedFloat8CastTensor(torch.Tensor): - @staticmethod - def __new__( - cls, - tensor: torch.Tensor, - amax_buffer: torch.Tensor, - amax_history_buffer: torch.Tensor, - scale_buffer: torch.Tensor, - linear_mm_config: LinearMMConfig, - dtype: torch.dtype, - is_amax_initialized: bool, - ): - return torch.Tensor._make_wrapper_subclass( - cls, - tensor.size(), - strides=tensor.stride(), - storage_offset=tensor.storage_offset(), - memory_format=suggest_memory_format(tensor), - dtype=tensor.dtype, - layout=tensor.layout, - device=tensor.device, - pin_memory=tensor.is_pinned(), - requires_grad=tensor.requires_grad, - ) - - def __init__( - self, - tensor: torch.Tensor, - amax_buffer: torch.Tensor, - amax_history_buffer: torch.Tensor, - scale_buffer: torch.Tensor, - linear_mm_config: LinearMMConfig, - dtype: torch.dtype, - is_amax_initialized: bool, - ): - self._tensor = tensor - self._amax_buffer = amax_buffer - self._amax_history_buffer = amax_history_buffer - self._scale_buffer = scale_buffer - self._linear_mm_config = linear_mm_config - self._dtype = dtype - - # Note: is_amax_initialized is not a buffer to avoid data dependent - # control flow visible to dynamo - # TODO(future PR): add serialization for this flag - self.is_amax_initialized = is_amax_initialized - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - if func == torch.ops.aten.detach.default: - return WeightWithDelayedFloat8CastTensor( - args[0]._tensor, - args[0]._amax_buffer, - args[0]._amax_history_buffer, - args[0]._scale_buffer, - args[0]._linear_mm_config, - args[0]._dtype, - args[0].is_amax_initialized, - ) - mm_config: Optional[LinearMMConfig] = None - dtype: Optional[torch.dtype] = None - amax_buffer: Optional[torch.Tensor] = None - amax_history_buffer: Optional[torch.Tensor] = None - scale_buffer: Optional[torch.Tensor] = None - is_amax_initialized: Optional[bool] = None - - def unwrap(t): - nonlocal mm_config - if mm_config is None: - mm_config = t._linear_mm_config - else: - assert t._linear_mm_config == mm_config - nonlocal dtype - if dtype is None: - dtype = t._dtype - else: - assert t._dtype == dtype - nonlocal amax_buffer - if amax_buffer is None: - amax_buffer = t._amax_buffer - nonlocal amax_history_buffer - if amax_history_buffer is None: - amax_history_buffer = t._amax_history_buffer - nonlocal scale_buffer - if scale_buffer is None: - scale_buffer = t._scale_buffer - nonlocal is_amax_initialized - if is_amax_initialized is None: - is_amax_initialized = t.is_amax_initialized - return t._tensor - - args, kwargs = pytree.tree_map_only( - WeightWithDelayedFloat8CastTensor, unwrap, (args, kwargs or {}) - ) - out = func(*args, **kwargs) - if func not in _ops_to_preserve_subclass: - return out - return pytree.tree_map_only( - torch.Tensor, - lambda x: WeightWithDelayedFloat8CastTensor( - x, - amax_buffer, - amax_history_buffer, - scale_buffer, - mm_config, - dtype, - is_amax_initialized, - ), - out, - ) - - def __tensor_flatten__(self): - return ( - [ - "_tensor", - "_amax_buffer", - "_amax_history_buffer", - "_scale_buffer", - ], - { - "mm_config": self._linear_mm_config, - "dtype": self._dtype, - "is_amax_initialized": self.is_amax_initialized, - }, - ) - - @staticmethod - def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): - return WeightWithDelayedFloat8CastTensor( - inner_tensors["_tensor"], - inner_tensors["_amax_buffer"], - inner_tensors["_amax_history_buffer"], - inner_tensors["_scale_buffer"], - metadata["mm_config"], - metadata["dtype"], - metadata["is_amax_initialized"], - ) - - def __repr__(self): - return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._linear_mm_config}, dtype={self._dtype})" - - def fsdp_pre_all_gather(self, mesh): - # initialize if needed - # TODO(before land): ensure settings are consistent between Float8Linear and here - if not self.is_amax_initialized: - _maybe_initialize_amaxes_scales_for_float8_cast( - self._tensor, - self._amax_buffer, - self._amax_history_buffer, - self._scale_buffer, - "max", # TODO(before land): read this from parent - self._dtype, - self.is_amax_initialized, - reduce_amax=True, - ) - self.is_amax_initialized = True - - float8_tensor = hp_tensor_to_float8_delayed( - self._tensor, - self._scale_buffer, - self._dtype, - self._amax_buffer, - self._linear_mm_config, - GemmInputRole.WEIGHT, - ) - return (float8_tensor._data,), (float8_tensor._scale,) - - def fsdp_post_all_gather( - self, - all_gather_outputs: Tuple[torch.Tensor, ...], - metadata: Any, - param_dtype: torch.dtype, - *, - out: Optional[torch.Tensor] = None, - ): - (data,) = all_gather_outputs - (scale,) = metadata - if out is not None: - assert isinstance(out, Float8Tensor), f"{type(out)}" - out._scale = scale - return - return Float8Tensor( - data, - scale, - param_dtype, - self._linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - ), (data,) - - -class WeightWithStaticFloat8CastTensor(torch.Tensor): - @staticmethod - def __new__( - cls, - tensor: torch.Tensor, - static_scale: torch.Tensor, - linear_mm_config: LinearMMConfig, - dtype: torch.dtype, - ): - return torch.Tensor._make_wrapper_subclass( - cls, - tensor.size(), - strides=tensor.stride(), - storage_offset=tensor.storage_offset(), - memory_format=suggest_memory_format(tensor), - dtype=tensor.dtype, - layout=tensor.layout, - device=tensor.device, - pin_memory=tensor.is_pinned(), - requires_grad=tensor.requires_grad, - ) - - def __init__( - self, - tensor: torch.Tensor, - static_scale: torch.Tensor, - linear_mm_config: LinearMMConfig, - dtype: torch.dtype, - ): - self._tensor = tensor - self._static_scale = static_scale - self._linear_mm_config = linear_mm_config - self._dtype = dtype - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - if func == torch.ops.aten.detach.default: - return WeightWithStaticFloat8CastTensor( - args[0]._tensor, - args[0]._static_scale, - args[0]._linear_mm_config, - args[0]._dtype, - ) - static_scale: Optional[torch.Tensor] = None - mm_config: Optional[LinearMMConfig] = None - dtype: Optional[torch.dtype] = None - - def unwrap(t): - nonlocal static_scale - if static_scale is None: - static_scale = t._static_scale - nonlocal mm_config - if mm_config is None: - mm_config = t._linear_mm_config - else: - assert t._linear_mm_config == mm_config - nonlocal dtype - if dtype is None: - dtype = t._dtype - else: - assert t._dtype == dtype - return t._tensor - - args, kwargs = pytree.tree_map_only( - WeightWithStaticFloat8CastTensor, unwrap, (args, kwargs or {}) - ) - out = func(*args, **kwargs) - if func not in _ops_to_preserve_subclass: - return out - return pytree.tree_map_only( - torch.Tensor, - lambda x: WeightWithStaticFloat8CastTensor( - x, static_scale, mm_config, dtype - ), - out, - ) - - def __tensor_flatten__(self): - return ["_tensor", "_static_scale"], { - "mm_config": self._linear_mm_config, - "dtype": self._dtype, - } - - @staticmethod - def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): - return WeightWithStaticFloat8CastTensor( - inner_tensors["_tensor"], - inner_tensors["_static_scale"], - flatten_spec["mm_config"], - flatten_spec["dtype"], - ) - - def __repr__(self): - return f"WeightWithStaticFloat8CastTensor(tensor={self._tensor}, static_scale={self._static_scale}, linear_mm_config={self._linear_mm_config}, dtype={self.dtype})" - - def fsdp_pre_all_gather(self, mesh): - float8_tensor = hp_tensor_and_scale_to_float8( - self._tensor, - self._static_scale, - self._dtype, - self._linear_mm_config, - GemmInputRole.WEIGHT, - ) - return (float8_tensor._data,), (float8_tensor._scale,) - - def fsdp_post_all_gather( - self, - all_gather_outputs: Tuple[torch.Tensor, ...], - metadata: Any, - param_dtype: torch.dtype, - *, - out: Optional[torch.Tensor] = None, - ): - (data,) = all_gather_outputs - (scale,) = metadata - if out is not None: - from torch.distributed._tensor import DTensor - - if isinstance(out, Float8Tensor): - out._scale = scale - elif isinstance(out, DTensor) and isinstance( - out._local_tensor, Float8Tensor - ): - out._local_tensor._scale = scale - else: - raise RuntimeError( - f"out must be a Float8Tensor or DTensor(_local_tensor=Float8Tensor), but got {out}" - ) - return - return Float8Tensor( - data, - scale, - param_dtype, - self._linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - ), (data,) diff --git a/torchao/float8/inductor_utils.py b/torchao/float8/inductor_utils.py deleted file mode 100644 index 3e86202536..0000000000 --- a/torchao/float8/inductor_utils.py +++ /dev/null @@ -1,126 +0,0 @@ -import functools -import inspect -import traceback -from collections import deque - -import torch - - -def amax_with_scaling_pattern(tensor_x_inp, scale_x, fp8_dtype, fp8_max): - tensor_x = tensor_x_inp.to(torch.float32) * scale_x - tensor_x = tensor_x.clamp(min=-1 * fp8_max, max=fp8_max) - tensor_x = tensor_x.to(fp8_dtype) - amax = torch.max(torch.abs(tensor_x_inp)) - return (tensor_x, amax) - - -def amax_with_scaling_tiled_replacement(tensor_x_inp, scale_x, fp8_dtype, fp8_max): - tensor_x = tensor_x_inp.to(torch.float32) * scale_x - tensor_x = tensor_x.clamp(min=-1 * fp8_max, max=fp8_max) - tensor_x = tensor_x.to(fp8_dtype) - amax_1 = torch.max(torch.abs(tensor_x_inp), dim=-1).values - amax = torch.max(amax_1) - return (tensor_x, amax) - - -# The amax_with_scaling_pattern will also match dynamic scaling cases, we want to avoid that. -# `scale_x` of delayed scaling comes from the previous iteration, instead of from `tensor_x_inp`. -# We check that `scale_x` is not a dependency of `tensor_x_inp` -def fp8_delayed_scaling_extra_check(match): - scale_x_inputs = deque([match.kwargs["scale_x"]]) - max_num_node_to_check = 20 # Don't traverse too many nodes - current_num_node = 0 - while len(scale_x_inputs) > 0 and current_num_node < max_num_node_to_check: - current_node = scale_x_inputs.popleft() - for n in current_node.all_input_nodes: - if n == match.kwargs["tensor_x_inp"]: - return False - scale_x_inputs.append(n) - current_num_node += 1 - return True - - -def partialize_and_update_signature(func, **kwargs): - """ - Equivalent to functools.partial but also updates the signature on returned function - """ - original_sig = inspect.signature(func) - parameters = original_sig.parameters - - new_parameters = { - key: value for key, value in parameters.items() if key not in kwargs - } - new_sig = inspect.Signature(parameters=list(new_parameters.values())) - - partial_func = functools.partial(func, **kwargs) - - def wrapper(*args, **kwargs): - return partial_func(*args, **kwargs) - - wrapper.__signature__ = new_sig # type: ignore[attr-defined] - wrapper.__name__ = func.__name__ - - return wrapper - - -def register_fp8_delayed_scaling_patterns_inner(): - from torch._inductor.fx_passes.post_grad import ( - pass_patterns as post_grad_patterns_all, - ) - from torch._inductor.pattern_matcher import fwd_only, register_replacement - - post_grad_patterns = post_grad_patterns_all[1] # medium priority - - if torch.cuda.is_available(): - for fp8_dtype in [ - torch.float8_e4m3fn, - torch.float8_e5m2, - torch.float8_e4m3fnuz, - torch.float8_e5m2fnuz, - ]: - # torch.float16 has the same pattern as torch.bfloat16, because they both needs `tensor_x_inp.to(torch.float32)` - for dtype in [torch.float32, torch.bfloat16]: - device = "cuda" - register_replacement( - partialize_and_update_signature( - amax_with_scaling_pattern, - fp8_dtype=fp8_dtype, - fp8_max=torch.finfo(fp8_dtype).max, - ), - partialize_and_update_signature( - amax_with_scaling_tiled_replacement, - fp8_dtype=fp8_dtype, - fp8_max=torch.finfo(fp8_dtype).max, - ), - [ - torch.tensor((16, 16), device=device, dtype=dtype), - torch.tensor(2.0, device=device, dtype=torch.float32), - ], - fwd_only, - post_grad_patterns, - extra_check=fp8_delayed_scaling_extra_check, - ) - - -""" -This a short-term workaround of the delayed scaling performance issue. -It explicitly replaces `max(x)` with `max(max(x, dim=-1))`, enabling the fusion of amax scaling factor calculation and fp8 casting. - -Usage: - To use this solution, add the following line at the beginning of your user code: - torchao.float8._prototype_register_float8_delayed_scaling_inductor_passes() -""" - - -def _prototype_register_float8_delayed_scaling_inductor_passes() -> None: - # To make the fp8 delayed scaling pattern work, we need a fix pr from inductor, https://github.com/pytorch/pytorch/pull/139321 - # Will throw the error if the pattern registration did not work, up to user to decide what to do with it - try: - register_fp8_delayed_scaling_patterns_inner() - except AssertionError as e: - if "assert pattern_repr not in _seen_patterns" in traceback.format_exc(): - print( - f"Caught duplicated patterns in register_fp8_delayed_scaling_patterns: {traceback.format_exc()}", - "\nPlease update your pytorch dependency to the latest main branch to fix it.\n", - ) - raise e diff --git a/torchao/float8/roofline_utils.py b/torchao/float8/roofline_utils.py index 16cf847fe2..58c84c5fa6 100644 --- a/torchao/float8/roofline_utils.py +++ b/torchao/float8/roofline_utils.py @@ -38,78 +38,30 @@ def get_tensor_memory_traffic_bytes( # assumes input bf16, output f8 numel = dim0 * dim1 - if scaling_type == "dynamic": - # x_bf16 = ... - # kernel 1: x_bf16 -> max_abs_stage_1 -> tmp - # kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs - # kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8 - - if fuse_with_prev: - kernel_1_rw = 0 - else: - # kernel 1: read numel, write 0 (assume size(tmp) ~ 0) - kernel_1_rw = BYTES_PER_EL_BF16 * numel - - # kernel 3: read in bf16, write twice in float8 (row-major and col-major) - kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel - - if model_torch_compile_limitations: - # today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...) - # has an extra memory read of the input in fp8 - # context: https://github.com/pytorch/pytorch/issues/130015 - tc_adjustment = numel * BYTES_PER_EL_FLOAT8 - else: - tc_adjustment = 0 - - return kernel_1_rw + kernel_3_rw + tc_adjustment + assert scaling_type == "dynamic", "unsupported" + # x_bf16 = ... + # kernel 1: x_bf16 -> max_abs_stage_1 -> tmp + # kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs + # kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8 + + if fuse_with_prev: + kernel_1_rw = 0 + else: + # kernel 1: read numel, write 0 (assume size(tmp) ~ 0) + kernel_1_rw = BYTES_PER_EL_BF16 * numel + + # kernel 3: read in bf16, write twice in float8 (row-major and col-major) + kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel + if model_torch_compile_limitations: + # today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...) + # has an extra memory read of the input in fp8 + # context: https://github.com/pytorch/pytorch/issues/130015 + tc_adjustment = numel * BYTES_PER_EL_FLOAT8 else: - assert scaling_type == "delayed", "unsupported" - # x_bf16 = ... - # kernel 1: x_bf16 -> max_abs_stage_1_and_to_float8 -> x_float8, tmp - # kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs - # kernel 3 (not modeled): scale -> reciprocal -> inv_scale - - if fuse_with_prev: - kernel_1_r = 0 - else: - kernel_1_r = numel * BYTES_PER_EL_BF16 - # write twice: once in row major, once in col-major - kernel_1_w = numel * BYTES_PER_EL_FLOAT8 * 2 - - if model_torch_compile_limitations: - # today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...) - # has an extra memory read of the input in fp8 - # context: https://github.com/pytorch/pytorch/issues/130015 - tc_adjustment = numel * BYTES_PER_EL_FLOAT8 - - # https://github.com/pytorch/pytorch/issues/128063 - # instead of - # kernel 1: x_bf16 -> max(abs(x)), x_fp8 - # kernel 2: not modeled - # kernel 3: not modeled - # we get - # kernel 1: x_bf16 -> max(abs(x)) - # reads: same as before - # writes: 0 - # ... - # kernel 4: x_bf16, scale -> x_fp8 - # reads: numel * BYTES_PER_EL_BF16 - # writes: 2 * numel * BYTES_PER_EL_FLOAT8 - # Note that assuming worst case, this issue brings the memory - # traffic for delayed scaling to be equal to that of dynamic scaling. - tc_adjustment += ( - # subtract writes from kernel 1 - -1 * 2 * numel * BYTES_PER_EL_FLOAT8 - # add reads for kernel 4 - + numel * BYTES_PER_EL_BF16 - # add writes for kernel 4 - + 2 * numel * BYTES_PER_EL_FLOAT8 - ) - else: - tc_adjustment = 0 - - return kernel_1_r + kernel_1_w + tc_adjustment + tc_adjustment = 0 + + return kernel_1_rw + kernel_3_rw + tc_adjustment def get_gemm_time_sympy(M, K, N, dtype): @@ -131,9 +83,9 @@ def get_float8_mem_sympy( scaling_type_weight: str = "dynamic", scaling_type_grad_output: str = "dynamic", ): - assert scaling_type_input in ("dynamic", "delayed"), "unsupported" - assert scaling_type_weight in ("dynamic", "delayed"), "unsupported" - assert scaling_type_grad_output in ("dynamic", "delayed"), "unsupported" + assert scaling_type_input in ("dynamic",), "unsupported" + assert scaling_type_weight in ("dynamic",), "unsupported" + assert scaling_type_grad_output in ("dynamic",), "unsupported" # there are three gemms in the fwd/bwd of a linear: # @@ -207,27 +159,12 @@ def get_float8_mem_sympy( if scaling_type_input == "dynamic": # second stage of max-abs reduction num_extra_kernels += 1 - elif scaling_type_input == "delayed": - # second stage of max-abs reduction - num_extra_kernels += 1 - # reciprocal of scale - num_extra_kernels += 1 if scaling_type_weight == "dynamic": # second stage of max-abs reduction num_extra_kernels += 1 - elif scaling_type_weight == "delayed": - # second stage of max-abs reduction - num_extra_kernels += 1 - # reciprocal of scale - num_extra_kernels += 1 if scaling_type_grad_output == "dynamic": # second stage of max-abs reduction num_extra_kernels += 1 - elif scaling_type_grad_output == "delayed": - # second stage of max-abs reduction - num_extra_kernels += 1 - # reciprocal of scale - num_extra_kernels += 1 extra_kernel_overhead_s = num_extra_kernels * TRITON_KERNEL_1_ELEMENT_TIME_SEC diff --git a/torchao/float8/stateful_float8_linear.py b/torchao/float8/stateful_float8_linear.py deleted file mode 100644 index ac01803e0b..0000000000 --- a/torchao/float8/stateful_float8_linear.py +++ /dev/null @@ -1,439 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -""" -Stateful version of Float8Linear, created to keep Float8Linear simple and -only require code readers to read the stateful code if they care about delayed -or static scaling. -""" - -from typing import Optional - -import torch -import torch.utils.checkpoint as checkpoint - -from torchao.float8.config import Float8LinearConfig, ScalingType -from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 -from torchao.float8.float8_linear import ( - Float8Linear, -) -from torchao.float8.float8_scaling_utils import ( - NoopFwToFloat8BwDelayed, - NoopFwToFloat8BwDynamic, - NoopFwToFloat8BwStatic, - _maybe_initialize_amaxes_scales_for_float8_cast, - hp_tensor_to_float8_delayed, - hp_tensor_to_float8_dynamic, - hp_tensor_to_float8_static, -) -from torchao.float8.float8_tensor import ( - GemmInputRole, - hp_tensor_and_scale_to_float8, -) -from torchao.float8.float8_utils import ( - tensor_to_amax, - tensor_to_scale, -) -from torchao.float8.fsdp_utils import ( - WeightWithDelayedFloat8CastTensor, - WeightWithDynamicFloat8CastTensor, - WeightWithStaticFloat8CastTensor, -) - - -@torch._dynamo.allow_in_graph -class manual_float8_matmul_with_args_in_float8(torch.autograd.Function): - """ - Like torch.matmul, but with the arguments in float8 - - Note: this function requires all arguments to already be Float8Tensor objects, - which only supports tensorwise scaling granularity. The reason we didn't just make this - function support axiswise scaling granularity is because that would need very - careful testing of delayed scaling, as delayed scaling modifies buffers inplace. - - In the future we'll probably have to unify, just postponing that until a future PR. - """ - - @staticmethod - def forward( - ctx, - input_fp8, - weight_fp8_t, - ): - ctx.save_for_backward(input_fp8, weight_fp8_t) - # the reshapes are needed in order to make the shapes compatible with - # torch.mm - orig_shape = input_fp8.shape - input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1]) - res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t) - res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) - return res_bits - - @staticmethod - def backward(ctx, grad_output_fp8): - input_fp8, weight_fp8_t = ctx.saved_tensors - - # the reshapes are needed in order to make the shapes compatible with - # torch.mm - grad_output_fp8_orig_shape = grad_output_fp8.shape - grad_output_fp8_reshaped = grad_output_fp8.reshape( - -1, grad_output_fp8_orig_shape[-1] - ) - - # calculate grad_input - grad_input = torch.mm( - grad_output_fp8_reshaped, - weight_fp8_t.t(), - ) - grad_input = grad_input.reshape( - *grad_output_fp8_orig_shape[:-1], grad_input.shape[-1] - ) - - input_fp8_orig_shape = input_fp8.shape - input_fp8_reshaped = input_fp8.reshape(-1, input_fp8_orig_shape[-1]) - - # calculate grad_weight - # Note: the variant below is slightly faster on LLaMa 3 8B pretraining - # compared to than calculating `grad_weight_t = input_fp8_t @ grad_output_fp8_reshaped` - grad_weight = torch.mm( - grad_output_fp8_reshaped.t(), - input_fp8_reshaped, - ) - - return grad_input, grad_weight.t() - - -class StatefulFloat8Linear(Float8Linear): - def __init__(self, *args, **kwargs): - # Amax scales should always be kept as float32. - self.always_float32_buffers = set() - - super().__init__(*args, **kwargs) - - # Convenience flag to skip code related to delayed scaling - self.has_any_delayed_scaling = ( - self.scaling_type_input is ScalingType.DELAYED - or self.scaling_type_weight is ScalingType.DELAYED - or self.scaling_type_grad_output is ScalingType.DELAYED - ) - - self.create_buffers() - - # Note: is_amax_initialized is not a buffer to avoid data dependent - # control flow visible to dynamo - # TODO(future PR): add serialization for this flag - self.is_amax_initialized = not self.config.enable_amax_init - - # pre_forward and post_forward are currently broken with FSDP - # and torch.compile, this option can disable them - # Note that when using `self.config.enable_pre_and_post_forward = False`, - # it's recommended to also set `self.config.enable_amax_init = False`. - # Otherwise, the amax buffer would never be marked as initialized and - # would be initialized in every iteration. - self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward - - def create_buffers(self): - # Default values for history buffers, see above TODO - history_len = self.config.delayed_scaling_config.history_len - device = self.weight.device - default_input = torch.finfo(self.config.cast_config_input.target_dtype).max - default_weight = torch.finfo(self.config.cast_config_weight.target_dtype).max - default_grad_output = torch.finfo( - self.config.cast_config_grad_output.target_dtype - ).max - - # Note: for now, create all the buffers if any are needed, to postpone - # the work to make the scale and amax syncing and history calculation - # handle a heterogeneous setup. We can do that work later if benchmarks - # show it is worth doing. - if self.has_any_delayed_scaling: - self.register_always_float32_buffer( - "fp8_amax_input", torch.tensor([default_input], device=device) - ) - self.register_always_float32_buffer( - "fp8_amax_history_input", torch.zeros(history_len, device=device) - ) - self.register_always_float32_buffer( - "fp8_scale_input", torch.tensor([1.0], device=device) - ) - self.register_always_float32_buffer( - "fp8_amax_weight", torch.tensor([default_weight], device=device) - ) - self.register_always_float32_buffer( - "fp8_amax_history_weight", torch.zeros(history_len, device=device) - ) - self.register_always_float32_buffer( - "fp8_scale_weight", torch.tensor([1.0], device=device) - ) - self.register_always_float32_buffer( - "fp8_amax_grad_output", - torch.tensor([default_grad_output], device=device), - ) - self.register_always_float32_buffer( - "fp8_amax_history_grad_output", torch.zeros(history_len, device=device) - ) - self.register_always_float32_buffer( - "fp8_scale_grad_output", torch.tensor([1.0], device=device) - ) - - if self.config.cast_config_input.static_scale is not None: - self.register_always_float32_buffer( - "fp8_static_scale_input", - self.config.cast_config_input.static_scale.to(device), - ) - if self.config.cast_config_weight.static_scale is not None: - self.register_always_float32_buffer( - "fp8_static_scale_weight", - self.config.cast_config_weight.static_scale.to(device), - ) - if self.config.cast_config_grad_output.static_scale is not None: - self.register_always_float32_buffer( - "fp8_static_scale_grad_output", - self.config.cast_config_grad_output.static_scale.to(device), - ) - - def register_always_float32_buffer( - self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True - ) -> None: - self.register_buffer(name=name, tensor=tensor, persistent=persistent) - self.always_float32_buffers.add(name) - - def _apply(self, fn, recurse=True): - ret = super()._apply(fn, recurse) - self.convert_amax_buffer_to_float32() - return ret - - def convert_amax_buffer_to_float32(self): - for key in self.always_float32_buffers: - if self._buffers[key] is not None: - self._buffers[key] = self._buffers[key].to(torch.float32) - - def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor: - is_amax_initialized = self.is_amax_initialized - # Duplicate the autocast logic for F.linear, so that the output - # of our module has the right original precision - if torch.is_autocast_enabled(): - # For now, hardcode to GPU's autocast dtype - # if we need CPU support in the future, we can add it - autocast_dtype = torch.get_autocast_gpu_dtype() - input = input.to(autocast_dtype) - - if tensor_already_casted_to_fp8(input): - input_fp8 = input - elif self.scaling_type_input is ScalingType.DELAYED: - scale_fn_name = self.config.delayed_scaling_config.scale_fn_name - _maybe_initialize_amaxes_scales_for_float8_cast( - input, - self.fp8_amax_input, - self.fp8_amax_history_input, - self.fp8_scale_input, - scale_fn_name, - self.config.cast_config_input.target_dtype, - is_amax_initialized, - reduce_amax=True, - ) - input_fp8 = hp_tensor_to_float8_delayed( - input, - self.fp8_scale_input, - self.config.cast_config_input.target_dtype, - self.fp8_amax_input, - linear_mm_config=self.linear_mm_config, - gemm_input_role=GemmInputRole.INPUT, - ) - elif self.scaling_type_input is ScalingType.DYNAMIC: - input_fp8 = hp_tensor_to_float8_dynamic( - input, - self.config.cast_config_input.target_dtype, - self.linear_mm_config, - gemm_input_role=GemmInputRole.INPUT, - ) - else: - assert self.scaling_type_input is ScalingType.STATIC - input_fp8 = hp_tensor_to_float8_static( - input, - self.fp8_static_scale_input, - self.config.cast_config_input.target_dtype, - self.linear_mm_config, - ) - - return input_fp8 - - def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]: - if tensor_already_casted_to_fp8(weight): - return None - if self.scaling_type_weight is ScalingType.DELAYED: - scale_fn_name = self.config.delayed_scaling_config.scale_fn_name - _maybe_initialize_amaxes_scales_for_float8_cast( - weight, - self.fp8_amax_weight, - self.fp8_amax_history_weight, - self.fp8_scale_weight, - scale_fn_name, - self.config.cast_config_weight.target_dtype, - self.is_amax_initialized, - reduce_amax=True, - ) - self.fp8_amax_weight.fill_(tensor_to_amax(weight)) - return self.fp8_scale_weight - elif self.scaling_type_weight is ScalingType.DYNAMIC: - return tensor_to_scale(weight, self.config.cast_config_weight.target_dtype) - else: - assert self.scaling_type_weight is ScalingType.STATIC - return self.fp8_static_scale_weight - - def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: - if self.scaling_type_grad_output is ScalingType.DELAYED: - scale_fn_name = self.config.delayed_scaling_config.scale_fn_name - output = NoopFwToFloat8BwDelayed.apply( - output, - self.fp8_amax_grad_output, - self.fp8_amax_history_grad_output, - self.fp8_scale_grad_output, - scale_fn_name, - self.is_amax_initialized, - self.linear_mm_config, - self.config.cast_config_grad_output.target_dtype, - ) - elif self.scaling_type_grad_output is ScalingType.DYNAMIC: - output = NoopFwToFloat8BwDynamic.apply( - output, - self.linear_mm_config, - self.config.cast_config_grad_output.target_dtype, - ) - else: - assert self.scaling_type_grad_output is ScalingType.STATIC - output = NoopFwToFloat8BwStatic.apply( - output, - self.fp8_static_scale_grad_output, - self.linear_mm_config, - self.config.cast_config_grad_output.target_dtype, - ) - return output - - def cast_weight_to_float8_t( - self, - weight: torch.Tensor, - weight_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if tensor_already_casted_to_fp8(weight): - return weight.t() - weight_fp8 = hp_tensor_and_scale_to_float8( - weight, - weight_scale, - self.config.cast_config_weight.target_dtype, - self.linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - ) - return weight_fp8.t() - - def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.has_any_delayed_scaling: - self.float8_pre_forward(input) - - input_fp8 = self.cast_input_to_float8(input) - # If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight, - # weight_scale should be saved. - weight_scale = self.get_weight_scale(self.weight) - - if self.config.force_recompute_fp8_weight_in_bwd: - weight_fp8_t = checkpoint.checkpoint( - self.cast_weight_to_float8_t, - self.weight, - weight_scale, - ) - else: - weight_fp8_t = self.cast_weight_to_float8_t(self.weight, weight_scale) - - output = manual_float8_matmul_with_args_in_float8.apply(input_fp8, weight_fp8_t) - - # Cast grad_output to float8_e5m2 during backward - output = self.cast_output_to_float8_in_bw(output) - - if self.bias is not None: - output = output + self.bias.to(output.dtype) - - if self.has_any_delayed_scaling: - self.float8_post_forward() - return output - - def float8_pre_forward(self, input): - # TODO(future PR): deprecate these functions and the corresponding - # config setting - if not self.enable_pre_and_post_forward: - return - - def float8_post_forward(self): - # TODO(future PR): deprecate these functions and the corresponding - # config setting - if not self.enable_pre_and_post_forward: - return - - @classmethod - def from_float( - cls, - mod, - config: Optional[Float8LinearConfig] = None, - ): - """ - Create an nn.Linear with fp8 compute from a regular nn.Linear - - Args: - mod (torch.nn.Linear): nn.Linear to convert - config (Optional[Float8LinearConfig]): configuration for conversion to float8 - """ - if config is None: - config = Float8LinearConfig() - with torch.device("meta"): - new_mod = cls( - mod.in_features, - mod.out_features, - bias=False, - config=config, - ) - new_mod.weight = mod.weight - new_mod.bias = mod.bias - # need to create buffers again when moving from meta device to - # real device - new_mod.create_buffers() - - # If FSDP float8 all-gather is on, wrap the weight in a float8-aware - # tensor subclass. This must happen last because: - # 1. weight needs to be on the correct device to create the buffers - # 2. buffers need to be already created for the delayed scaling version - # of the weight wrapper to be initialized - if config.enable_fsdp_float8_all_gather: - if config.cast_config_weight.scaling_type is ScalingType.DYNAMIC: - new_mod.weight = torch.nn.Parameter( - WeightWithDynamicFloat8CastTensor( - new_mod.weight, - new_mod.linear_mm_config, - new_mod.config.cast_config_weight.target_dtype, - ) - ) - elif config.cast_config_weight.scaling_type is ScalingType.DELAYED: - new_mod.weight = torch.nn.Parameter( - WeightWithDelayedFloat8CastTensor( - new_mod.weight, - new_mod.fp8_amax_weight, - new_mod.fp8_amax_history_weight, - new_mod.fp8_scale_weight, - new_mod.linear_mm_config, - new_mod.config.cast_config_weight.target_dtype, - new_mod.is_amax_initialized, - ) - ) - else: - assert config.cast_config_weight.scaling_type is ScalingType.STATIC - new_mod.weight = torch.nn.Parameter( - WeightWithStaticFloat8CastTensor( - new_mod.weight, - new_mod.fp8_static_scale_weight, - new_mod.linear_mm_config, - new_mod.config.cast_config_weight.target_dtype, - ) - ) - - return new_mod diff --git a/torchao/testing/float8/fsdp2_utils.py b/torchao/testing/float8/fsdp2_utils.py index a059b4d2a9..31a5cf8db0 100644 --- a/torchao/testing/float8/fsdp2_utils.py +++ b/torchao/testing/float8/fsdp2_utils.py @@ -8,10 +8,6 @@ Float8LinearConfig, ScalingType, ) -from torchao.float8.float8_linear_utils import ( - linear_requires_sync, - sync_float8_amax_and_scale_history, -) from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp @@ -38,9 +34,6 @@ def check_parity_no_mp( dist.all_reduce(param.grad) param.grad.div_(dist.get_world_size()) - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(model) - optim.step() if ( model is fsdp_model @@ -82,7 +75,6 @@ def check_parity_bf16_mp( param_bf16.grad.div_(dist.get_world_size()) param_fp32.grad = param_bf16.grad.float() param_bf16.grad = None - # TODO(future): add amax syncing once delayed scaling is supported optim.step() for param_fp32, param_bf16 in zip( ref_model.parameters(), ref_model_bf16.parameters() diff --git a/torchao/testing/float8/test_utils.py b/torchao/testing/float8/test_utils.py index 7b8ac121b6..2da34f53ed 100644 --- a/torchao/testing/float8/test_utils.py +++ b/torchao/testing/float8/test_utils.py @@ -1,9 +1,6 @@ -import torch - from torchao.float8.config import ( CastConfig, Float8LinearConfig, - ScalingType, ) @@ -13,32 +10,14 @@ def get_test_float8_linear_config( scaling_type_grad_output, emulate: bool, ): - static_scale_one = torch.tensor([1.0], device="cuda") - - if scaling_type_input is ScalingType.STATIC: - static_scale_input = static_scale_one - else: - static_scale_input = None - if scaling_type_weight is ScalingType.STATIC: - static_scale_weight = static_scale_one - else: - static_scale_weight = None - if scaling_type_grad_output is ScalingType.STATIC: - static_scale_grad_output = static_scale_one - else: - static_scale_grad_output = None - cast_config_input = CastConfig( scaling_type=scaling_type_input, - static_scale=static_scale_input, ) cast_config_weight = CastConfig( scaling_type=scaling_type_weight, - static_scale=static_scale_weight, ) cast_config_grad_output = CastConfig( scaling_type=scaling_type_grad_output, - static_scale=static_scale_grad_output, ) config = Float8LinearConfig(