Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 13 additions & 40 deletions benchmarks/float8/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@
ScalingType,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
linear_requires_sync,
sync_float8_amax_and_scale_history,
)
from torchao.float8.float8_tensor import ScaledMMConfig

# estimating TOPs for matmuls in fp32, fp16, fp8
Expand Down Expand Up @@ -122,39 +118,18 @@ def main(
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
scaling_granularity = ScalingGranularity(scaling_granularity)

if scaling_type_input is ScalingType.STATIC:
cast_config_input = CastConfig(
scaling_type=scaling_type_input,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_input = CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight = CastConfig(
scaling_type=scaling_type_weight,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_weight = CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output = CastConfig(
scaling_type=scaling_type_grad_output,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_grad_output = CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)
cast_config_input = CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
cast_config_weight = CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
cast_config_grad_output = CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
Expand Down Expand Up @@ -185,7 +160,7 @@ def main(
copy.deepcopy(linear_ref),
config=config,
)
scaling_repr = f"{linear_float8.scaling_type_repr()},{linear_float8.scaling_granularity_repr()}"
scaling_repr = linear_float8.extra_repr()

if fast_accum:
linear_float8.forward_config = ScaledMMConfig(False, True, False)
Expand All @@ -196,8 +171,6 @@ def main(
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()

def float8_forw_backward():
if linear_requires_sync(config):
sync_float8_amax_and_scale_history(linear_float8)
linear_float8(input_tensor).sum().backward()

def n_times(n, fn, *args, **kwargs):
Expand Down
180 changes: 0 additions & 180 deletions benchmarks/float8/bench_multi_gpu.py

This file was deleted.

47 changes: 0 additions & 47 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@
)

from torchao.float8 import (
CastConfig,
Float8LinearConfig,
ScalingType,
convert_to_float8_training,
)
from torchao.float8.roofline_utils import (
Expand Down Expand Up @@ -219,24 +217,6 @@ def run(
scaling_type_weight="dynamic",
scaling_type_grad_output="dynamic",
)
fp8_mem_time_sympy_del_limit = get_float8_mem_sympy(
M,
K,
N,
model_torch_compile_limitations=True,
scaling_type_input="delayed",
scaling_type_weight="delayed",
scaling_type_grad_output="delayed",
)
fp8_mem_time_sympy_del_nolimit = get_float8_mem_sympy(
M,
K,
N,
model_torch_compile_limitations=False,
scaling_type_input="delayed",
scaling_type_weight="delayed",
scaling_type_grad_output="delayed",
)

if gemm_time_strategy == "roofline":
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16)
Expand All @@ -258,16 +238,12 @@ def run(
# roofline memory overhead estimates
"fp8_oh_dyn_limit",
"fp8_oh_dyn_nolimit",
"fp8_oh_del_limit",
"fp8_oh_del_nolimit",
# actual e2e measurements
"bf16_s",
"fp8_dyn_s",
"fp8_del_s",
"fp8_dyn_axs_s",
# 'fp8_lw_s',
"fp8_dyn_sp",
"fp8_del_sp",
"fp8_dyn_axs_sp",
# 'fp8_lw_sp',
]
Expand Down Expand Up @@ -309,12 +285,6 @@ def run(
fp8_mem_time_dyn_nolimit_s = (
fp8_mem_time_sympy_dyn_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
)
fp8_mem_time_del_limit_s = (
fp8_mem_time_sympy_del_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
)
fp8_mem_time_del_nolimit_s = (
fp8_mem_time_sympy_del_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
)

# create the model
m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16()
Expand All @@ -333,19 +303,6 @@ def run(
m_fp8_dyn = torch.compile(m_fp8_dyn)
fp8_dyn_time_actual_s = get_gpu_kernel_time(m_fp8_dyn, x)

# get the float8 delayed scaling gpu kernel time
torch._dynamo.reset()
config = Float8LinearConfig(
enable_amax_init=False,
enable_pre_and_post_forward=False,
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
)
m_fp8_del = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
m_fp8_del = torch.compile(m_fp8_del)
fp8_del_time_actual_s = get_gpu_kernel_time(m_fp8_del, x)

# get the float8 dynamic axiswise scaling gpu kernel time
torch._dynamo.reset()
config = Float8LinearConfig.from_recipe_name("rowwise")
Expand Down Expand Up @@ -374,16 +331,12 @@ def run(
# roofline overhead estimates
fp8_mem_time_dyn_limit_s,
fp8_mem_time_dyn_nolimit_s,
fp8_mem_time_del_limit_s,
fp8_mem_time_del_nolimit_s,
# e2e numbers
bf16_time_actual_s,
fp8_dyn_time_actual_s,
fp8_del_time_actual_s,
fp8_dyn_axs_time_actual_s,
# fp8_lw_time_actual_s,
bf16_time_actual_s / fp8_dyn_time_actual_s,
bf16_time_actual_s / fp8_del_time_actual_s,
bf16_time_actual_s / fp8_dyn_axs_time_actual_s,
# bf16_time_actual_s / fp8_lw_time_actual_s,
]
Expand Down
Loading
Loading