Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/integration/test_lists/test-db/l0_dgx_b200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ l0_dgx_b200:
backend: pytorch
tests:
- unittest/_torch/multi_gpu_modeling -k "deepseek"
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEP]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=FLASHINFER-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=FLASHINFER-torch_compile=False]
Expand Down
185 changes: 185 additions & 0 deletions tests/unittest/_torch/modules/test_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,191 @@ def per_rank_test_fused_moe_alltoall(job_id):
assert r is None


@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="needs 4 GPUs to run this test")
@pytest.mark.parametrize("alltoall_method_type", [
AlltoallMethodType.MNNVL, AlltoallMethodType.DeepEP,
AlltoallMethodType.DeepEPLowLatency
],
ids=lambda s: s.name)
def test_fused_moe_alltoall_fp4(alltoall_method_type):

world_size = 4
dtype = torch.bfloat16
HIDDEN_SIZE = 2560
INTERMEDIATE_SIZE = 1536
NUM_EXPERTS = 72
TOP_K = 6
MAX_NUM_TOKENS = 2048

torch.manual_seed(0)
torch.cuda.manual_seed(0)

x_list_world = []
weights_world = []

for i in range(world_size):
x_list = []
m = MAX_NUM_TOKENS
while m >= 1:
x = torch.randn((m, HIDDEN_SIZE), dtype=dtype, device="cuda")
x_list.append(x.cuda(i))
m //= 2

x_abs_max = torch.cat([x.flatten() for x in x_list]).abs().max().float()
x_sf_global = (448 * 6) / x_abs_max

weights = {}
for expert_id in range(NUM_EXPERTS):

w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE),
dtype=dtype,
device="cuda")
w1_sf_global = (448 * 6) / w1_weight.abs().max().float()

w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE),
dtype=dtype,
device="cuda")
w2_sf_global = (448 * 6) / w2_weight.abs().max().float()

w3_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE),
dtype=dtype,
device="cuda")
w3_sf_global = (448 * 6) / w3_weight.abs().max().float()

w3_w1_global = min(
w1_sf_global,
w3_sf_global) # w3 global and w1 global must be the same

SCALING_VECTOR_SIZE = 16

w1_weight_nvfp4, w1_sf_block = torch.ops.trtllm.fp4_quantize(
w1_weight, w3_w1_global, SCALING_VECTOR_SIZE, False)
w1_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
w1_sf_block.cpu().view(INTERMEDIATE_SIZE, -1))

w2_weight_nvfp4, w2_sf_block = torch.ops.trtllm.fp4_quantize(
w2_weight, w2_sf_global, SCALING_VECTOR_SIZE, False)
w2_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
w2_sf_block.cpu().view(HIDDEN_SIZE, -1))

w3_weight_nvfp4, w3_sf_block = torch.ops.trtllm.fp4_quantize(
w3_weight, w3_w1_global, SCALING_VECTOR_SIZE, False)
w3_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
w3_sf_block.cpu().view(INTERMEDIATE_SIZE, -1))

w1_input_scale = x_sf_global.cuda(i)
w2_input_scale = x_sf_global.cuda(i)
w3_input_scale = x_sf_global.cuda(i)

weights[f"{expert_id}.w1.weight"] = w1_weight_nvfp4.cuda(i)
weights[f"{expert_id}.w2.weight"] = w2_weight_nvfp4.cuda(i)
weights[f"{expert_id}.w3.weight"] = w3_weight_nvfp4.cuda(i)
weights[
f"{expert_id}.w1.weight_scale"] = w1_sf_block_unswizzled.cuda(i)
weights[
f"{expert_id}.w2.weight_scale"] = w2_sf_block_unswizzled.cuda(i)
weights[
f"{expert_id}.w3.weight_scale"] = w3_sf_block_unswizzled.cuda(i)

weights[f"{expert_id}.w1.input_scale"] = 1.0 / w1_input_scale.cuda(
i)
weights[f"{expert_id}.w2.input_scale"] = 1.0 / w2_input_scale.cuda(
i)
weights[f"{expert_id}.w3.input_scale"] = 1.0 / w3_input_scale.cuda(
i)
weights[f"{expert_id}.w1.weight_scale_2"] = 1.0 / w3_w1_global.cuda(
i)
weights[f"{expert_id}.w2.weight_scale_2"] = 1.0 / w2_sf_global.cuda(
i)
weights[f"{expert_id}.w3.weight_scale_2"] = 1.0 / w3_w1_global.cuda(
i)

x_list_world.append(x_list)
weights_world.append(weights)

def per_rank_test_fused_moe_alltoall(job_id):
routing_method = DefaultMoeRoutingMethod(top_k=TOP_K)
mapping = Mapping(world_size=world_size,
rank=mpi_rank(),
tp_size=world_size,
moe_ep_size=world_size,
moe_tp_size=1,
enable_attention_dp=True)
torch.cuda.set_device(mapping.rank)
torch.manual_seed(mapping.rank)

x_list = x_list_world[mapping.rank]
weights = weights_world[mapping.rank]

quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4)
with mock.patch.object(WideEPMoE,
"select_alltoall_method_type",
return_value=alltoall_method_type):
alltoall_model = WideEPMoE(
num_experts=NUM_EXPERTS,
routing_method=routing_method,
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
dtype=dtype,
reduce_results=True,
model_config=ModelConfig(mapping=mapping,
max_num_tokens=MAX_NUM_TOKENS,
quant_config=quant_config),
)
alltoall_model.to("cuda")
alltoall_model.load_weights([weights])

ref_model = CutlassFusedMoE(
num_experts=NUM_EXPERTS,
routing_method=routing_method,
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
dtype=dtype,
reduce_results=True,
model_config=ModelConfig(mapping=mapping,
max_num_tokens=MAX_NUM_TOKENS,
quant_config=quant_config),
)
ref_model.to("cuda")
ref_model.load_weights([weights])

# Evaluate the outputs on a variant sequence length to verify the robustness of alltoall methods
m = MAX_NUM_TOKENS
i = 0
while m >= 1:
x = x_list[i]
i += 1
router_logits = torch.randn((m, NUM_EXPERTS),
dtype=dtype,
device="cuda")
all_rank_num_tokens = [m] * mapping.world_size

with torch.inference_mode():
output = alltoall_model.forward(
x,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
all_rank_max_num_tokens=m,
use_dp_padding=False)
ref_output = ref_model.forward(
x,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
all_rank_max_num_tokens=m,
use_dp_padding=False)

# Evaluate outputs
torch.testing.assert_close(output, ref_output, rtol=0.05, atol=0.5)
m //= 2

with MPIPoolExecutor(max_workers=world_size) as executor:
results = executor.map(per_rank_test_fused_moe_alltoall,
range(world_size))
for r in results:
assert r is None


@skip_pre_hopper
@pytest.mark.parametrize("moe_backend", ["CUTLASS", "TRITON"])
@pytest.mark.parametrize("routing_cls",
Expand Down