Skip to content

Commit 7b90ea5

Browse files
committed
Use HybridMesh
1 parent 3799ac9 commit 7b90ea5

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

examples/pytorch/language-modeling/run_clm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def main():
488488
max_dim = np.argmax(param.shape)
489489
shape = [1] * len(param.shape)
490490
shape[max_dim] = num_devices
491-
mesh = xs.Mesh(device_ids, tuple(shape))
491+
mesh = xs.HybridMesh(ici_mesh_shape=tuple(shape))
492492
xs.mark_sharding(param, mesh, range(len(param.shape)))
493493

494494

0 commit comments

Comments
 (0)