Skip to content

Commit 998b595

Browse files
committed
Add comment
1 parent b883291 commit 998b595

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

tutorials/developer_api_guide/tensor_parallel.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,13 @@ def shard(
107107
device_mesh: DeviceMesh,
108108
placements: Sequence[Placement],
109109
) -> DTensor:
110+
"""
111+
Add a shard function to simplify both colwise_shard and rowwise_shard. The
112+
shard function accepts a full tensor, and returns a DTensor based on
113+
indicated placements. Goal is to move the shard function as a static method
114+
of DTensor, e.g.
115+
dtensor = DTensor.shard(full_tensor, device_mesh, placement)
116+
"""
110117
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
111118

112119
shape, offset = compute_local_shape_and_global_offset(

0 commit comments

Comments
 (0)