Skip to content
91 changes: 91 additions & 0 deletions test/spmd/test_spmd_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import XLAShardedTensor
from torch_xla.distributed.spmd import Mesh
from torch_xla.distributed.spmd.debugging import construct_v1_sharding_str

import test_xla_sharding_base

Expand Down Expand Up @@ -822,6 +823,96 @@ def test_multi_host_replicated_cpu(self):
fake_output = fake_capture.get()
assert output == fake_output


class ConvertV2ShardingToV1Test(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
super().setUpClass()
os.environ["CONVERT_SHLO_TO_SHARDY"] = "1"

def run_test(self):
mesh = self._get_mesh(self.device_mesh_shape)
t = torch.randn(self.tensor_shape).to(torch_xla.device())
xs.mark_sharding(t, mesh, self.partition_spec)
actual_str = construct_v1_sharding_str(t)
self.assertEqual(self.expected_str, actual_str)

def test_tiled_sharding(self):
self.device_mesh_shape = (1, self.n_devices)
self.tensor_shape = (1, 128)
self.partition_spec = (0, 1)
if self.n_devices == 1:
# Any tiled sharding on a single device should be treated as replicated.
self.expected_str = '{replicated}'
else:
self.expected_str = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
[str(i) for i in range(self.n_devices)]))
self.run_test()

self.partition_spec = (1, 0)
if self.n_devices == 1:
self.expected_str = '{replicated}'
else:
self.expected_str = '{devices=[%d,1]%s}' % (self.n_devices, ','.join(
[str(i) for i in range(self.n_devices)]))
self.run_test()

@unittest.skipIf(xr.global_runtime_device_count() < 2,
f"Requires at least 2 devices.")
def test_tupled_tiled_sharding(self):
self.device_mesh_shape = (2, self.n_devices // 2)
self.tensor_shape = (16,)
self.partition_spec = ((0, 1),)
self.expected_str = "{devices=[%d]%s}" % (self.n_devices, ','.join(
str(x) for x in range(self.n_devices)))
self.run_test()

def test_replicated_sharding(self):
self.device_mesh_shape = (1, self.n_devices)
self.tensor_shape = (4, 4)
self.partition_spec = (None, None)
self.expected_str = '{replicated}'
self.run_test()

@unittest.skipIf(xr.global_runtime_device_count() < 4,
f"Requires at least 4 devices.")
def test_partial_replication_sharding(self):
self.device_mesh_shape = (2, self.n_devices // 2)
self.tensor_shape = (4, 4)
self.partition_spec = (0, None)
self.expected_str = '{devices=[2,1,%d]%s last_tile_dim_replicate}' % (
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
self.run_test()

self.partition_spec = (None, 0)
self.expected_str = '{devices=[1,2,%d]%s last_tile_dim_replicate}' % (
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
self.run_test()

@unittest.skipIf(xr.global_runtime_device_count() < 4,
f"Requires at least 4 devices.")
def test_tupled_partial_replication_sharding(self):
self.device_mesh_shape = (1, 2, self.n_devices // 2)
self.tensor_shape = (16, 16)
self.partition_spec = ((0, 1), None)
self.expected_str = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % (
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
self.run_test()

@unittest.skipIf(xr.global_runtime_device_count() < 4,
f"Requires at least 4 devices.")
def test_tupled_partial_replication_sharding_with_transpose(self):
self.device_mesh_shape = (1, 2, self.n_devices // 2)
self.tensor_shape = (16, 16)
self.partition_spec = (None, (2, 1))
device_order = self.device_ids.reshape(self.device_mesh_shape).transpose(
(2, 1, 0)).flatten()
self.expected_str = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
str(x) for x in device_order))
self.run_test()


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
99 changes: 63 additions & 36 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class BasicXlaShardingTest(test_xla_sharding_base.XlaShardingTest):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.convert_to_shardy = xu.check_env_flag("CONVERT_SHLO_TO_SHARDY")

def test_xla_sharded_tensor(self):
partition_spec = (0, 1)
Expand Down Expand Up @@ -238,6 +239,8 @@ def test_custom_tile_assignment(self):
if self.n_devices > 1:
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
[str(i) for i in reversed(range(self.n_devices))]))
if self.convert_to_shardy:
annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices)
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))

def test_mark_sharding_2d(self):
Expand All @@ -252,6 +255,8 @@ def test_mark_sharding_2d(self):
if self.n_devices > 1:
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
[str(i) for i in range(self.n_devices)]))
if self.convert_to_shardy:
annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices)
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt1))

actual = (xt1 + xt2).cpu()
Expand All @@ -271,6 +276,9 @@ def test_mark_sharding_4d(self):
annotation = '{devices=[1,1,%d,%d]%s}' % (
z_dim, self.n_devices // z_dim, ','.join(
[str(i) for i in range(self.n_devices)]))
if self.convert_to_shardy:
annotation = '{devices=[1,1,%d,%d]<=[%d]}' % (z_dim, self.n_devices //
z_dim, self.n_devices)
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))

actual = (xt + xt).cpu()
Expand Down Expand Up @@ -403,9 +411,11 @@ def test_tupled_partition_spec(self):
mesh = self._get_mesh((2, self.n_devices // 2))
t = torch.randn(16).to('xla')
xs.mark_sharding(t, mesh, ((0, 1),))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[%d]%s}" %
(self.n_devices, ','.join(str(x) for x in range(self.n_devices))))
annotation = "{devices=[%d]%s}" % (self.n_devices, ','.join(
str(x) for x in range(self.n_devices)))
if self.convert_to_shardy:
annotation = "{devices=[%d]<=[%d]}" % (self.n_devices, self.n_devices)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)

@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
"Multiple devices required for tupled partition spec")
Expand All @@ -415,34 +425,43 @@ def test_named_partial_tupled_partition_spec(self):
# Shard the first dimension on `r` and `b`, replicate the second dimension
t = torch.randn(16, 16).to('xla')
xs.mark_sharding(t, mesh, (('r', 'b'), None))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(t),
"{devices=[2,1,%d]%s last_tile_dim_replicate}" %
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
annotation = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % (
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
if self.convert_to_shardy:
annotation = "{devices=[2,1,%d]<=[%d] last_tile_dim_replicate}" % (
self.n_devices // 2, self.n_devices)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)

# Replicate the first dimension, shard the second on `b` and `m`
u = torch.randn(16, 16).to('xla')
xs.mark_sharding(u, mesh, (None, ('b', 'm')))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(u), "{devices=[1,%d]%s}" %
(self.n_devices, ','.join(str(x) for x in range(self.n_devices))))
annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
str(x) for x in range(self.n_devices)))
if self.convert_to_shardy:
annotation = "{devices=[1,%d]<=[%d]}" % (self.n_devices, self.n_devices)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(u), annotation)

# Replicate the first dimension, shard the second on `r` and `m`
v = torch.randn(16, 16).to('xla')
xs.mark_sharding(v, mesh, (None, ('r', 'm')))
device_order = mesh.get_logical_mesh().transpose((0, 2, 1)).flatten()
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(v),
"{devices=[1,%d,2]%s last_tile_dim_replicate}" %
(self.n_devices // 2, ','.join(str(x) for x in device_order)))
annotation = "{devices=[1,%d,2]%s last_tile_dim_replicate}" % (
self.n_devices // 2, ','.join(str(x) for x in device_order))
if self.convert_to_shardy:
annotation = "{devices=[1,%d,2]<=[2,%d]T(1,0) last_tile_dim_replicate}" % (
self.n_devices // 2, self.n_devices // 2)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation)

# Replicate the first dimension, shard the second on `m` and `b`
v = torch.randn(16, 16).to('xla')
xs.mark_sharding(v, mesh, (None, ('m', 'b')))
device_order = mesh.get_logical_mesh().transpose((2, 1, 0)).flatten()
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(v), "{devices=[1,%d]%s}" %
(self.n_devices, ','.join(str(x) for x in device_order)))
annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
str(x) for x in device_order))
if self.convert_to_shardy:
annotation = "{devices=[1,%d]<=[2,%d]T(1,0)}" % (self.n_devices,
self.n_devices // 2)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation)

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
'Multiple devices required for tupled partition spec')
Expand All @@ -452,19 +471,25 @@ def test_multiple_tuples_in_spec(self):
('a', 'b', 'c', 'd'))
t = torch.randn(2, 2).to('xla')
xs.mark_sharding(t, mesh, (('a', 'b'), ('c', 'd')))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[2,%d]%s}" %
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
annotation = "{devices=[2,%d]%s}" % (self.n_devices // 2, ','.join(
str(x) for x in range(self.n_devices)))
if self.convert_to_shardy:
annotation = "{devices=[2,%d]<=[%d]}" % (self.n_devices // 2,
self.n_devices)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
'At least 2 devices needed for 2D mesh')
def test_3d_tensor_2d_mesh(self):
mesh = self._get_mesh((2, self.n_devices // 2))
t = torch.randn(16, 16, 16).to('xla')
xs.mark_sharding(t, mesh, (None, 0, 1))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(t), '{devices=[1,2,%d]%s}' %
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
annotation = '{devices=[1,2,%d]%s}' % (self.n_devices // 2, ','.join(
str(x) for x in range(self.n_devices)))
if self.convert_to_shardy:
annotation = '{devices=[1,2,%d]<=[%d]}' % (self.n_devices // 2,
self.n_devices)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)

def test_partial_replication_addmm(self):
device = torch_xla.device()
Expand Down Expand Up @@ -983,18 +1008,20 @@ def test_op_sharding_cache(self):

t = torch.randn(1, self.n_devices).to('xla')
xs.mark_sharding(t, mesh, (0, 1))
self.assertIn("CreateOpSharding", met.counter_names())
self.assertEqual(met.counter_value("CreateOpSharding"), 1)
counter_name = "CreateIotaOpSharding" if self.convert_to_shardy else "CreateOpSharding"
self.assertIn(counter_name, met.counter_names())
self.assertEqual(met.counter_value(counter_name), 1)

# Sharding with the same partition spec should not result in another call
u = torch.randn(1, self.n_devices).to('xla')
xs.mark_sharding(u, mesh, (0, 1))
self.assertEqual(met.counter_value("CreateOpSharding"), 1)
self.assertEqual(met.counter_value(counter_name), 1)

# Changing the partition spec will result in another CreateOpSharding
# Changing the partition spec will result in another
# CreateOpSharding or CreatingIotaOpSharding call
v = torch.randn(1, self.n_devices).to('xla')
xs.mark_sharding(v, mesh, (0, None))
self.assertEqual(met.counter_value("CreateOpSharding"), 2)
self.assertEqual(met.counter_value(counter_name), 2)

def test_from_cpu_shards_replicated(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards
Expand Down Expand Up @@ -1397,10 +1424,10 @@ def test_data_loader_with_sharding(self):
input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None)))
data, _ = iter(train_device_loader).__next__()
self.assertEqual(data.size(), torch.Size([8, 3, 64, 64]))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(data),
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
)
annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
if self.convert_to_shardy:
annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}"
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation)

@unittest.skipUnless(
xr.global_runtime_device_count() > 1,
Expand All @@ -1420,10 +1447,10 @@ def test_data_loader_with_non_batch_size(self):
input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None)))
data, _ = iter(train_device_loader).__next__()
self.assertEqual(data.size(), torch.Size([mesh.size() - 1, 3, 64, 64]))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(data),
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
)
annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
if self.convert_to_shardy:
annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}"
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation)

@unittest.skipUnless(
xr.global_runtime_device_count() > 1,
Expand Down
Loading