We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3799ac9 commit 7b90ea5Copy full SHA for 7b90ea5
examples/pytorch/language-modeling/run_clm.py
@@ -488,7 +488,7 @@ def main():
488
max_dim = np.argmax(param.shape)
489
shape = [1] * len(param.shape)
490
shape[max_dim] = num_devices
491
- mesh = xs.Mesh(device_ids, tuple(shape))
+ mesh = xs.HybridMesh(ici_mesh_shape=tuple(shape))
492
xs.mark_sharding(param, mesh, range(len(param.shape)))
493
494
0 commit comments