|
| 1 | +import pytest |
| 2 | + |
| 3 | +pytest.importorskip("triton") |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.distributed as dist |
| 7 | +import torch.distributed._symmetric_memory as symm_mem |
| 8 | +from torch.distributed._functional_collectives import ( |
| 9 | + all_to_all_single_autograd, |
| 10 | +) |
| 11 | +from torch.nn import functional as F |
| 12 | +from torch.testing._internal.common_distributed import ( |
| 13 | + MultiProcessTestCase, |
| 14 | +) |
| 15 | +from torch.testing._internal.common_utils import ( |
| 16 | + instantiate_parametrized_tests, |
| 17 | + run_tests, |
| 18 | +) |
| 19 | + |
| 20 | +from torchao.float8.float8_utils import ( |
| 21 | + compute_error, |
| 22 | +) |
| 23 | +from torchao.prototype.moe_training.kernels.mxfp8.comms import ( |
| 24 | + mxfp8_on_device_all_to_all_v, |
| 25 | +) |
| 26 | + |
| 27 | + |
| 28 | +@instantiate_parametrized_tests |
| 29 | +class TritonAllReduceTest(MultiProcessTestCase): |
| 30 | + def setUp(self) -> None: |
| 31 | + super().setUp() |
| 32 | + self._spawn_processes() |
| 33 | + |
| 34 | + @property |
| 35 | + def world_size(self) -> int: |
| 36 | + return 2 |
| 37 | + |
| 38 | + @property |
| 39 | + def device(self) -> torch.device: |
| 40 | + return torch.device(f"cuda:{self.rank}") |
| 41 | + |
| 42 | + def _init_process(self): |
| 43 | + torch.cuda.set_device(self.device) |
| 44 | + store = dist.FileStore(self.file_name, self.world_size) |
| 45 | + dist.init_process_group( |
| 46 | + backend="nccl", |
| 47 | + world_size=self.world_size, |
| 48 | + rank=self.rank, |
| 49 | + store=store, |
| 50 | + ) |
| 51 | + torch.manual_seed(42 + self.rank) |
| 52 | + |
| 53 | + def _init_device(self): |
| 54 | + symm_mem.set_backend("NVSHMEM") |
| 55 | + |
| 56 | + def test_a2a_fwd_bwd(self): |
| 57 | + self._init_process() |
| 58 | + try: |
| 59 | + torch.manual_seed(42 + self.rank) |
| 60 | + self._init_device() |
| 61 | + |
| 62 | + group_name = dist.group.WORLD.group_name |
| 63 | + symm_mem.enable_symm_mem_for_group(group_name) |
| 64 | + |
| 65 | + tokens_per_ep_rank = 8192 |
| 66 | + dim = 2048 |
| 67 | + input_tensor = torch.randn( |
| 68 | + tokens_per_ep_rank, |
| 69 | + dim, |
| 70 | + device=self.device, |
| 71 | + dtype=torch.float32, |
| 72 | + requires_grad=True, |
| 73 | + ) |
| 74 | + ref_input_tensor = input_tensor.detach().clone().requires_grad_(True) |
| 75 | + |
| 76 | + # Generate random input splits that sum to tokens_per_ep_rank |
| 77 | + num_splits = self.world_size |
| 78 | + input_splits = generate_split_sizes( |
| 79 | + num_splits, tokens_per_ep_rank, self.device |
| 80 | + ) |
| 81 | + |
| 82 | + # Max output tokens per rank is worst case where one rank receives all tokens |
| 83 | + max_output_tokens_per_rank = tokens_per_ep_rank * self.world_size |
| 84 | + |
| 85 | + # Test forward |
| 86 | + output, output_splits = mxfp8_on_device_all_to_all_v( |
| 87 | + input_tensor, |
| 88 | + input_splits, |
| 89 | + max_output_tokens_per_rank, |
| 90 | + group_name, |
| 91 | + ) |
| 92 | + |
| 93 | + # Reference torch.all_to_all_single to compare against |
| 94 | + output_splits_ref = torch.empty_like(output_splits) |
| 95 | + |
| 96 | + # Compute output splits from input splits |
| 97 | + dist.all_to_all_single(output_splits_ref, input_splits) |
| 98 | + |
| 99 | + # Pre-allocate output buffer for reference a2a |
| 100 | + total_tokens_on_rank_after_a2a = output_splits_ref.sum() |
| 101 | + ref_output = torch.empty( |
| 102 | + total_tokens_on_rank_after_a2a, |
| 103 | + dim, |
| 104 | + device=self.device, |
| 105 | + dtype=torch.float32, |
| 106 | + ) |
| 107 | + |
| 108 | + # Do the actual all_to_all_single |
| 109 | + ref_output = all_to_all_single_autograd( |
| 110 | + ref_input_tensor, |
| 111 | + output_splits_ref.tolist(), |
| 112 | + input_splits.tolist(), |
| 113 | + dist.group.WORLD, |
| 114 | + ) |
| 115 | + |
| 116 | + # Compare output |
| 117 | + assert torch.equal(output_splits, output_splits_ref), ( |
| 118 | + "output_splits mismatch" |
| 119 | + ) |
| 120 | + out_no_padding = output[:total_tokens_on_rank_after_a2a] |
| 121 | + sqnr = compute_error(ref_output, out_no_padding) |
| 122 | + min_sqnr = 30.0 |
| 123 | + assert sqnr > min_sqnr, f"sqnr={sqnr} is less than min_sqnr={min_sqnr}" |
| 124 | + |
| 125 | + # Test backwards |
| 126 | + labels = torch.ones_like(out_no_padding) |
| 127 | + loss = F.mse_loss(out_no_padding, labels) |
| 128 | + ref_loss = F.mse_loss(ref_output, labels) |
| 129 | + loss.backward() |
| 130 | + ref_loss.backward() |
| 131 | + |
| 132 | + # Compare grads |
| 133 | + grad_sqnr = compute_error(ref_input_tensor.grad, input_tensor.grad) |
| 134 | + min_grad_sqnr = 28.0 |
| 135 | + assert grad_sqnr > min_grad_sqnr, ( |
| 136 | + f"grad_sqnr={grad_sqnr} is less than min_grad_sqnr={min_grad_sqnr}" |
| 137 | + ) |
| 138 | + |
| 139 | + finally: |
| 140 | + dist.destroy_process_group() |
| 141 | + |
| 142 | + |
| 143 | +def generate_split_sizes(K: int, N: int, device: str = "cpu") -> torch.Tensor: |
| 144 | + """ |
| 145 | + Generates a tensor of K random non-negative integers that sum to N. |
| 146 | + """ |
| 147 | + if K <= 0: |
| 148 | + raise ValueError("K must be a positive integer.") |
| 149 | + if N < 0: |
| 150 | + raise ValueError("N must be a non-negative integer.") |
| 151 | + |
| 152 | + if K == 1: |
| 153 | + return torch.tensor([N], dtype=torch.long, device=device) |
| 154 | + |
| 155 | + # Generate K-1 random "dividers" in the range [0, N]. |
| 156 | + dividers = torch.randint(0, N + 1, (K - 1,), device=device) |
| 157 | + |
| 158 | + # Add 0 and N to the set of dividers to form the boundaries. |
| 159 | + boundaries = torch.cat( |
| 160 | + [torch.tensor([0], device=device), dividers, torch.tensor([N], device=device)] |
| 161 | + ) |
| 162 | + |
| 163 | + # Sort the boundaries to ensure they are in order |
| 164 | + sorted_boundaries = torch.sort(boundaries).values |
| 165 | + |
| 166 | + # The K integers are the differences between consecutive boundaries (will sum to N) |
| 167 | + result = sorted_boundaries[1:] - sorted_boundaries[:-1] |
| 168 | + |
| 169 | + return result.to(dtype=torch.int64) |
| 170 | + |
| 171 | + |
| 172 | +if __name__ == "__main__": |
| 173 | + run_tests() |
0 commit comments