We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b883291 commit 998b595Copy full SHA for 998b595
tutorials/developer_api_guide/tensor_parallel.py
@@ -107,6 +107,13 @@ def shard(
107
device_mesh: DeviceMesh,
108
placements: Sequence[Placement],
109
) -> DTensor:
110
+ """
111
+ Add a shard function to simplify both colwise_shard and rowwise_shard. The
112
+ shard function accepts a full tensor, and returns a DTensor based on
113
+ indicated placements. Goal is to move the shard function as a static method
114
+ of DTensor, e.g.
115
+ dtensor = DTensor.shard(full_tensor, device_mesh, placement)
116
117
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
118
119
shape, offset = compute_local_shape_and_global_offset(
0 commit comments