Skip to content

Commit b883291

Browse files
committed
[Distributed] Improve sharding example
1 parent 728d629 commit b883291

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

tutorials/developer_api_guide/tensor_parallel.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import os
22
import torch
33
import torch.distributed as dist
4+
from typing import Sequence
45
from torch.distributed import DeviceMesh
5-
from torch.distributed._tensor import DTensor, Replicate, Shard
6+
from torch.distributed.tensor import DTensor, Replicate, Shard, Placement
67
from torch.utils._python_dispatch import return_and_correct_aliasing
78
from my_dtype_tensor_subclass import MyDTypeTensor, fill_defaults
89

@@ -101,18 +102,33 @@ def quantize(m: torch.nn.Module) -> torch.nn.Module:
101102
)
102103
return m
103104

105+
def shard(
106+
full_tensor: torch.Tensor,
107+
device_mesh: DeviceMesh,
108+
placements: Sequence[Placement],
109+
) -> DTensor:
110+
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
111+
112+
shape, offset = compute_local_shape_and_global_offset(
113+
full_tensor.shape, device_mesh, placements
114+
)
115+
slices = [
116+
slice(cur_offset, cur_offset + cur_shape)
117+
for cur_shape, cur_offset in zip(shape, offset)
118+
]
119+
local_tensor = full_tensor[slices]
120+
return DTensor.from_local(
121+
local_tensor, device_mesh, placements
122+
)
123+
104124
def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
105125
"""
106126
Shard linear layer of the model in column-wise fashion
107127
"""
108128
# Column-wise is wrt to A^T, so for A it is row-wise.
109-
# Number of rows per rank
110129
orig_weight = m.linear.weight
111-
n_local_rows = orig_weight.size(0) // mesh.size()
112-
rank = mesh.get_local_rank()
113-
local_shard = orig_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :]
114130
# Construct DTensor from local shard
115-
dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)])
131+
dtensor = shard(orig_weight, mesh, [Shard(0)])
116132
# Replace parameter in module
117133
m.linear.weight = torch.nn.Parameter(
118134
dtensor, requires_grad=False
@@ -124,13 +140,9 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
124140
Shard linear layer of the model in row-wise fashion
125141
"""
126142
# Row-wise is wrt to A^T, so for A it is column-wise.
127-
# Number of rows per rank
128143
orig_weight = m.linear.weight
129-
n_local_cols = orig_weight.size(1) // mesh.size()
130-
rank = mesh.get_local_rank()
131-
local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols]
132144
# Construct DTensor from local shard
133-
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)])
145+
dtensor = shard(orig_weight, mesh, [Shard(1)])
134146
# Replace parameter in module
135147
m.linear.weight = torch.nn.Parameter(
136148
dtensor, requires_grad=False

0 commit comments

Comments
 (0)