| 
3 | 3 | #  | 
4 | 4 | # This source code is licensed under the BSD 3-Clause license found in the  | 
5 | 5 | # LICENSE file in the root directory of this source tree.  | 
6 |  | -from typing import Any  | 
7 | 6 | 
 
  | 
8 | 7 | import torch  | 
9 |  | -from fairscale.nn.model_parallel.initialize import get_model_parallel_group  | 
10 | 8 | 
 
  | 
11 |  | -# from float8_tensor import Float8Tensor  | 
 | 9 | +from torch.distributed._tensor import DTensor  | 
 | 10 | +import torch.distributed._functional_collectives as funcol  | 
12 | 11 | from torchao.float8.float8_tensor import Float8Tensor  | 
13 | 12 | 
 
  | 
14 |  | -# additional differentiable distributed primitives for SP which are not in  | 
15 |  | -# the Fairscale codebase  | 
16 | 13 | 
 
  | 
17 |  | - | 
18 |  | -def _gather_along_first_dim(input_: torch.Tensor):  | 
19 |  | -    # same as https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/model_parallel/mappings.py#L67,  | 
20 |  | -    # but gather along first dim instead of last dim  | 
21 |  | -    group = get_model_parallel_group()  | 
22 |  | - | 
23 |  | -    # Bypass the function if we are using only 1 GPU.  | 
24 |  | -    if torch.distributed.get_world_size(group=group) == 1:  | 
25 |  | -        return input_  | 
26 |  | - | 
27 |  | -    # Size and dimension.  | 
28 |  | -    first_dim = 0  | 
29 |  | -    rank = torch.distributed.get_rank(group=group)  | 
30 |  | -    world_size = torch.distributed.get_world_size(group=group)  | 
31 |  | - | 
32 |  | -    # If the input is a float8 tensor, we need to do the transformation on the  | 
33 |  | -    # inner tensor and then return a new wrapper.  | 
34 |  | -    def _transform(t):  | 
35 |  | -        # tensors must be contiguous for all_gather to work  | 
36 |  | -        input_contig = t.contiguous()  | 
37 |  | - | 
38 |  | -        tensor_list = [torch.empty_like(input_contig) for _ in range(world_size)]  | 
39 |  | -        tensor_list[rank] = input_contig  | 
40 |  | -        torch.distributed.all_gather(tensor_list, input_contig, group=group)  | 
41 |  | - | 
42 |  | -        # Note: torch.cat already creates a contiguous tensor.  | 
43 |  | -        output = torch.cat(tensor_list, dim=first_dim).contiguous()  | 
44 |  | -        return output  | 
45 |  | - | 
46 |  | -    if isinstance(input_, Float8Tensor):  | 
47 |  | -        new_data = input_._data  | 
48 |  | -        new_data = new_data.view(torch.int8)  | 
49 |  | -        new_data = _transform(new_data)  | 
50 |  | -        new_data = new_data.view(input_._data.dtype)  | 
51 |  | -        output = Float8Tensor(new_data, input_._scale, input_._orig_dtype)  | 
52 |  | -    else:  | 
53 |  | -        output = _transform(input_)  | 
54 |  | - | 
55 |  | -    return output  | 
56 |  | - | 
57 |  | - | 
58 |  | -def _reduce_scatter(ctx: Any, input_: torch.Tensor):  | 
59 |  | -    group = get_model_parallel_group()  | 
60 |  | -    world_size = torch.distributed.get_world_size(group)  | 
61 |  | - | 
62 |  | -    assert input_.shape[0] % world_size == 0  | 
63 |  | -    output_shape = (input_.shape[0] // world_size, *input_.shape[1:])  | 
64 |  | -    output = torch.empty(*output_shape, device=input_.device, dtype=input_.dtype)  | 
65 |  | - | 
66 |  | -    torch.distributed.reduce_scatter_tensor(output, input_, group=group)  | 
67 |  | -    return output  | 
68 |  | - | 
69 |  | - | 
70 |  | -def _split_along_first_dim(input_: torch.Tensor):  | 
71 |  | -    # this is needed for testing  | 
72 |  | - | 
73 |  | -    # like fairscale.nn.model_parallel.mappings._split, but  | 
74 |  | -    # along the first dim instead of last dim  | 
75 |  | - | 
76 |  | -    group = get_model_parallel_group()  | 
77 |  | -    local_rank = torch.distributed.get_rank(group)  | 
78 |  | -    world_size = torch.distributed.get_world_size(group)  | 
79 |  | - | 
80 |  | -    assert input_.shape[0] % world_size == 0  | 
81 |  | -    input_list = torch.split(input_, input_.shape[0] // world_size)  | 
82 |  | -    return input_list[local_rank]  | 
83 |  | - | 
84 |  | - | 
85 |  | -class _AllGatherFloat8FwReduceScatterBw(torch.autograd.Function):  | 
86 |  | -    @staticmethod  | 
87 |  | -    def forward(ctx, input_):  | 
88 |  | -        return _gather_along_first_dim(input_)  | 
89 |  | - | 
90 |  | -    @staticmethod  | 
91 |  | -    def backward(ctx, grad_output):  | 
92 |  | -        return _reduce_scatter(ctx, grad_output)  | 
93 |  | - | 
94 |  | - | 
95 |  | -class _ReduceScatterFwAllGatherFloat8Bw(torch.autograd.Function):  | 
96 |  | -    @staticmethod  | 
97 |  | -    def forward(ctx, input_):  | 
98 |  | -        return _reduce_scatter(ctx, input_)  | 
99 |  | - | 
100 |  | -    @staticmethod  | 
101 |  | -    def backward(ctx, grad_output):  | 
102 |  | -        return _gather_along_first_dim(grad_output)  | 
103 |  | - | 
104 |  | - | 
105 |  | -class _AllGatherFwSplitBw(torch.autograd.Function):  | 
106 |  | -    @staticmethod  | 
107 |  | -    def forward(ctx, input_):  | 
108 |  | -        return _gather_along_first_dim(input_)  | 
109 |  | - | 
110 |  | -    @staticmethod  | 
111 |  | -    def backward(ctx, grad_output):  | 
112 |  | -        return _split_along_first_dim(grad_output)  | 
 | 14 | +def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool:  | 
 | 15 | +    """  | 
 | 16 | +    Check if the tensor is already casted to fp8, works if the local  | 
 | 17 | +    tensor is wrapped in DTensor.  | 
 | 18 | +    """  | 
 | 19 | +    if isinstance(tensor, Float8Tensor):  | 
 | 20 | +        return True  | 
 | 21 | +    elif isinstance(tensor, DTensor):  | 
 | 22 | +        # TODO: shall we stick to public API and directly use tensor.to_local() here?  | 
 | 23 | +        return tensor_already_casted_to_fp8(tensor._local_tensor)  | 
 | 24 | +    elif isinstance(tensor, funcol.AsyncCollectiveTensor):  | 
 | 25 | +        return tensor_already_casted_to_fp8(tensor.elem)  | 
 | 26 | + | 
 | 27 | +    return False  | 
0 commit comments