Skip to content

Commit 0fa6e31

Browse files
hshahTTsshonTT
andauthored
Add Shardy support for Torch-XLA (#9541)
This PR implements the necessary changes to support the Shardy dialect within Torch-XLA (relevant issue: #9348): 1. Adding support for V2 HLO sharding within the `OpSharding` and `XlaShardingSpec` classes (since Shardy doesn't support the V1 shardings that are currently implemented). 2. Add the OpenXLA [`addStablehloImportPipeline()`](https://github.com/openxla/xla/blob/0cead9fb0f6b5a3effbfc90858640f2234e1c76c/xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_import.cc#L762) pass that performs the SHLO to Shardy conversion. 3. This is protected by the `"CONVERT_SHLO_TO_SHARDY"` environment variable. --------- Co-authored-by: Sungjoon Shon <[email protected]>
1 parent b1131d1 commit 0fa6e31

File tree

11 files changed

+389
-61
lines changed

11 files changed

+389
-61
lines changed

test/spmd/test_spmd_debugging.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch_xla.distributed.spmd as xs
1818
from torch_xla.distributed.spmd import XLAShardedTensor
1919
from torch_xla.distributed.spmd import Mesh
20+
from torch_xla.distributed.spmd.debugging import construct_v1_sharding_str
2021

2122
import test_xla_sharding_base
2223

@@ -822,6 +823,96 @@ def test_multi_host_replicated_cpu(self):
822823
fake_output = fake_capture.get()
823824
assert output == fake_output
824825

826+
827+
class ConvertV2ShardingToV1Test(test_xla_sharding_base.XlaShardingTest):
828+
829+
@classmethod
830+
def setUpClass(cls):
831+
super().setUpClass()
832+
os.environ["CONVERT_SHLO_TO_SHARDY"] = "1"
833+
834+
def run_test(self):
835+
mesh = self._get_mesh(self.device_mesh_shape)
836+
t = torch.randn(self.tensor_shape).to(torch_xla.device())
837+
xs.mark_sharding(t, mesh, self.partition_spec)
838+
actual_str = construct_v1_sharding_str(t)
839+
self.assertEqual(self.expected_str, actual_str)
840+
841+
def test_tiled_sharding(self):
842+
self.device_mesh_shape = (1, self.n_devices)
843+
self.tensor_shape = (1, 128)
844+
self.partition_spec = (0, 1)
845+
if self.n_devices == 1:
846+
# Any tiled sharding on a single device should be treated as replicated.
847+
self.expected_str = '{replicated}'
848+
else:
849+
self.expected_str = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
850+
[str(i) for i in range(self.n_devices)]))
851+
self.run_test()
852+
853+
self.partition_spec = (1, 0)
854+
if self.n_devices == 1:
855+
self.expected_str = '{replicated}'
856+
else:
857+
self.expected_str = '{devices=[%d,1]%s}' % (self.n_devices, ','.join(
858+
[str(i) for i in range(self.n_devices)]))
859+
self.run_test()
860+
861+
@unittest.skipIf(xr.global_runtime_device_count() < 2,
862+
f"Requires at least 2 devices.")
863+
def test_tupled_tiled_sharding(self):
864+
self.device_mesh_shape = (2, self.n_devices // 2)
865+
self.tensor_shape = (16,)
866+
self.partition_spec = ((0, 1),)
867+
self.expected_str = "{devices=[%d]%s}" % (self.n_devices, ','.join(
868+
str(x) for x in range(self.n_devices)))
869+
self.run_test()
870+
871+
def test_replicated_sharding(self):
872+
self.device_mesh_shape = (1, self.n_devices)
873+
self.tensor_shape = (4, 4)
874+
self.partition_spec = (None, None)
875+
self.expected_str = '{replicated}'
876+
self.run_test()
877+
878+
@unittest.skipIf(xr.global_runtime_device_count() < 4,
879+
f"Requires at least 4 devices.")
880+
def test_partial_replication_sharding(self):
881+
self.device_mesh_shape = (2, self.n_devices // 2)
882+
self.tensor_shape = (4, 4)
883+
self.partition_spec = (0, None)
884+
self.expected_str = '{devices=[2,1,%d]%s last_tile_dim_replicate}' % (
885+
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
886+
self.run_test()
887+
888+
self.partition_spec = (None, 0)
889+
self.expected_str = '{devices=[1,2,%d]%s last_tile_dim_replicate}' % (
890+
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
891+
self.run_test()
892+
893+
@unittest.skipIf(xr.global_runtime_device_count() < 4,
894+
f"Requires at least 4 devices.")
895+
def test_tupled_partial_replication_sharding(self):
896+
self.device_mesh_shape = (1, 2, self.n_devices // 2)
897+
self.tensor_shape = (16, 16)
898+
self.partition_spec = ((0, 1), None)
899+
self.expected_str = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % (
900+
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
901+
self.run_test()
902+
903+
@unittest.skipIf(xr.global_runtime_device_count() < 4,
904+
f"Requires at least 4 devices.")
905+
def test_tupled_partial_replication_sharding_with_transpose(self):
906+
self.device_mesh_shape = (1, 2, self.n_devices // 2)
907+
self.tensor_shape = (16, 16)
908+
self.partition_spec = (None, (2, 1))
909+
device_order = self.device_ids.reshape(self.device_mesh_shape).transpose(
910+
(2, 1, 0)).flatten()
911+
self.expected_str = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
912+
str(x) for x in device_order))
913+
self.run_test()
914+
915+
825916
if __name__ == '__main__':
826917
test = unittest.main()
827918
sys.exit(0 if test.result.wasSuccessful() else 1)

test/spmd/test_xla_sharding.py

Lines changed: 63 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class BasicXlaShardingTest(test_xla_sharding_base.XlaShardingTest):
3131
@classmethod
3232
def setUpClass(cls):
3333
super().setUpClass()
34+
cls.convert_to_shardy = xu.check_env_flag("CONVERT_SHLO_TO_SHARDY")
3435

3536
def test_xla_sharded_tensor(self):
3637
partition_spec = (0, 1)
@@ -238,6 +239,8 @@ def test_custom_tile_assignment(self):
238239
if self.n_devices > 1:
239240
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
240241
[str(i) for i in reversed(range(self.n_devices))]))
242+
if self.convert_to_shardy:
243+
annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices)
241244
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))
242245

243246
def test_mark_sharding_2d(self):
@@ -252,6 +255,8 @@ def test_mark_sharding_2d(self):
252255
if self.n_devices > 1:
253256
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
254257
[str(i) for i in range(self.n_devices)]))
258+
if self.convert_to_shardy:
259+
annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices)
255260
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt1))
256261

257262
actual = (xt1 + xt2).cpu()
@@ -271,6 +276,9 @@ def test_mark_sharding_4d(self):
271276
annotation = '{devices=[1,1,%d,%d]%s}' % (
272277
z_dim, self.n_devices // z_dim, ','.join(
273278
[str(i) for i in range(self.n_devices)]))
279+
if self.convert_to_shardy:
280+
annotation = '{devices=[1,1,%d,%d]<=[%d]}' % (z_dim, self.n_devices //
281+
z_dim, self.n_devices)
274282
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))
275283

276284
actual = (xt + xt).cpu()
@@ -403,9 +411,11 @@ def test_tupled_partition_spec(self):
403411
mesh = self._get_mesh((2, self.n_devices // 2))
404412
t = torch.randn(16).to('xla')
405413
xs.mark_sharding(t, mesh, ((0, 1),))
406-
self.assertEqual(
407-
torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[%d]%s}" %
408-
(self.n_devices, ','.join(str(x) for x in range(self.n_devices))))
414+
annotation = "{devices=[%d]%s}" % (self.n_devices, ','.join(
415+
str(x) for x in range(self.n_devices)))
416+
if self.convert_to_shardy:
417+
annotation = "{devices=[%d]<=[%d]}" % (self.n_devices, self.n_devices)
418+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
409419

410420
@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
411421
"Multiple devices required for tupled partition spec")
@@ -415,34 +425,43 @@ def test_named_partial_tupled_partition_spec(self):
415425
# Shard the first dimension on `r` and `b`, replicate the second dimension
416426
t = torch.randn(16, 16).to('xla')
417427
xs.mark_sharding(t, mesh, (('r', 'b'), None))
418-
self.assertEqual(
419-
torch_xla._XLAC._get_xla_sharding_spec(t),
420-
"{devices=[2,1,%d]%s last_tile_dim_replicate}" %
421-
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
428+
annotation = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % (
429+
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
430+
if self.convert_to_shardy:
431+
annotation = "{devices=[2,1,%d]<=[%d] last_tile_dim_replicate}" % (
432+
self.n_devices // 2, self.n_devices)
433+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
422434

423435
# Replicate the first dimension, shard the second on `b` and `m`
424436
u = torch.randn(16, 16).to('xla')
425437
xs.mark_sharding(u, mesh, (None, ('b', 'm')))
426-
self.assertEqual(
427-
torch_xla._XLAC._get_xla_sharding_spec(u), "{devices=[1,%d]%s}" %
428-
(self.n_devices, ','.join(str(x) for x in range(self.n_devices))))
438+
annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
439+
str(x) for x in range(self.n_devices)))
440+
if self.convert_to_shardy:
441+
annotation = "{devices=[1,%d]<=[%d]}" % (self.n_devices, self.n_devices)
442+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(u), annotation)
429443

430444
# Replicate the first dimension, shard the second on `r` and `m`
431445
v = torch.randn(16, 16).to('xla')
432446
xs.mark_sharding(v, mesh, (None, ('r', 'm')))
433447
device_order = mesh.get_logical_mesh().transpose((0, 2, 1)).flatten()
434-
self.assertEqual(
435-
torch_xla._XLAC._get_xla_sharding_spec(v),
436-
"{devices=[1,%d,2]%s last_tile_dim_replicate}" %
437-
(self.n_devices // 2, ','.join(str(x) for x in device_order)))
448+
annotation = "{devices=[1,%d,2]%s last_tile_dim_replicate}" % (
449+
self.n_devices // 2, ','.join(str(x) for x in device_order))
450+
if self.convert_to_shardy:
451+
annotation = "{devices=[1,%d,2]<=[2,%d]T(1,0) last_tile_dim_replicate}" % (
452+
self.n_devices // 2, self.n_devices // 2)
453+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation)
438454

439455
# Replicate the first dimension, shard the second on `m` and `b`
440456
v = torch.randn(16, 16).to('xla')
441457
xs.mark_sharding(v, mesh, (None, ('m', 'b')))
442458
device_order = mesh.get_logical_mesh().transpose((2, 1, 0)).flatten()
443-
self.assertEqual(
444-
torch_xla._XLAC._get_xla_sharding_spec(v), "{devices=[1,%d]%s}" %
445-
(self.n_devices, ','.join(str(x) for x in device_order)))
459+
annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
460+
str(x) for x in device_order))
461+
if self.convert_to_shardy:
462+
annotation = "{devices=[1,%d]<=[2,%d]T(1,0)}" % (self.n_devices,
463+
self.n_devices // 2)
464+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation)
446465

447466
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
448467
'Multiple devices required for tupled partition spec')
@@ -452,19 +471,25 @@ def test_multiple_tuples_in_spec(self):
452471
('a', 'b', 'c', 'd'))
453472
t = torch.randn(2, 2).to('xla')
454473
xs.mark_sharding(t, mesh, (('a', 'b'), ('c', 'd')))
455-
self.assertEqual(
456-
torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[2,%d]%s}" %
457-
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
474+
annotation = "{devices=[2,%d]%s}" % (self.n_devices // 2, ','.join(
475+
str(x) for x in range(self.n_devices)))
476+
if self.convert_to_shardy:
477+
annotation = "{devices=[2,%d]<=[%d]}" % (self.n_devices // 2,
478+
self.n_devices)
479+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
458480

459481
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
460482
'At least 2 devices needed for 2D mesh')
461483
def test_3d_tensor_2d_mesh(self):
462484
mesh = self._get_mesh((2, self.n_devices // 2))
463485
t = torch.randn(16, 16, 16).to('xla')
464486
xs.mark_sharding(t, mesh, (None, 0, 1))
465-
self.assertEqual(
466-
torch_xla._XLAC._get_xla_sharding_spec(t), '{devices=[1,2,%d]%s}' %
467-
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
487+
annotation = '{devices=[1,2,%d]%s}' % (self.n_devices // 2, ','.join(
488+
str(x) for x in range(self.n_devices)))
489+
if self.convert_to_shardy:
490+
annotation = '{devices=[1,2,%d]<=[%d]}' % (self.n_devices // 2,
491+
self.n_devices)
492+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
468493

469494
def test_partial_replication_addmm(self):
470495
device = torch_xla.device()
@@ -984,18 +1009,20 @@ def test_op_sharding_cache(self):
9841009

9851010
t = torch.randn(1, self.n_devices).to('xla')
9861011
xs.mark_sharding(t, mesh, (0, 1))
987-
self.assertIn("CreateOpSharding", met.counter_names())
988-
self.assertEqual(met.counter_value("CreateOpSharding"), 1)
1012+
counter_name = "CreateIotaOpSharding" if self.convert_to_shardy else "CreateOpSharding"
1013+
self.assertIn(counter_name, met.counter_names())
1014+
self.assertEqual(met.counter_value(counter_name), 1)
9891015

9901016
# Sharding with the same partition spec should not result in another call
9911017
u = torch.randn(1, self.n_devices).to('xla')
9921018
xs.mark_sharding(u, mesh, (0, 1))
993-
self.assertEqual(met.counter_value("CreateOpSharding"), 1)
1019+
self.assertEqual(met.counter_value(counter_name), 1)
9941020

995-
# Changing the partition spec will result in another CreateOpSharding
1021+
# Changing the partition spec will result in another
1022+
# CreateOpSharding or CreatingIotaOpSharding call
9961023
v = torch.randn(1, self.n_devices).to('xla')
9971024
xs.mark_sharding(v, mesh, (0, None))
998-
self.assertEqual(met.counter_value("CreateOpSharding"), 2)
1025+
self.assertEqual(met.counter_value(counter_name), 2)
9991026

10001027
def test_from_cpu_shards_replicated(self):
10011028
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards
@@ -1398,10 +1425,10 @@ def test_data_loader_with_sharding(self):
13981425
input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None)))
13991426
data, _ = iter(train_device_loader).__next__()
14001427
self.assertEqual(data.size(), torch.Size([8, 3, 64, 64]))
1401-
self.assertEqual(
1402-
torch_xla._XLAC._get_xla_sharding_spec(data),
1403-
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
1404-
)
1428+
annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
1429+
if self.convert_to_shardy:
1430+
annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}"
1431+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation)
14051432

14061433
@unittest.skipUnless(
14071434
xr.global_runtime_device_count() > 1,
@@ -1421,10 +1448,10 @@ def test_data_loader_with_non_batch_size(self):
14211448
input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None)))
14221449
data, _ = iter(train_device_loader).__next__()
14231450
self.assertEqual(data.size(), torch.Size([mesh.size() - 1, 3, 64, 64]))
1424-
self.assertEqual(
1425-
torch_xla._XLAC._get_xla_sharding_spec(data),
1426-
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
1427-
)
1451+
annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
1452+
if self.convert_to_shardy:
1453+
annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}"
1454+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation)
14281455

14291456
@unittest.skipUnless(
14301457
xr.global_runtime_device_count() > 1,

0 commit comments

Comments
 (0)