Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions test/prototype/moe_training/test_scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,25 +230,26 @@ def compute_reference_forward(
@pytest.mark.parametrize("num_experts", (1, 8, 16))
def test_emulate_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts):
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda")
w = torch.randn(num_experts, N, K, dtype=torch.bfloat16, device="cuda")
offs = generate_jagged_offs(num_experts, M)
x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone()
x_ref, w_ref, offs_ref = x.clone(), w.clone(), offs.clone()

# Quantize inputs to mxpf8 for emulated mxfp8 scaled grouped mm
block_size = 32
x_scale, x_mx = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
x_scale, x_fp8 = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size)

# To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose.
w_scale, w_mx = to_mx(
w_t.transpose(-2, -1).contiguous(),
w_scale, w_fp8 = to_mx(
w,
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
)
w_t_scale, w_t_mx = w_scale.transpose(-2, -1), w_mx.transpose(-2, -1)

ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
ref_out = torch._grouped_mm(
x_ref, w_ref.transpose(-2, -1), offs=offs_ref, out_dtype=torch.bfloat16
)
out = _emulated_mxfp8_scaled_grouped_mm_2d_3d(
x_mx, x_scale, w_t_mx, w_t_scale, offs=offs, out_dtype=torch.bfloat16
x_fp8, x_scale, w_fp8, w_scale, offs=offs, out_dtype=torch.bfloat16
)

sqnr = compute_error(ref_out, out)
Expand Down Expand Up @@ -305,18 +306,25 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):


@skip_if_rocm("ROCm not supported")
@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
@pytest.mark.parametrize("num_experts", (1, 8, 16))
@pytest.mark.parametrize(
"M,K,N", [(1024, 5120, 8192), (2048, 5120, 8192), (16640, 5120, 8192)]
)
@pytest.mark.parametrize("num_experts", (2, 4, 8, 16))
def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts):
from torchao.prototype.moe_training.scaled_grouped_mm import (
_MXFP8GroupedMM,
)

block_size = 32
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
w_t = torch.randn(
num_experts, K, N, dtype=torch.bfloat16, device="cuda", requires_grad=True
w = torch.randn(
num_experts,
N,
K,
dtype=torch.bfloat16,
device="cuda",
)
w_t = w.transpose(-2, -1).requires_grad_(True)
offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
x_ref, w_t_ref, offs_ref = (
x.clone().detach().requires_grad_(True),
Expand Down
133 changes: 29 additions & 104 deletions test/prototype/moe_training/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,109 +40,38 @@
],
)
@pytest.mark.parametrize("compile", [False, True])
def test_moe_float8_training(target_fqns: list[str], compile: bool):
# Set token group alignment size to 16. This is required so that
# each logically distinct gemm in the grouped gemm `grad_weight = grad_output_t @ input`
# has the contraction dim be divisible by 16. 16 byte alignment is required
# for the slowest moving dim (stride 1), so 16 bytes / 1 byte per element in fp8 = 16 elements.
set_token_group_alignment_size_m(16)
model_args = MoEArgs(
num_experts=8,
)
init_std = 0.02
device = torch.device("cuda")

# reference bf16 MoE
dim, hidden_dim = 5120, 8192
ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda()
torch.manual_seed(42)
ref_model.init_weights(init_std, device)

# target MoE for testing conversion
model = copy.deepcopy(ref_model)

# assert starting params are identical for both models
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
assert torch.equal(param1, param2)

# convert MoE to float8 training
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
for target_fqn in target_fqns:
if target_fqn in cur_fqn:
return True
return False

# quantize test model
config = MoETrainingConfig()
quantize_(model, config=config, filter_fn=moe_module_filter_fn)

# validate that only the experts were converted
_validate_model_conversion(
model,
target_fqns=target_fqns,
)
if compile:
# TODO: compile with fullgraph=True when torchtitan llama4 moe supports it
model = torch.compile(model, fullgraph=False)
ref_model = torch.compile(ref_model, fullgraph=False)

# inputs
batch, seq = 8, 2048
ref_x = torch.randn(
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
)
x = ref_x.detach().clone().requires_grad_(True)

# forward pass
ref_out = ref_model(ref_x)
out = model(x)

# validate output
out_sqnr = compute_error(out, ref_out)
min_out_sqnr = 29.0
assert out_sqnr.item() >= min_out_sqnr, (
f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}."
)

# compute loss
labels = torch.ones_like(ref_out)
ref_loss = F.mse_loss(ref_out, labels)
out_loss = F.mse_loss(out, labels)

# backward pass
ref_loss.backward()
out_loss.backward()

# validate input gradient
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
min_input_grad_sqnr = 29.0
assert input_grad_sqnr.item() >= min_input_grad_sqnr, (
f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}."
)

# validate param gradients
min_param_grad_sqnr = 23.0
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
param_grad_sqnr = compute_error(param1.grad, param2.grad)
assert param_grad_sqnr.item() >= min_param_grad_sqnr, (
f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}."
)


@pytest.mark.parametrize(
"target_fqns",
"recipe_config",
[
["experts"],
["does.not.exist"],
# {"recipe": MoEScalingType.FP8_ROWWISE, "group_alignment_size": 16, "min_out_sqnr": 29.0, "min_input_grad_sqnr": 29.0, "min_param_grad_sqnr": 23.0},
{
"recipe": MoEScalingType.MXFP8,
"group_alignment_size": 32,
"min_out_sqnr": 28.0,
"min_input_grad_sqnr": 29.0,
"min_param_grad_sqnr": 21.0,
},
],
)
@pytest.mark.parametrize("compile", [False, True])
def test_moe_mxfp8_training(target_fqns: list[str], compile: bool):
block_size = 32

# Token groups must be divisible by 32 for mxfp8
set_token_group_alignment_size_m(block_size)

def test_moe_training(target_fqns: list[str], compile: bool, recipe_config: dict):
(
recipe,
group_alignment_size,
min_out_sqnr,
min_input_grad_sqnr,
min_param_grad_sqnr,
) = (
recipe_config["recipe"],
recipe_config["group_alignment_size"],
recipe_config["min_out_sqnr"],
recipe_config["min_input_grad_sqnr"],
recipe_config["min_param_grad_sqnr"],
)
# Set token group alignment size. This is required so that
# each logically distinct gemm in the grouped gemm `grad_weight = grad_output_t @ input`
# has the contraction dim be divisible by 16. 16 byte alignment is required
# for the slowest moving dim (stride 1).
set_token_group_alignment_size_m(group_alignment_size)
model_args = MoEArgs(
num_experts=8,
)
Expand Down Expand Up @@ -170,15 +99,14 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
return False

# quantize test model
config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8)
config = MoETrainingConfig(scaling_type=recipe)
quantize_(model, config=config, filter_fn=moe_module_filter_fn)

# validate that only the experts were converted
_validate_model_conversion(
model,
target_fqns=target_fqns,
)

if compile:
# TODO: compile with fullgraph=True when torchtitan llama4 moe supports it
model = torch.compile(model, fullgraph=False)
Expand All @@ -197,7 +125,6 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:

# validate output
out_sqnr = compute_error(out, ref_out)
min_out_sqnr = 28.0
assert out_sqnr.item() >= min_out_sqnr, (
f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}."
)
Expand All @@ -213,13 +140,11 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:

# validate input gradient
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
min_input_grad_sqnr = 30.0
assert input_grad_sqnr.item() >= min_input_grad_sqnr, (
f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}."
)

# validate param gradients
min_param_grad_sqnr = 21.0
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
param_grad_sqnr = compute_error(param1.grad, param2.grad)
assert param_grad_sqnr.item() >= min_param_grad_sqnr, (
Expand Down
3 changes: 3 additions & 0 deletions torchao/prototype/moe_training/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
triton_fp8_per_group_rowwise_scales as triton_fp8_per_group_rowwise_scales,
)
from torchao.prototype.moe_training.kernels.mxfp8 import (
fbgemm_mxfp8_grouped_mm_2d_3d as fbgemm_mxfp8_grouped_mm_2d_3d,
)
135 changes: 135 additions & 0 deletions torchao/prototype/moe_training/kernels/mxfp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import logging

import torch

from torchao.prototype.mx_formats.utils import (
to_blocked_per_group_2d,
to_blocked_per_group_3d,
)

logger: logging.Logger = logging.getLogger(__name__)

try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
except Exception as e:
logging.warning(
f"fbgemm_gpu_genai package is required for this feature but import failed with exception: {e}"
"Please install nightly builds of pytorch and fbgemm_gpu_genai build using this command and try again: "
"pip3 install --force-reinstall --pre torch fbgemm-gpu-genai --index-url https://download.pytorch.org/whl/nightly/cu129"
"If errors persist, please file a bug report."
)


@torch.library.custom_op("torchao::fbgemm_mxfp8_grouped_mm_2d_3d", mutates_args={})
def fbgemm_mxfp8_grouped_mm_2d_3d(
A_fp8: torch.Tensor,
A_scales: torch.Tensor,
B_fp8: torch.Tensor,
B_scales: torch.Tensor,
offs: torch.Tensor,
block_size: int = 32,
out_dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
assert A_fp8.ndim == 2, "A_fp8 tensor must be 2D"
assert B_fp8.ndim == 3, "B_fp8 tensor must be 3D"
assert block_size == 32, "Only block_size=32 is supported"
assert out_dtype == torch.bfloat16, "Only out_dtype=bfloat16 is supported"
assert A_fp8.shape[-1] == B_fp8.shape[-1], "A_fp8 and B_fp8 must have same last dim"

# Convert scales for each group to blocked format.
Mg, K = A_fp8.shape
A_scales_blocked, starting_row_after_padding = to_blocked_per_group_2d(
A_scales, offs, Mg, K
)
B_scales_blocked = to_blocked_per_group_3d(B_scales)

# From this, we compute `group_sizes` and `starting_row_after_padding`:
# group_sizes = [32, 32, 64]
# starting_row_after_padding = [0, 32, 64, 128]
zero = torch.tensor([0], dtype=offs.dtype, device=offs.device)
group_sizes = torch.diff(offs, prepend=zero).to(torch.int64)

# TODO: remove debug logging once prototype is more mature.
_log_inputs(
A_fp8,
B_fp8,
A_scales,
A_scales_blocked,
B_scales,
B_scales_blocked,
offs,
group_sizes,
starting_row_after_padding,
)

out = torch.ops.fbgemm.mx8mx8bf16_grouped_stacked(
A_fp8,
B_fp8,
A_scales_blocked,
B_scales_blocked,
group_sizes,
starting_row_after_padding=starting_row_after_padding,
)
return out


@fbgemm_mxfp8_grouped_mm_2d_3d.register_fake
def _fbgemm_mxfp8_grouped_mm_2d_3d_fake(
A_fp8: torch.Tensor,
A_scales: torch.Tensor,
B_fp8: torch.Tensor,
B_scales: torch.Tensor,
offs: torch.Tensor,
block_size: int = 32,
out_dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
assert A_fp8.ndim == 2, "A_fp8 tensor must be 2D"
assert B_fp8.ndim == 3, "B_fp8 tensor must be 3D"
assert out_dtype == torch.bfloat16, "Only out_dtype=bfloat16 is supported"
assert A_fp8.shape[-1] == B_fp8.shape[-1], "A_fp8 and B_fp8 must have same last dim"
mg, k = A_fp8.shape
e, n, k = B_fp8.shape
n_groups = offs.numel()
assert n_groups == e, (
"Size of `offs` (number of groups) must match first dim of `B_fp8`"
)
output = torch.empty((mg, n), dtype=torch.bfloat16, device=A_fp8.device)
return output


def _log_inputs(
A_fp8: torch.Tensor,
B_fp8: torch.Tensor,
A_scales: torch.Tensor,
A_scales_blocked: torch.Tensor,
B_scales: torch.Tensor,
B_scales_blocked: torch.Tensor,
offs: torch.Tensor,
group_sizes: torch.Tensor,
starting_row_after_padding: torch.Tensor,
):
logger.info(f"offs: {offs}, dtype: {offs.dtype}")
logger.info(
f"A_fp8.shape: {A_fp8.shape}, stride: {A_fp8.stride()}, dtype: {A_fp8.dtype}"
)
logger.info(
f"B_fp8.shape: {B_fp8.shape}, stride: {B_fp8.stride()}, dtype: {B_fp8.dtype}"
)
logger.info(
f"A_scales (non-blocked) shape: {A_scales.shape}, stride: {A_scales.stride()}, dtype: {A_scales.dtype}"
)
logger.info(
f"A_scales_blocked.shape: {A_scales_blocked.shape}, stride: {A_scales_blocked.stride()}, dtype: {A_scales_blocked.dtype}"
)
logger.info(
f"B_scales (non-blocked) shape: {B_scales.shape}, stride: {B_scales.stride()}, dtype: {B_scales.dtype}"
)
logger.info(
f"B_scales_blocked.shape: {B_scales_blocked.shape}, stride: {B_scales_blocked.stride()}, dtype: {B_scales_blocked.dtype}"
)
logger.info(
f"group_sizes: {group_sizes}, stride: {group_sizes.stride()}, dtype: {group_sizes.dtype}"
)
logger.info(
f"starting_row_after_padding: {starting_row_after_padding}, stride: {starting_row_after_padding.stride()}, dtype: {starting_row_after_padding.dtype}"
)
Loading
Loading