diff --git a/benchmarks/prototype/moe_training/benchmark_kernels.py b/benchmarks/prototype/moe_training/benchmark_kernels.py index 7068fe5b58..d9e79c6cf3 100644 --- a/benchmarks/prototype/moe_training/benchmark_kernels.py +++ b/benchmarks/prototype/moe_training/benchmark_kernels.py @@ -19,8 +19,8 @@ triton_fp8_row_major_jagged_rowwise_scales, ) from torchao.prototype.moe_training.utils import ( - _to_2d_jagged_float8_tensor_colwise, - _to_2d_jagged_float8_tensor_rowwise, + torch_to_float8_per_group_colwise, + torch_to_float8_per_group_rowwise, ) device = torch.device("cuda") @@ -98,13 +98,13 @@ def warmup(func, *args, **kwargs): def run_torch( input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor ): - _ = _to_2d_jagged_float8_tensor_rowwise( + _ = torch_to_float8_per_group_rowwise( input_row_major, offs, target_dtype=torch.float8_e4m3fn, round_scales_to_power_of_2=True, ) - _ = _to_2d_jagged_float8_tensor_colwise( + _ = torch_to_float8_per_group_colwise( input_col_major, offs, target_dtype=torch.float8_e4m3fn, diff --git a/test/prototype/moe_training/test_kernels.py b/test/prototype/moe_training/test_kernels.py index ed68e8fa23..b24b61be8c 100644 --- a/test/prototype/moe_training/test_kernels.py +++ b/test/prototype/moe_training/test_kernels.py @@ -19,14 +19,18 @@ pytest.skip("Unsupported PyTorch version", allow_module_level=True) +from torchao.prototype.moe_training.kernels.float8_rowwise import ( + triton_fp8_rowwise_3d_transpose_rhs, +) from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( triton_fp8_col_major_jagged_colwise_scales, triton_fp8_row_major_jagged_rowwise_scales, ) from torchao.prototype.moe_training.utils import ( _is_column_major, - _to_2d_jagged_float8_tensor_colwise, - _to_2d_jagged_float8_tensor_rowwise, + torch_to_3d_rowwise_float8_transpose_rhs, + torch_to_float8_per_group_colwise, + torch_to_float8_per_group_rowwise, ) from torchao.testing.utils import skip_if_rocm @@ -42,7 +46,7 @@ def test_row_major_with_jagged_rowwise_scales(round_scales_to_power_of_2: bool): colwise_offs = torch.arange(k, k * n_groups + 1, k, device=device) # compute reference with torch impl - ref_fp8_data, ref_scales = _to_2d_jagged_float8_tensor_rowwise( + ref_fp8_data, ref_scales = torch_to_float8_per_group_rowwise( x, colwise_offs, target_dtype=torch.float8_e4m3fn, @@ -70,7 +74,7 @@ def test_column_major_with_jagged_colwise_scales(round_scales_to_power_of_2: boo rowwise_offs = torch.arange(m, m * n_groups + 1, m, device=device) # compute reference with torch impl - ref_fp8_data, ref_scales = _to_2d_jagged_float8_tensor_colwise( + ref_fp8_data, ref_scales = torch_to_float8_per_group_colwise( x, rowwise_offs, target_dtype=torch.float8_e4m3fn, @@ -85,3 +89,38 @@ def test_column_major_with_jagged_colwise_scales(round_scales_to_power_of_2: boo assert torch.eq(ref_fp8_data, kernel_fp8_data).all(), "fp8 data not equal" assert torch.eq(ref_scales, kernel_scales).all(), "scales not equal" assert _is_column_major(kernel_fp8_data), "fp8 data is not column major" + + +@skip_if_rocm("ROCm not supported") +@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False]) +def test_fp8_rowwise_3d_transpose_rhs(round_scales_to_power_of_2: bool): + device = "cuda" + experts, n, k = 8, 4 * 5120, 5120 + + # Example expert weights as it comes into forward transposed + torch.manual_seed(0) + x = torch.randn((experts, n, k), dtype=torch.bfloat16, device=device).transpose( + -2, -1 + ) + + # Compute reference with torch impl + ref_fp8, ref_scales = torch_to_3d_rowwise_float8_transpose_rhs( + x, + target_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=round_scales_to_power_of_2, + ) + # Torch impl keeps empty scaled dim, so we squeeze it out to be consistent with triton impl + ref_scales = ref_scales.squeeze(1) + + triton_fp8, triton_scales = triton_fp8_rowwise_3d_transpose_rhs( + x, + output_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=round_scales_to_power_of_2, + ) + assert ref_scales.shape == triton_scales.shape, "scale shapes not equal" + assert ref_scales.stride() == triton_scales.stride(), "scale strides not equal" + assert torch.allclose(ref_scales, triton_scales, rtol=0, atol=0), "scales not equal" + + assert ref_fp8.shape == triton_fp8.shape, "output shapes not equal" + assert ref_fp8.stride() == triton_fp8.stride(), "output strides not equal" + assert torch.allclose(ref_fp8, triton_fp8, rtol=0, atol=0), "fp8 data not equal" diff --git a/torchao/prototype/moe_training/kernels/float8_rowwise.py b/torchao/prototype/moe_training/kernels/float8_rowwise.py new file mode 100644 index 0000000000..2e75a0cc95 --- /dev/null +++ b/torchao/prototype/moe_training/kernels/float8_rowwise.py @@ -0,0 +1,251 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +EPS = 1e-12 + +FP8_DTYPE_MAP = { + torch.int8: tl.int8, + torch.int16: tl.int16, + torch.int32: tl.int32, + torch.int64: tl.int64, + torch.float8_e4m3fn: tl.float8e4nv, + torch.float8_e5m2: tl.float8e5, + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, + torch.float32: tl.float32, + torch.float64: tl.float64, +} + +block_sizes = [16] +num_warps = [4] +num_stages = [2] +kernel_configs_2D = [ + triton.Config( + {"BLOCK_SIZE_N": block_size, "BLOCK_SIZE_K": block_size * 2}, + num_warps=warps, + num_stages=stages, + ) + for block_size in block_sizes + for warps in num_warps + for stages in num_stages +] + +from torch.library import triton_op, wrap_triton + + +@triton_op("torchao::triton_fp8_rowwise_transpose_rhs", mutates_args={}) +def triton_fp8_rowwise_3d_transpose_rhs( + hp_tensor: torch.Tensor, # (E, K, N) + output_dtype: torch.dtype = torch.float8_e4m3fn, + round_scales_to_power_of_2: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert hp_tensor.ndim == 3, "input tensor must be 3D" + + num_elements = hp_tensor.numel() + tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] + tl_output_dtype = FP8_DTYPE_MAP[output_dtype] + + fp8_dtype_min = torch.finfo(output_dtype).min + fp8_dtype_max = torch.finfo(output_dtype).max + + e, k, n = hp_tensor.shape + + # allocate on-device buffers for output and scales + # output shape = input.transpose(-2, -1).shape = (E, N, K) in column major layout + output_buffer = torch.empty((e, k, n), dtype=output_dtype, device=hp_tensor.device) + output_buffer = output_buffer.transpose(-2, -1) + scales_buffer = torch.full( + (e, k), float("inf"), dtype=torch.float32, device=hp_tensor.device + ) + + # parallelize across experts, and for each expert, parallelize across rows and cols + grid = lambda meta: ( + e, + triton.cdiv(k, meta["BLOCK_SIZE_K"]), + triton.cdiv(n, meta["BLOCK_SIZE_N"]), + ) + + # compute scales + wrap_triton(_triton_fp8_rowwise_3d_transpose_scales_rhs_kernel)[grid]( + hp_tensor, + hp_tensor.stride(0), + hp_tensor.stride(1), + hp_tensor.stride(2), + scales_buffer, + scales_buffer.stride(0), + scales_buffer.stride(1), + e, + n, + k, + num_elements, + fp8_dtype_min, + fp8_dtype_max, + tl_input_dtype, + round_scales_to_power_of_2=round_scales_to_power_of_2, + EPS=EPS, + ) + + # perform casting + wrap_triton(_triton_fp8_rowwise_3d_transpose_cast_rhs_kernel)[grid]( + hp_tensor, + hp_tensor.stride(0), + hp_tensor.stride(1), + hp_tensor.stride(2), + output_buffer, + output_buffer.stride(0), + output_buffer.stride(1), + output_buffer.stride(2), + scales_buffer, + scales_buffer.stride(0), + scales_buffer.stride(1), + e, + n, + k, + num_elements, + fp8_dtype_min, + fp8_dtype_max, + tl_input_dtype, + tl_output_dtype, + ) + return output_buffer, scales_buffer + + +@triton.autotune(configs=kernel_configs_2D, key=["num_elements"]) +@triton.jit +def _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel( + input_ptr, + stride_input_dim0: int, + stride_input_dim1: int, + stride_input_dim2: int, + scales_ptr, + stride_scales_dim0: int, + stride_scales_dim1: int, + E: int, + N: int, + K: int, + num_elements: int, + fp8_dtype_min: tl.constexpr, + fp8_dtype_max: tl.constexpr, + input_dtype: tl.constexpr, + round_scales_to_power_of_2: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + EPS: tl.constexpr, +): + # parallelize across experts, rows, and cols + expert_idx = tl.program_id(0) + k_block_idx = tl.program_id(1) + n_block_idx = tl.program_id(2) + + # compute offsets for each dimension + k_offs = k_block_idx * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + n_offs = n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + # load block of input data, shape (K, N) + input_offs = ( + expert_idx * stride_input_dim0 + + k_offs[:, None] * stride_input_dim1 + + (n_offs[None, :] * stride_input_dim2) + ) + input_mask = (k_offs[:, None] < K) & (n_offs[None, :] < N) + input_data = tl.load(input_ptr + input_offs, mask=input_mask, other=0.0).to( + input_dtype + ) + + # compute scales with local amax, using axis=0 because for each expert, + # we are reading the non-transposed input, and want to compute the scales + # along axis=1 for the transposed input. + amaxes = tl.max(tl.abs(input_data), axis=1).to(tl.float64) # (K,) + scales = (fp8_dtype_max / tl.clamp(amaxes, min=EPS, max=float("inf"))).to( + tl.float32 + ) + if round_scales_to_power_of_2: + scales = tl.exp2(tl.floor(tl.log2(scales))) + + # compute global scales using atomics with local scales - shape (1, K) + scales_offs = ( + expert_idx[:, None] * stride_scales_dim0 + k_offs[None, :] * stride_scales_dim1 + ) + scales_mask = k_offs[None, :] < K + tl.atomic_min(scales_ptr + scales_offs, scales[None, :], mask=scales_mask) + + +@triton.autotune(configs=kernel_configs_2D, key=["num_elements"]) +@triton.jit +def _triton_fp8_rowwise_3d_transpose_cast_rhs_kernel( + input_ptr, + stride_input_dim0: int, + stride_input_dim1: int, + stride_input_dim2: int, + output_ptr, + stride_output_dim0: int, + stride_output_dim1: int, + stride_output_dim2: int, + scales_ptr, + stride_scales_dim0: int, + stride_scales_dim1: int, + E: int, + N: int, + K: int, + num_elements: int, + fp8_dtype_min: tl.constexpr, + fp8_dtype_max: tl.constexpr, + input_dtype: tl.constexpr, + output_dtype: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + # parallelize across experts, rows, and cols + expert_idx = tl.program_id(0) + k_block_idx = tl.program_id(1) + n_block_idx = tl.program_id(2) + + # compute offsets for each dimension + k_offs = k_block_idx * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + n_offs = n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + # load block of input data for this expert - shape (K, N) + input_offs = ( + expert_idx * stride_input_dim0 + + k_offs[:, None] * stride_input_dim1 + + (n_offs[None, :] * stride_input_dim2) + ) + input_mask = (k_offs[:, None] < K) & (n_offs[None, :] < N) + input_data = tl.load(input_ptr + input_offs, mask=input_mask, other=0.0).to( + input_dtype + ) + input_data = input_data.trans(1, 0) # (K, N) -> (N, K) + + # load global scales for this block of the given expert - shape (1, K) + scales_offs = ( + expert_idx[:, None] * stride_scales_dim0 + k_offs[None, :] * stride_scales_dim1 + ) + scales_mask = k_offs[None, :] < K + scales = tl.load(scales_ptr + scales_offs, mask=scales_mask, other=0.0).to( + tl.float32 + ) + + # transpose data and apply scales - shape (N,K) * (1,K) = (N,K) + scaled_data = input_data * scales + output_data = tl.clamp(scaled_data, min=fp8_dtype_min, max=fp8_dtype_max).to( + output_dtype + ) + + # store transpose and store output data - shape (N, K) + output_offs = ( + expert_idx * stride_output_dim0 + + n_offs[:, None] * stride_output_dim1 + + (k_offs[None, :] * stride_output_dim2) + ) + output_mask = (n_offs[:, None] < N) & (k_offs[None, :] < K) + tl.store(output_ptr + output_offs, output_data, mask=output_mask) diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index f4dca9f4e8..5604d1ecad 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -126,9 +126,9 @@ def forward( A_fp8_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn) # Convert B to float8, column-major for right operand of grouped GEMM. - # B shape: (E, K, N) - # B scales must be computed rowwise keeping the outer/final dim, so: - # B_scales shape: (E, 1, N) + # B_t shape: (E, K, N) + # B_t scales must be computed rowwise keeping the outer/final dim, so: + # B_t_scales shape: (E, 1, N) B_t_scales = tensor_to_scale( B_t, torch.float8_e4m3fn, @@ -144,9 +144,9 @@ def forward( # In the backward this is needed for grad_A: grad_output @ B. B = B_t.contiguous().transpose(-2, -1) - # - B shape: (E, K, N) + # - B shape: (E, N, K) # - B scales must be computed rowwise keeping the outer/final dim, so: - # - B_scale shape: (E, 1, N) + # - B_scale shape: (E, 1, K) B_scales = tensor_to_scale( B, torch.float8_e4m3fn, diff --git a/torchao/prototype/moe_training/utils.py b/torchao/prototype/moe_training/utils.py index 21f917ce03..ba02eafc7d 100644 --- a/torchao/prototype/moe_training/utils.py +++ b/torchao/prototype/moe_training/utils.py @@ -9,7 +9,7 @@ # --- float8 rowwise scaling --- -def _to_2d_jagged_float8_tensor_colwise( +def torch_to_float8_per_group_colwise( A_col_major: torch.Tensor, offs: torch.Tensor, target_dtype: torch.dtype = torch.float8_e4m3fn, @@ -78,7 +78,7 @@ def _to_2d_jagged_float8_tensor_colwise( return A_fp8_col_major, A_scales -def _to_2d_jagged_float8_tensor_rowwise( +def torch_to_float8_per_group_rowwise( x: torch.Tensor, offs: torch.Tensor, target_dtype: torch.dtype, @@ -145,6 +145,41 @@ def _to_2d_jagged_float8_tensor_rowwise( return x_fp8, x_scales +def torch_to_3d_rowwise_float8_transpose_rhs( + input_hp: torch.Tensor, # (E, K, N) + target_dtype: torch.dtype = torch.float8_e4m3fn, + round_scales_to_power_of_2: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This function converts the 3D input tensor to a float8 tensor, with scales computed along logical columns + on a per-expert basis. + + Args: + x (torch.Tensor): The input tensor to be converted to a float8 tensor. Shape (E, K, N). + + Returns: + A tuple containing the float8 tensor and the scales used for the conversion. + Output shape: (E, N, K) + Scales shape: (E, 1, K + """ + input_hp_t = input_hp.transpose(-2, -1) # (E, N, K) + scales = tensor_to_scale( + input_hp_t, + target_dtype, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-2, + round_scales_to_power_of_2=round_scales_to_power_of_2, + ) # (E, 1, K) + + # Apply scales to tensor and convert to float8. + tensor_scaled = input_hp_t.to(torch.float32) * scales + float8_tensor = to_fp8_saturated(tensor_scaled, target_dtype) + + # To column major + float8_tensor = float8_tensor.transpose(-2, -1).contiguous().transpose(-2, -1) + return float8_tensor, scales + + # --- mxfp8 scaling --- def _to_mxfp8_per_group_rowwise( x: torch.Tensor,