|
3 | 3 | # |
4 | 4 | # This source code is licensed under the BSD 3-Clause license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | | -from typing import Any |
7 | 6 |
|
8 | 7 | import torch |
9 | | -from fairscale.nn.model_parallel.initialize import get_model_parallel_group |
| 8 | +import torch.distributed._functional_collectives as funcol |
| 9 | +from torch.distributed._tensor import DTensor |
10 | 10 |
|
11 | | -# from float8_tensor import Float8Tensor |
12 | 11 | from torchao.float8.float8_tensor import Float8Tensor |
13 | 12 |
|
14 | | -# additional differentiable distributed primitives for SP which are not in |
15 | | -# the Fairscale codebase |
16 | 13 |
|
17 | | - |
18 | | -def _gather_along_first_dim(input_: torch.Tensor): |
19 | | - # same as https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/model_parallel/mappings.py#L67, |
20 | | - # but gather along first dim instead of last dim |
21 | | - group = get_model_parallel_group() |
22 | | - |
23 | | - # Bypass the function if we are using only 1 GPU. |
24 | | - if torch.distributed.get_world_size(group=group) == 1: |
25 | | - return input_ |
26 | | - |
27 | | - # Size and dimension. |
28 | | - first_dim = 0 |
29 | | - rank = torch.distributed.get_rank(group=group) |
30 | | - world_size = torch.distributed.get_world_size(group=group) |
31 | | - |
32 | | - # If the input is a float8 tensor, we need to do the transformation on the |
33 | | - # inner tensor and then return a new wrapper. |
34 | | - def _transform(t): |
35 | | - # tensors must be contiguous for all_gather to work |
36 | | - input_contig = t.contiguous() |
37 | | - |
38 | | - tensor_list = [torch.empty_like(input_contig) for _ in range(world_size)] |
39 | | - tensor_list[rank] = input_contig |
40 | | - torch.distributed.all_gather(tensor_list, input_contig, group=group) |
41 | | - |
42 | | - # Note: torch.cat already creates a contiguous tensor. |
43 | | - output = torch.cat(tensor_list, dim=first_dim).contiguous() |
44 | | - return output |
45 | | - |
46 | | - if isinstance(input_, Float8Tensor): |
47 | | - new_data = input_._data |
48 | | - new_data = new_data.view(torch.int8) |
49 | | - new_data = _transform(new_data) |
50 | | - new_data = new_data.view(input_._data.dtype) |
51 | | - output = Float8Tensor(new_data, input_._scale, input_._orig_dtype) |
52 | | - else: |
53 | | - output = _transform(input_) |
54 | | - |
55 | | - return output |
56 | | - |
57 | | - |
58 | | -def _reduce_scatter(ctx: Any, input_: torch.Tensor): |
59 | | - group = get_model_parallel_group() |
60 | | - world_size = torch.distributed.get_world_size(group) |
61 | | - |
62 | | - assert input_.shape[0] % world_size == 0 |
63 | | - output_shape = (input_.shape[0] // world_size, *input_.shape[1:]) |
64 | | - output = torch.empty(*output_shape, device=input_.device, dtype=input_.dtype) |
65 | | - |
66 | | - torch.distributed.reduce_scatter_tensor(output, input_, group=group) |
67 | | - return output |
68 | | - |
69 | | - |
70 | | -def _split_along_first_dim(input_: torch.Tensor): |
71 | | - # this is needed for testing |
72 | | - |
73 | | - # like fairscale.nn.model_parallel.mappings._split, but |
74 | | - # along the first dim instead of last dim |
75 | | - |
76 | | - group = get_model_parallel_group() |
77 | | - local_rank = torch.distributed.get_rank(group) |
78 | | - world_size = torch.distributed.get_world_size(group) |
79 | | - |
80 | | - assert input_.shape[0] % world_size == 0 |
81 | | - input_list = torch.split(input_, input_.shape[0] // world_size) |
82 | | - return input_list[local_rank] |
83 | | - |
84 | | - |
85 | | -class _AllGatherFloat8FwReduceScatterBw(torch.autograd.Function): |
86 | | - @staticmethod |
87 | | - def forward(ctx, input_): |
88 | | - return _gather_along_first_dim(input_) |
89 | | - |
90 | | - @staticmethod |
91 | | - def backward(ctx, grad_output): |
92 | | - return _reduce_scatter(ctx, grad_output) |
93 | | - |
94 | | - |
95 | | -class _ReduceScatterFwAllGatherFloat8Bw(torch.autograd.Function): |
96 | | - @staticmethod |
97 | | - def forward(ctx, input_): |
98 | | - return _reduce_scatter(ctx, input_) |
99 | | - |
100 | | - @staticmethod |
101 | | - def backward(ctx, grad_output): |
102 | | - return _gather_along_first_dim(grad_output) |
103 | | - |
104 | | - |
105 | | -class _AllGatherFwSplitBw(torch.autograd.Function): |
106 | | - @staticmethod |
107 | | - def forward(ctx, input_): |
108 | | - return _gather_along_first_dim(input_) |
109 | | - |
110 | | - @staticmethod |
111 | | - def backward(ctx, grad_output): |
112 | | - return _split_along_first_dim(grad_output) |
| 14 | +def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: |
| 15 | + """ |
| 16 | + Check if the tensor is already casted to fp8, works if the local |
| 17 | + tensor is wrapped in DTensor. |
| 18 | + """ |
| 19 | + if isinstance(tensor, Float8Tensor): |
| 20 | + return True |
| 21 | + elif isinstance(tensor, DTensor): |
| 22 | + # TODO: shall we stick to public API and directly use tensor.to_local() here? |
| 23 | + return tensor_already_casted_to_fp8(tensor._local_tensor) |
| 24 | + elif isinstance(tensor, funcol.AsyncCollectiveTensor): |
| 25 | + return tensor_already_casted_to_fp8(tensor.elem) |
| 26 | + |
| 27 | + return False |
0 commit comments