|
17 | 17 | import torch.nn as nn |
18 | 18 | from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType |
19 | 19 | from torchao.float8.float8_linear_utils import convert_to_float8_training |
| 20 | +from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic |
20 | 21 | from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor |
21 | 22 | from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp |
22 | 23 | from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy |
23 | | -from torch.distributed._tensor import DTensor |
| 24 | +from torch.distributed._tensor import DTensor, init_device_mesh |
| 25 | +from torchao.float8.float8_tensor import GemmInputRole |
24 | 26 | from torch.testing._internal.common_cuda import TEST_CUDA |
25 | 27 | from torch.testing._internal.common_distributed import skip_if_lt_x_gpu |
26 | 28 | from torch.testing._internal.common_fsdp import ( |
@@ -293,6 +295,34 @@ def _get_curr_active_memory_mb(self) -> int: |
293 | 295 | return round(mem_stats["active_bytes.all.current"] / 1e6) |
294 | 296 |
|
295 | 297 |
|
| 298 | +class Test2DParallelMultiThread(FSDPTestMultiThread, TestFloat8Common): |
| 299 | + @property |
| 300 | + def world_size(self) -> int: |
| 301 | + return 4 |
| 302 | + |
| 303 | + def test_amax_allreduce_device_mesh(self): |
| 304 | + dp_size = 2 |
| 305 | + pp_size = self.world_size // dp_size |
| 306 | + global_mesh = init_device_mesh("cuda", (pp_size, dp_size), mesh_dim_names=("pp", "dp")) |
| 307 | + dp_mesh = global_mesh["dp"] |
| 308 | + pp_mesh = global_mesh["pp"] |
| 309 | + |
| 310 | + if self.rank in [0, 1]: |
| 311 | + # rank 0 and 1 are the 1st stage in the pipeline |
| 312 | + # rank 2 and 4 are doing nothing but waiting for the 1st stage |
| 313 | + torch.manual_seed(42 + self.rank) |
| 314 | + hp_tensor = torch.randn(768, 32, device="cuda") |
| 315 | + float8_tensor = hp_tensor_to_float8_dynamic( |
| 316 | + hp_tensor, |
| 317 | + torch.float8_e4m3fn, |
| 318 | + Float8LinearConfig( |
| 319 | + cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC), |
| 320 | + ), |
| 321 | + gemm_input_role=GemmInputRole.WEIGHT, |
| 322 | + reduce_amax=True, |
| 323 | + device_mesh=dp_mesh |
| 324 | + ) |
| 325 | + |
296 | 326 | class TestFloat8MultiThread(FSDPTestMultiThread, TestFloat8Common): |
297 | 327 | @property |
298 | 328 | def world_size(self) -> int: |
|
0 commit comments