| 
 | 1 | +"""  | 
 | 2 | +This file is necessary until new version of torch.distributed is released with  | 
 | 3 | +https://github.com/pytorch/pytorch/commit/b96b1e8cff029bb0a73283e6e7f6cc240313f1dc  | 
 | 4 | +"""  | 
 | 5 | +import torch  | 
 | 6 | +import torch.distributed as dist  | 
 | 7 | +from torch.distributed.distributed_c10d import (_get_pg_default_device,  | 
 | 8 | +                                                _object_to_tensor,  | 
 | 9 | +                                                _tensor_to_object)  | 
 | 10 | + | 
 | 11 | + | 
 | 12 | +def send_object_list(object_list, dst, group=None, device=None):  | 
 | 13 | +    """  | 
 | 14 | +    Sends picklable objects in ``object_list`` synchronously.  | 
 | 15 | +
  | 
 | 16 | +    Similar to :func:`send`, but Python objects can be passed in.  | 
 | 17 | +    Note that all objects in ``object_list`` must be picklable in order to be  | 
 | 18 | +    sent.  | 
 | 19 | +
  | 
 | 20 | +    Args:  | 
 | 21 | +        object_list (List[Any]): List of input objects to sent.  | 
 | 22 | +            Each object must be picklable. Receiver must provide lists of  | 
 | 23 | +            equal sizes.  | 
 | 24 | +        dst (int): Destination rank to send ``object_list`` to.  | 
 | 25 | +            Destination rank is based on global process group  | 
 | 26 | +            (regardless of ``group`` argument)  | 
 | 27 | +        group: (ProcessGroup, optional): The process group to work on. If None,  | 
 | 28 | +            the default process group will be used. Default is ``None``.  | 
 | 29 | +        device (``torch.device``, optional): If not None, the objects are  | 
 | 30 | +            serialized and converted to tensors which are moved to the  | 
 | 31 | +            ``device`` before sending. Default is ``None``.  | 
 | 32 | +
  | 
 | 33 | +    Returns:  | 
 | 34 | +        ``None``.  | 
 | 35 | +    """  | 
 | 36 | +    if dist.get_rank() == dst:  | 
 | 37 | +        raise ValueError(  | 
 | 38 | +            "Invalid destination rank: destination rank should not be the "  | 
 | 39 | +            "same as the rank of the current process.")  | 
 | 40 | + | 
 | 41 | +    # Current device selection.  | 
 | 42 | +    # To preserve backwards compatibility, ``device`` is default to ``None``  | 
 | 43 | +    # in which case we run current logic of device selection, i.e.  | 
 | 44 | +    # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the  | 
 | 45 | +    # case it is not ``None`` we move the size and object tensors to be  | 
 | 46 | +    # sent to this device.  | 
 | 47 | +    current_device = device or _get_pg_default_device(group)  | 
 | 48 | +    # Serialize object_list elements to tensors on src rank.  | 
 | 49 | +    tensor_list, size_list = zip(  | 
 | 50 | +        *  | 
 | 51 | +        [_object_to_tensor(obj, current_device, group) for obj in object_list])  | 
 | 52 | +    object_sizes_tensor = torch.cat(size_list)  | 
 | 53 | + | 
 | 54 | +    # Send object sizes  | 
 | 55 | +    dist.send(object_sizes_tensor, dst=dst, group=group)  | 
 | 56 | + | 
 | 57 | +    # Concatenate and send serialized object tensors  | 
 | 58 | +    # Note: torch.cat will do an extra memory copy to the current device,  | 
 | 59 | +    # if the tensor_list has only one element, we can skip the copy.  | 
 | 60 | +    if len(tensor_list) == 1:  # type: ignore[possibly-undefined]  | 
 | 61 | +        object_tensor = tensor_list[0]  | 
 | 62 | +    else:  | 
 | 63 | +        object_tensor = torch.cat(tensor_list)  | 
 | 64 | + | 
 | 65 | +    dist.send(object_tensor, dst=dst, group=group)  | 
 | 66 | + | 
 | 67 | + | 
 | 68 | +def recv_object_list(object_list, src=None, group=None, device=None):  | 
 | 69 | +    """  | 
 | 70 | +    Receives picklable objects in ``object_list`` synchronously.  | 
 | 71 | +
  | 
 | 72 | +    Similar to :func:`recv`, but can receive Python objects.  | 
 | 73 | +
  | 
 | 74 | +    Args:  | 
 | 75 | +        object_list (List[Any]): List of objects to receive into.  | 
 | 76 | +            Must provide a list of sizes equal to the size of the list  | 
 | 77 | +            being sent.  | 
 | 78 | +        src (int, optional): Source rank from which to recv ``object_list``.  | 
 | 79 | +            Source rank is based on global process group  | 
 | 80 | +            (regardless of ``group`` argument)  | 
 | 81 | +            Will receive from any rank if set to None. Default is ``None``.  | 
 | 82 | +        group: (ProcessGroup, optional): The process group to work on. If None,  | 
 | 83 | +            the default process group will be used. Default is ``None``.  | 
 | 84 | +        device (``torch.device``, optional): If not None, receives on  | 
 | 85 | +            this device. Default is ``None``.  | 
 | 86 | +
  | 
 | 87 | +    Returns:  | 
 | 88 | +        Sender rank. -1 if rank is not part of the group. If rank is part   | 
 | 89 | +        of the group, ``object_list`` will contain the sent objects from  | 
 | 90 | +        ``src`` rank.  | 
 | 91 | +    """  | 
 | 92 | + | 
 | 93 | +    # Current device selection.  | 
 | 94 | +    # To preserve backwards compatibility, ``device`` is default to ``None``  | 
 | 95 | +    # in which case we run current logic of device selection, i.e.  | 
 | 96 | +    # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the  | 
 | 97 | +    # case it is not ``None`` we move the size and object tensors to be  | 
 | 98 | +    # received to this device.  | 
 | 99 | +    current_device = device or _get_pg_default_device(group)  | 
 | 100 | +    object_sizes_tensor = torch.empty(len(object_list),  | 
 | 101 | +                                      dtype=torch.long,  | 
 | 102 | +                                      device=current_device)  | 
 | 103 | + | 
 | 104 | +    # Receive object sizes  | 
 | 105 | +    rank_sizes = dist.recv(object_sizes_tensor, src=src, group=group)  | 
 | 106 | + | 
 | 107 | +    # Tensor to receive serialized objects into.  | 
 | 108 | +    object_tensor = torch.empty(  # type: ignore[call-overload]  | 
 | 109 | +        torch.sum(object_sizes_tensor).item(),  # type: ignore[arg-type]  | 
 | 110 | +        dtype=torch.uint8,  | 
 | 111 | +        device=current_device)  | 
 | 112 | + | 
 | 113 | +    rank_objects = dist.recv(object_tensor, src=src, group=group)  | 
 | 114 | +    assert (rank_sizes == rank_objects  | 
 | 115 | +            ), "Mismatch in return ranks for object sizes and objects."  | 
 | 116 | +    # Deserialize objects using their stored sizes.  | 
 | 117 | +    offset = 0  | 
 | 118 | +    for i, obj_size in enumerate(object_sizes_tensor):  | 
 | 119 | +        obj_view = object_tensor[offset:offset + obj_size]  | 
 | 120 | +        obj_view = obj_view.type(torch.uint8)  | 
 | 121 | +        offset += obj_size  | 
 | 122 | +        object_list[i] = _tensor_to_object(obj_view, obj_size, group)  | 
 | 123 | +    return rank_objects  | 
0 commit comments