From f70cc907604c03850008b4adff00b5b2d5625214 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 21 Aug 2025 16:49:07 -0700 Subject: [PATCH] [mxfp8 moe] add support for fbgemm 2d-3d mx8mx8bf16 grouped gemm stack-info: PR: https://github.com/pytorch/ao/pull/2848, branch: danielvegamyhre/stack/55 --- .../moe_training/test_scaled_grouped_mm.py | 32 +++-- test/prototype/moe_training/test_training.py | 133 ++++------------- .../moe_training/kernels/__init__.py | 3 + .../prototype/moe_training/kernels/mxfp8.py | 135 ++++++++++++++++++ .../moe_training/scaled_grouped_mm.py | 116 ++++++++------- torchao/prototype/mx_formats/utils.py | 73 ++++++++++ 6 files changed, 324 insertions(+), 168 deletions(-) create mode 100644 torchao/prototype/moe_training/kernels/mxfp8.py diff --git a/test/prototype/moe_training/test_scaled_grouped_mm.py b/test/prototype/moe_training/test_scaled_grouped_mm.py index 9b340a900f..1fd39451ce 100644 --- a/test/prototype/moe_training/test_scaled_grouped_mm.py +++ b/test/prototype/moe_training/test_scaled_grouped_mm.py @@ -230,25 +230,26 @@ def compute_reference_forward( @pytest.mark.parametrize("num_experts", (1, 8, 16)) def test_emulate_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts): x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") - w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda") + w = torch.randn(num_experts, N, K, dtype=torch.bfloat16, device="cuda") offs = generate_jagged_offs(num_experts, M) - x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone() + x_ref, w_ref, offs_ref = x.clone(), w.clone(), offs.clone() # Quantize inputs to mxpf8 for emulated mxfp8 scaled grouped mm block_size = 32 - x_scale, x_mx = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size) + x_scale, x_fp8 = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size) # To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose. - w_scale, w_mx = to_mx( - w_t.transpose(-2, -1).contiguous(), + w_scale, w_fp8 = to_mx( + w, elem_dtype=torch.float8_e4m3fn, block_size=block_size, ) - w_t_scale, w_t_mx = w_scale.transpose(-2, -1), w_mx.transpose(-2, -1) - ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16) + ref_out = torch._grouped_mm( + x_ref, w_ref.transpose(-2, -1), offs=offs_ref, out_dtype=torch.bfloat16 + ) out = _emulated_mxfp8_scaled_grouped_mm_2d_3d( - x_mx, x_scale, w_t_mx, w_t_scale, offs=offs, out_dtype=torch.bfloat16 + x_fp8, x_scale, w_fp8, w_scale, offs=offs, out_dtype=torch.bfloat16 ) sqnr = compute_error(ref_out, out) @@ -305,8 +306,10 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts): @skip_if_rocm("ROCm not supported") -@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)]) -@pytest.mark.parametrize("num_experts", (1, 8, 16)) +@pytest.mark.parametrize( + "M,K,N", [(1024, 5120, 8192), (2048, 5120, 8192), (16640, 5120, 8192)] +) +@pytest.mark.parametrize("num_experts", (2, 4, 8, 16)) def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts): from torchao.prototype.moe_training.scaled_grouped_mm import ( _MXFP8GroupedMM, @@ -314,9 +317,14 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts): block_size = 32 x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True) - w_t = torch.randn( - num_experts, K, N, dtype=torch.bfloat16, device="cuda", requires_grad=True + w = torch.randn( + num_experts, + N, + K, + dtype=torch.bfloat16, + device="cuda", ) + w_t = w.transpose(-2, -1).requires_grad_(True) offs = generate_jagged_offs(num_experts, M, multiple_of=block_size) x_ref, w_t_ref, offs_ref = ( x.clone().detach().requires_grad_(True), diff --git a/test/prototype/moe_training/test_training.py b/test/prototype/moe_training/test_training.py index 0aae474ae4..26c9c279d9 100644 --- a/test/prototype/moe_training/test_training.py +++ b/test/prototype/moe_training/test_training.py @@ -40,109 +40,38 @@ ], ) @pytest.mark.parametrize("compile", [False, True]) -def test_moe_float8_training(target_fqns: list[str], compile: bool): - # Set token group alignment size to 16. This is required so that - # each logically distinct gemm in the grouped gemm `grad_weight = grad_output_t @ input` - # has the contraction dim be divisible by 16. 16 byte alignment is required - # for the slowest moving dim (stride 1), so 16 bytes / 1 byte per element in fp8 = 16 elements. - set_token_group_alignment_size_m(16) - model_args = MoEArgs( - num_experts=8, - ) - init_std = 0.02 - device = torch.device("cuda") - - # reference bf16 MoE - dim, hidden_dim = 5120, 8192 - ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() - torch.manual_seed(42) - ref_model.init_weights(init_std, device) - - # target MoE for testing conversion - model = copy.deepcopy(ref_model) - - # assert starting params are identical for both models - for param1, param2 in zip(model.parameters(), ref_model.parameters()): - assert torch.equal(param1, param2) - - # convert MoE to float8 training - def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: - for target_fqn in target_fqns: - if target_fqn in cur_fqn: - return True - return False - - # quantize test model - config = MoETrainingConfig() - quantize_(model, config=config, filter_fn=moe_module_filter_fn) - - # validate that only the experts were converted - _validate_model_conversion( - model, - target_fqns=target_fqns, - ) - if compile: - # TODO: compile with fullgraph=True when torchtitan llama4 moe supports it - model = torch.compile(model, fullgraph=False) - ref_model = torch.compile(ref_model, fullgraph=False) - - # inputs - batch, seq = 8, 2048 - ref_x = torch.randn( - batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device - ) - x = ref_x.detach().clone().requires_grad_(True) - - # forward pass - ref_out = ref_model(ref_x) - out = model(x) - - # validate output - out_sqnr = compute_error(out, ref_out) - min_out_sqnr = 29.0 - assert out_sqnr.item() >= min_out_sqnr, ( - f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}." - ) - - # compute loss - labels = torch.ones_like(ref_out) - ref_loss = F.mse_loss(ref_out, labels) - out_loss = F.mse_loss(out, labels) - - # backward pass - ref_loss.backward() - out_loss.backward() - - # validate input gradient - input_grad_sqnr = compute_error(x.grad, ref_x.grad) - min_input_grad_sqnr = 29.0 - assert input_grad_sqnr.item() >= min_input_grad_sqnr, ( - f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}." - ) - - # validate param gradients - min_param_grad_sqnr = 23.0 - for param1, param2 in zip(model.parameters(), ref_model.parameters()): - param_grad_sqnr = compute_error(param1.grad, param2.grad) - assert param_grad_sqnr.item() >= min_param_grad_sqnr, ( - f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}." - ) - - @pytest.mark.parametrize( - "target_fqns", + "recipe_config", [ - ["experts"], - ["does.not.exist"], + # {"recipe": MoEScalingType.FP8_ROWWISE, "group_alignment_size": 16, "min_out_sqnr": 29.0, "min_input_grad_sqnr": 29.0, "min_param_grad_sqnr": 23.0}, + { + "recipe": MoEScalingType.MXFP8, + "group_alignment_size": 32, + "min_out_sqnr": 28.0, + "min_input_grad_sqnr": 29.0, + "min_param_grad_sqnr": 21.0, + }, ], ) -@pytest.mark.parametrize("compile", [False, True]) -def test_moe_mxfp8_training(target_fqns: list[str], compile: bool): - block_size = 32 - - # Token groups must be divisible by 32 for mxfp8 - set_token_group_alignment_size_m(block_size) - +def test_moe_training(target_fqns: list[str], compile: bool, recipe_config: dict): + ( + recipe, + group_alignment_size, + min_out_sqnr, + min_input_grad_sqnr, + min_param_grad_sqnr, + ) = ( + recipe_config["recipe"], + recipe_config["group_alignment_size"], + recipe_config["min_out_sqnr"], + recipe_config["min_input_grad_sqnr"], + recipe_config["min_param_grad_sqnr"], + ) + # Set token group alignment size. This is required so that + # each logically distinct gemm in the grouped gemm `grad_weight = grad_output_t @ input` + # has the contraction dim be divisible by 16. 16 byte alignment is required + # for the slowest moving dim (stride 1). + set_token_group_alignment_size_m(group_alignment_size) model_args = MoEArgs( num_experts=8, ) @@ -170,7 +99,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: return False # quantize test model - config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8) + config = MoETrainingConfig(scaling_type=recipe) quantize_(model, config=config, filter_fn=moe_module_filter_fn) # validate that only the experts were converted @@ -178,7 +107,6 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: model, target_fqns=target_fqns, ) - if compile: # TODO: compile with fullgraph=True when torchtitan llama4 moe supports it model = torch.compile(model, fullgraph=False) @@ -197,7 +125,6 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: # validate output out_sqnr = compute_error(out, ref_out) - min_out_sqnr = 28.0 assert out_sqnr.item() >= min_out_sqnr, ( f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}." ) @@ -213,13 +140,11 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: # validate input gradient input_grad_sqnr = compute_error(x.grad, ref_x.grad) - min_input_grad_sqnr = 30.0 assert input_grad_sqnr.item() >= min_input_grad_sqnr, ( f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}." ) # validate param gradients - min_param_grad_sqnr = 21.0 for param1, param2 in zip(model.parameters(), ref_model.parameters()): param_grad_sqnr = compute_error(param1.grad, param2.grad) assert param_grad_sqnr.item() >= min_param_grad_sqnr, ( diff --git a/torchao/prototype/moe_training/kernels/__init__.py b/torchao/prototype/moe_training/kernels/__init__.py index 0b88cc08a2..93531f7922 100644 --- a/torchao/prototype/moe_training/kernels/__init__.py +++ b/torchao/prototype/moe_training/kernels/__init__.py @@ -7,3 +7,6 @@ from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( triton_fp8_per_group_rowwise_scales as triton_fp8_per_group_rowwise_scales, ) +from torchao.prototype.moe_training.kernels.mxfp8 import ( + fbgemm_mxfp8_grouped_mm_2d_3d as fbgemm_mxfp8_grouped_mm_2d_3d, +) diff --git a/torchao/prototype/moe_training/kernels/mxfp8.py b/torchao/prototype/moe_training/kernels/mxfp8.py new file mode 100644 index 0000000000..c3683cf853 --- /dev/null +++ b/torchao/prototype/moe_training/kernels/mxfp8.py @@ -0,0 +1,135 @@ +import logging + +import torch + +from torchao.prototype.mx_formats.utils import ( + to_blocked_per_group_2d, + to_blocked_per_group_3d, +) + +logger: logging.Logger = logging.getLogger(__name__) + +try: + import fbgemm_gpu.experimental.gen_ai # noqa: F401 +except Exception as e: + logging.warning( + f"fbgemm_gpu_genai package is required for this feature but import failed with exception: {e}" + "Please install nightly builds of pytorch and fbgemm_gpu_genai build using this command and try again: " + "pip3 install --force-reinstall --pre torch fbgemm-gpu-genai --index-url https://download.pytorch.org/whl/nightly/cu129" + "If errors persist, please file a bug report." + ) + + +@torch.library.custom_op("torchao::fbgemm_mxfp8_grouped_mm_2d_3d", mutates_args={}) +def fbgemm_mxfp8_grouped_mm_2d_3d( + A_fp8: torch.Tensor, + A_scales: torch.Tensor, + B_fp8: torch.Tensor, + B_scales: torch.Tensor, + offs: torch.Tensor, + block_size: int = 32, + out_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + assert A_fp8.ndim == 2, "A_fp8 tensor must be 2D" + assert B_fp8.ndim == 3, "B_fp8 tensor must be 3D" + assert block_size == 32, "Only block_size=32 is supported" + assert out_dtype == torch.bfloat16, "Only out_dtype=bfloat16 is supported" + assert A_fp8.shape[-1] == B_fp8.shape[-1], "A_fp8 and B_fp8 must have same last dim" + + # Convert scales for each group to blocked format. + Mg, K = A_fp8.shape + A_scales_blocked, starting_row_after_padding = to_blocked_per_group_2d( + A_scales, offs, Mg, K + ) + B_scales_blocked = to_blocked_per_group_3d(B_scales) + + # From this, we compute `group_sizes` and `starting_row_after_padding`: + # group_sizes = [32, 32, 64] + # starting_row_after_padding = [0, 32, 64, 128] + zero = torch.tensor([0], dtype=offs.dtype, device=offs.device) + group_sizes = torch.diff(offs, prepend=zero).to(torch.int64) + + # TODO: remove debug logging once prototype is more mature. + _log_inputs( + A_fp8, + B_fp8, + A_scales, + A_scales_blocked, + B_scales, + B_scales_blocked, + offs, + group_sizes, + starting_row_after_padding, + ) + + out = torch.ops.fbgemm.mx8mx8bf16_grouped_stacked( + A_fp8, + B_fp8, + A_scales_blocked, + B_scales_blocked, + group_sizes, + starting_row_after_padding=starting_row_after_padding, + ) + return out + + +@fbgemm_mxfp8_grouped_mm_2d_3d.register_fake +def _fbgemm_mxfp8_grouped_mm_2d_3d_fake( + A_fp8: torch.Tensor, + A_scales: torch.Tensor, + B_fp8: torch.Tensor, + B_scales: torch.Tensor, + offs: torch.Tensor, + block_size: int = 32, + out_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + assert A_fp8.ndim == 2, "A_fp8 tensor must be 2D" + assert B_fp8.ndim == 3, "B_fp8 tensor must be 3D" + assert out_dtype == torch.bfloat16, "Only out_dtype=bfloat16 is supported" + assert A_fp8.shape[-1] == B_fp8.shape[-1], "A_fp8 and B_fp8 must have same last dim" + mg, k = A_fp8.shape + e, n, k = B_fp8.shape + n_groups = offs.numel() + assert n_groups == e, ( + "Size of `offs` (number of groups) must match first dim of `B_fp8`" + ) + output = torch.empty((mg, n), dtype=torch.bfloat16, device=A_fp8.device) + return output + + +def _log_inputs( + A_fp8: torch.Tensor, + B_fp8: torch.Tensor, + A_scales: torch.Tensor, + A_scales_blocked: torch.Tensor, + B_scales: torch.Tensor, + B_scales_blocked: torch.Tensor, + offs: torch.Tensor, + group_sizes: torch.Tensor, + starting_row_after_padding: torch.Tensor, +): + logger.info(f"offs: {offs}, dtype: {offs.dtype}") + logger.info( + f"A_fp8.shape: {A_fp8.shape}, stride: {A_fp8.stride()}, dtype: {A_fp8.dtype}" + ) + logger.info( + f"B_fp8.shape: {B_fp8.shape}, stride: {B_fp8.stride()}, dtype: {B_fp8.dtype}" + ) + logger.info( + f"A_scales (non-blocked) shape: {A_scales.shape}, stride: {A_scales.stride()}, dtype: {A_scales.dtype}" + ) + logger.info( + f"A_scales_blocked.shape: {A_scales_blocked.shape}, stride: {A_scales_blocked.stride()}, dtype: {A_scales_blocked.dtype}" + ) + logger.info( + f"B_scales (non-blocked) shape: {B_scales.shape}, stride: {B_scales.stride()}, dtype: {B_scales.dtype}" + ) + logger.info( + f"B_scales_blocked.shape: {B_scales_blocked.shape}, stride: {B_scales_blocked.stride()}, dtype: {B_scales_blocked.dtype}" + ) + logger.info( + f"group_sizes: {group_sizes}, stride: {group_sizes.stride()}, dtype: {group_sizes.dtype}" + ) + logger.info( + f"starting_row_after_padding: {starting_row_after_padding}, stride: {starting_row_after_padding.stride()}, dtype: {starting_row_after_padding.dtype}" + ) diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index a966e528c9..1a9f762773 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -13,6 +13,7 @@ from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated from torchao.prototype.moe_training.conversion_utils import MoEScalingType from torchao.prototype.moe_training.kernels import ( + fbgemm_mxfp8_grouped_mm_2d_3d, triton_fp8_per_group_colwise_scales, triton_fp8_per_group_rowwise_scales, triton_fp8_rowwise_3d_transpose_rhs, @@ -277,52 +278,43 @@ def forward( offs: Optional[torch.Tensor] = None, block_size: int = 32, out_dtype: Optional[torch.dtype] = torch.bfloat16, - emulated: bool = True, + emulated: bool = False, ) -> torch.Tensor: # torchao _scaled_grouped_mm only supports A=2D and B=3D. assert A.ndim == 2, "A must be 2D" assert B_t.ndim == 3, "B must be 3D" assert block_size == 32, "Only block_size=32 is supported" - assert emulated, "Only emulated mxfp8 grouped gemm is supported" - # Cast to mxpf8 across dim -1. + # Store what we need for backward. + ctx.save_for_backward(A, B_t, offs) + ctx.block_size = block_size + ctx.out_dtype = out_dtype + ctx.emulated = emulated + # A_mx shape: (M, K) # A_scale shape: (M, K//block_size) A_scale, A_mx = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size) - # Cast B_t per-expert to mxfp8 across dim1. - # B_t_mx shape: (E, K, N) - # B_t_scale shape: (E, K//block_size, N) - - # To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose. # B_mx shape: (E, N, K) # B_scale shape: (E, N, K//block_size) - B_scales_dim2, B_mx_dim2 = to_mx( - B_t.transpose(-2, -1), # (E,K,N) -> (E,N,K) + B_scales, B_mx = to_mx( + B_t.transpose(-2, -1), elem_dtype=torch.float8_e4m3fn, block_size=block_size, ) - # B_t_mx shape: (E, K, N) - # B_t_scale shape: (E, K//block_size, N) - B_t_scales_dim1 = B_scales_dim2.transpose( - -2, -1 - ) # (E,N,K//block_size) -> (E,K//block_size,N) - B_t_mx_dim1 = B_mx_dim2.transpose(-2, -1) # (E,N,K) -> (E,K,N) - - # Store what we need for backward. - ctx.save_for_backward(A, B_t, offs) - ctx.block_size = block_size - ctx.out_dtype = out_dtype - - # Perform scaled grouped GEMM and return result. # output = input @ weight.T # output shape: (M, N) - out = _emulated_mxfp8_scaled_grouped_mm_2d_3d( + mxfp8_2d_3d_grouped_mm = ( + _emulated_mxfp8_scaled_grouped_mm_2d_3d + if emulated + else fbgemm_mxfp8_grouped_mm_2d_3d + ) + out = mxfp8_2d_3d_grouped_mm( A_mx, A_scale, - B_t_mx_dim1, - B_t_scales_dim1, + B_mx, + B_scales, offs=offs, block_size=block_size, out_dtype=out_dtype, @@ -334,6 +326,7 @@ def backward(ctx, grad_out: torch.Tensor): A, B_t, offs = ctx.saved_tensors block_size = ctx.block_size out_dtype = ctx.out_dtype + emulated = ctx.emulated # grad_out_mx shape: (M, N) # grad_out_scale shape: (M, N//block_size) @@ -343,23 +336,24 @@ def backward(ctx, grad_out: torch.Tensor): # B_mx shape: (E, K, N) # B_scale shape: (E, K, N//block_size) - B_t_scale_dim2, B_t_mx_dim2 = to_mx( + B_scales, B_mx = to_mx( + # TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency? B_t.contiguous(), elem_dtype=torch.float8_e4m3fn, block_size=block_size, ) - B_scale_dim1 = B_t_scale_dim2.transpose( - -2, -1 - ) # (E,K,N//block_size) -> (E,N//block_size,K) - B_mx_dim1 = B_t_mx_dim2.transpose(-2, -1) # (E,K,N) -> (E,N,K) - # Compute grad_A. # grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K) - grad_A = _emulated_mxfp8_scaled_grouped_mm_2d_3d( + mxfp8_2d_3d_grouped_mm = ( + _emulated_mxfp8_scaled_grouped_mm_2d_3d + if emulated + else fbgemm_mxfp8_grouped_mm_2d_3d + ) + grad_A = mxfp8_2d_3d_grouped_mm( grad_out_mx, grad_out_scale, - B_mx_dim1, - B_scale_dim1, + B_mx, + B_scales, offs=offs, out_dtype=out_dtype, ) @@ -367,25 +361,28 @@ def backward(ctx, grad_out: torch.Tensor): # grad_out_t_mx shape: (N, M) # grad_out_t_scales shape: (N, M//block_size) grad_out_t_scales, grad_out_t_mx = to_mx( - grad_out.transpose(-2, -1).contiguous(), # (M,N) -> (N,M) + # TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency? + grad_out.transpose(-2, -1).contiguous(), elem_dtype=torch.float8_e4m3fn, block_size=block_size, ) + # Transpose A so we can scale along the M dimension, then un-transpose. + # A_t_mx shape: (K, M) + # A_t_scales shape: (K, M//block_size) A_t_scales, A_t_mx = to_mx( - A.transpose(-2, -1).contiguous(), # (M,K) -> (K,M) + A.transpose(-2, -1).contiguous(), elem_dtype=torch.float8_e4m3fn, block_size=block_size, ) - A_scales = A_t_scales.transpose( - -2, -1 - ) # (K,M//block_size) -> (M//block_size,K) - A_mx = A_t_mx.transpose(-2, -1) # (K,M) -> (M,K) - # Compute grad_B = grad_output_t @ A - # grad_B_t = scaled grouped mm of (N,M) @ (M,K) = (E,N,K) - # grad_B = grad_B_t.transpose(-2, -1) = (E,K,N) + # A_mx shape = (M, K) + A_mx = A_t_mx.transpose(-2, -1) + + # A_scales shape = (M//block_size, K) + A_scales = A_t_scales.transpose(-2, -1) + # grad_B_t = scaled grouped mm of (N,M) @ (M,K) = (E,N,K) grad_B = _emulated_mxfp8_scaled_grouped_mm_2d_2d( grad_out_t_mx, grad_out_t_scales, @@ -393,7 +390,8 @@ def backward(ctx, grad_out: torch.Tensor): A_scales, offs=offs, ) - # In forward we receive pre-transposed weights B_t as input + + # grad_B shape = (E,K,N) grad_B_t = grad_B.transpose(-2, -1) return grad_A, grad_B_t, None, None, None @@ -402,12 +400,30 @@ def backward(ctx, grad_out: torch.Tensor): def _emulated_mxfp8_scaled_grouped_mm_2d_3d( A_mx: torch.Tensor, A_scale: torch.Tensor, - B_t_mx: torch.Tensor, - B_t_scale: torch.Tensor, + B_mx: torch.Tensor, + B_scale: torch.Tensor, offs: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = torch.bfloat16, block_size: int = 32, ) -> torch.Tensor: + assert A_mx.ndim == 2, f"A must be 2D, got {A_mx.ndim}" + assert B_mx.ndim == 3, f"B must be 3D, got {B_mx.ndim}" + assert A_scale.shape[0] == A_mx.shape[0], ( + f"A_scale must have same M dim as A_mx, got A={A_mx.shape} and A_scale={A_scale.shape}" + ) + assert A_scale.shape[1] == A_mx.shape[1] // block_size, ( + f"A_scale dim1 should be size K//block_size, got A={A_mx.shape} and A_scale={A_scale.shape}" + ) + assert B_scale.shape[0] == B_mx.shape[0], ( + f"B_scale must have same E dim as B_mx, got B={B_mx.shape} and B_scale={B_scale.shape}" + ) + assert B_scale.shape[1] == B_mx.shape[1], ( + f"B_scale must have same N dim as B_mx, got B={B_mx.shape} and B_scale={B_scale.shape}" + ) + assert B_scale.shape[2] == B_mx.shape[2] // block_size, ( + f"B_scale dim2 should be size K//block_size, got B={B_mx.shape} and B_scale={B_scale.shape}" + ) + # Dequantize input # A_mx shape: (M, K) # A_scale shape: (M, K//block_size) @@ -427,14 +443,10 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_3d( A = A.reshape(A_orig_shape) # Dequantize weights - # B_t_mx shape: (E, K, N) - # B_t_scale shape: (E, K//block_size, N) - E, K, N = B_t_mx.shape - # Tranpose to get block_size on rightmost dim # B_mx shape: (E, N, K) # B_scale shape: (E, N, K//block_size) - B_mx, B_scale = B_t_mx.transpose(-2, -1), B_t_scale.transpose(-2, -1) + E, N, K = B_mx.shape # Reshape to be able to do per-scaling group multiplication # B_mx shape: (E, N, K//block_size, block_size) diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py index 2aaf13b868..0c5f6b8cbd 100644 --- a/torchao/prototype/mx_formats/utils.py +++ b/torchao/prototype/mx_formats/utils.py @@ -99,3 +99,76 @@ def _to_blocked_single(scales: Tensor) -> Tensor: assert scales.shape == (128, 4) scales_tiled = scales.view(4, 32, 4) # view as 4 - (32, 4) tiles return scales_tiled.transpose(0, 1).reshape(32, 16) # Interleave tiles + + +def to_blocked_per_group_2d( + x_scales: Tensor, group_offs: Tensor, Mg: int, K: int, block_size: int = 32 +) -> Tensor: + """ + Convert scales to blocked format for a 2D tensor (input activations / token groups) + + Args: + x_scales: Tensor with per group scales in blocked format concatenated into one tensor. + group_offs: Tensor of shape (num_groups,) which contains the end index of each group along the Mg dimension. + Mg: total size of all groups summed together + K: K dim size + + Returns: + blocked_scales: Tensor + start_row_after_padding: Tensor of shape (num_groups,) which contains the start row after padding for each group. + """ + from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import _to_blocked + + assert x_scales.ndim == 2, "x_scales must be 2D" + assert block_size == 32, "Only block_size=32 is supported for now" + blocked_scales_list = [] + start_row_after_padding_list = [0] + group_start_idx = 0 + for i, group_end_idx in enumerate(group_offs.tolist()): + group_size = group_end_idx - group_start_idx + prev_start_row_after_padding = start_row_after_padding_list[i] + if group_size == 0: + start_row_after_padding_list.append(prev_start_row_after_padding) + continue + + # Convert group scales to blocked format + group_scales = x_scales[group_start_idx:group_end_idx] + group_scales_blocked = _to_blocked(group_scales) + blocked_scales_list.append(group_scales_blocked) + + # Calculate the start row after padding + scaling_groups_per_row = K // block_size + rows_for_group = group_scales_blocked.numel() // scaling_groups_per_row + new_start_row = prev_start_row_after_padding + rows_for_group + start_row_after_padding_list.append(new_start_row) + + # Update next group start index + group_start_idx = group_end_idx + + blocked_scales = torch.cat(blocked_scales_list, dim=0).contiguous() + blocked_scales = blocked_scales.reshape(-1, K // 32) + start_row_after_padding = torch.tensor( + start_row_after_padding_list, device=x_scales.device, dtype=torch.int64 + ) + return blocked_scales, start_row_after_padding + + +def to_blocked_per_group_3d(weight_scales: Tensor) -> Tensor: + """ + Convert scales to blocked format for each group for a 3D tensor (expert weights) + + Args: + scales: Tensor of shape (E, N, K//block_size) + group_offs: Tensor of shape (num_groups,) which contains the end index of each group along the + """ + from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import _to_blocked + + blocked_scales_list = [] + num_groups = weight_scales.shape[0] + for i in range(num_groups): + group_scales = weight_scales[i] + group_scales_blocked = _to_blocked(group_scales) + blocked_scales_list.append(group_scales_blocked) + weight_scales_blocked = torch.stack(blocked_scales_list, dim=0).contiguous() + weight_scales_blocked = weight_scales_blocked.reshape(num_groups, -1) + return weight_scales_blocked