11import os
22import torch
33import torch .distributed as dist
4+ from typing import Sequence
45from torch .distributed import DeviceMesh
5- from torch .distributed ._tensor import DTensor , Replicate , Shard
6+ from torch .distributed .tensor import DTensor , Replicate , Shard , Placement
67from torch .utils ._python_dispatch import return_and_correct_aliasing
78from my_dtype_tensor_subclass import MyDTypeTensor , fill_defaults
89
@@ -101,18 +102,33 @@ def quantize(m: torch.nn.Module) -> torch.nn.Module:
101102 )
102103 return m
103104
105+ def shard (
106+ full_tensor : torch .Tensor ,
107+ device_mesh : DeviceMesh ,
108+ placements : Sequence [Placement ],
109+ ) -> DTensor :
110+ from torch .distributed .tensor ._utils import compute_local_shape_and_global_offset
111+
112+ shape , offset = compute_local_shape_and_global_offset (
113+ full_tensor .shape , device_mesh , placements
114+ )
115+ slices = [
116+ slice (cur_offset , cur_offset + cur_shape )
117+ for cur_shape , cur_offset in zip (shape , offset )
118+ ]
119+ local_tensor = full_tensor [slices ]
120+ return DTensor .from_local (
121+ local_tensor , device_mesh , placements
122+ )
123+
104124def colwise_shard (m : torch .nn .Module , mesh : DeviceMesh ) -> torch .nn .Module :
105125 """
106126 Shard linear layer of the model in column-wise fashion
107127 """
108128 # Column-wise is wrt to A^T, so for A it is row-wise.
109- # Number of rows per rank
110129 orig_weight = m .linear .weight
111- n_local_rows = orig_weight .size (0 ) // mesh .size ()
112- rank = mesh .get_local_rank ()
113- local_shard = orig_weight [rank * n_local_rows : (rank + 1 ) * n_local_rows , :]
114130 # Construct DTensor from local shard
115- dtensor = DTensor . from_local ( local_shard , mesh , [Shard (0 )])
131+ dtensor = shard ( orig_weight , mesh , [Shard (0 )])
116132 # Replace parameter in module
117133 m .linear .weight = torch .nn .Parameter (
118134 dtensor , requires_grad = False
@@ -124,13 +140,9 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
124140 Shard linear layer of the model in row-wise fashion
125141 """
126142 # Row-wise is wrt to A^T, so for A it is column-wise.
127- # Number of rows per rank
128143 orig_weight = m .linear .weight
129- n_local_cols = orig_weight .size (1 ) // mesh .size ()
130- rank = mesh .get_local_rank ()
131- local_shard = orig_weight [:, rank * n_local_cols : (rank + 1 ) * n_local_cols ]
132144 # Construct DTensor from local shard
133- dtensor = DTensor . from_local ( local_shard , mesh , [Shard (1 )])
145+ dtensor = shard ( orig_weight , mesh , [Shard (1 )])
134146 # Replace parameter in module
135147 m .linear .weight = torch .nn .Parameter (
136148 dtensor , requires_grad = False
0 commit comments