| 
 | 1 | +import sys  | 
 | 2 | +import unittest  | 
 | 3 | +import torch  | 
 | 4 | +import torch_xla  | 
 | 5 | +from torch_xla.core.xla_builder import create_placeholder_tensor  | 
 | 6 | +import torch_xla.debug.metrics as met  | 
 | 7 | +import re  | 
 | 8 | +import torch_xla.runtime as xr  | 
 | 9 | +import torch_xla.distributed.spmd as xs  | 
 | 10 | + | 
 | 11 | +import test_xla_sharding_base  | 
 | 12 | + | 
 | 13 | + | 
 | 14 | +class TestSPMDPlaceholder(test_xla_sharding_base.XlaShardingTest):  | 
 | 15 | + | 
 | 16 | +  def setUp(self):  | 
 | 17 | +    super().setUpClass()  | 
 | 18 | + | 
 | 19 | +  def test_create_placeholder(self):  | 
 | 20 | +    num_devices = self.n_devices  | 
 | 21 | +    for shape, dtype in zip(  | 
 | 22 | +        ((num_devices, num_devices), (num_devices, num_devices, 2),  | 
 | 23 | +         (num_devices, num_devices, 2, 2)),  | 
 | 24 | +        (torch.float32, torch.bfloat16, torch.int8),  | 
 | 25 | +    ):  | 
 | 26 | +      model_axis = max(1, self.n_devices // 2)  | 
 | 27 | +      data_axis = self.n_devices // model_axis  | 
 | 28 | +      mesh_shape = (data_axis, model_axis) + (1,) * (len(shape) - 2)  | 
 | 29 | +      axis_names = ('x', 'y') + tuple(f'z{i}' for i in range(1, len(shape) - 1))  | 
 | 30 | +      mesh = self._get_mesh(mesh_shape, axis_names=axis_names)  | 
 | 31 | + | 
 | 32 | +      p = create_placeholder_tensor(shape, dtype)  | 
 | 33 | +      xs.mark_sharding(p, mesh, axis_names)  | 
 | 34 | +      assert isinstance(p, torch.Tensor)  | 
 | 35 | +      assert p.device == torch_xla.device()  | 
 | 36 | +      self.assertEqual(p.dtype, dtype)  | 
 | 37 | +      self.assertEqual(p.shape, shape)  | 
 | 38 | +      self.assertTrue(torch_xla._XLAC._is_placeholder(p))  | 
 | 39 | + | 
 | 40 | +  def test_read_value_crashes(self):  | 
 | 41 | +    mesh = self._get_mesh((self.n_devices,), axis_names=('x',))  | 
 | 42 | +    p = create_placeholder_tensor((self.n_devices,), torch.bfloat16)  | 
 | 43 | +    xs.mark_sharding(p, mesh, ('x',))  | 
 | 44 | +    with self.assertRaises(RuntimeError):  | 
 | 45 | +      p.cpu()  | 
 | 46 | + | 
 | 47 | +  def test_trace_graph(self):  | 
 | 48 | +    met.clear_all()  | 
 | 49 | +    self.assertFalse(met.metric_data("TransferToDeviceTime"))  | 
 | 50 | + | 
 | 51 | +    model_axis = max(1, self.n_devices // 2)  | 
 | 52 | +    data_axis = self.n_devices // model_axis  | 
 | 53 | +    mesh_shape = (data_axis, model_axis)  | 
 | 54 | +    mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y'))  | 
 | 55 | + | 
 | 56 | +    p1 = create_placeholder_tensor((128, 32), torch.bfloat16)  | 
 | 57 | +    xs.mark_sharding(p1, mesh, ('x', 'y'))  | 
 | 58 | +    a = torch.sin(p1)  | 
 | 59 | + | 
 | 60 | +    p2 = create_placeholder_tensor((32, 64), torch.bfloat16)  | 
 | 61 | +    xs.mark_sharding(p2, mesh, ('x', 'y'))  | 
 | 62 | +    # We use p1 once and p2 twice. But the graph should still only have two parameters.  | 
 | 63 | +    b = (a @ p2) @ p2.T  | 
 | 64 | +    ir: str = torch_xla._XLAC._get_xla_tensors_text([b])  | 
 | 65 | +    self.assertEqual(ir.count("xla::device_data()"), 2)  | 
 | 66 | +    self.assertEqual(ir.count("bf16[32,64]{1,0} xla::device_data()"), 1)  | 
 | 67 | +    self.assertEqual(ir.count("bf16[128,32]{1,0} xla::device_data()"), 1)  | 
 | 68 | +    hlo: str = torch_xla._XLAC._get_xla_tensors_hlo([b])  | 
 | 69 | +    regex = r'\(p.*: bf16\[32,64\], p.*: bf16\[128,32\]\) -> \(bf16\[128,32\]\)'  | 
 | 70 | +    assert re.search(regex, hlo) is not None  | 
 | 71 | + | 
 | 72 | +    # There should be no buffers transferred to the device during tracing  | 
 | 73 | +    self.assertFalse(met.metric_data("TransferToDeviceTime"))  | 
 | 74 | + | 
 | 75 | +  def test_placeholder_handle_unique(self):  | 
 | 76 | +    mesh = self._get_mesh((self.n_devices,), axis_names=('x',))  | 
 | 77 | + | 
 | 78 | +    p1 = create_placeholder_tensor((self.n_devices,), torch.bfloat16)  | 
 | 79 | +    xs.mark_sharding(p1, mesh, ('x',))  | 
 | 80 | + | 
 | 81 | +    p2 = create_placeholder_tensor((self.n_devices,), torch.bfloat16)  | 
 | 82 | +    xs.mark_sharding(p2, mesh, ('x',))  | 
 | 83 | + | 
 | 84 | +    h1, h2 = torch_xla._XLAC._get_tensors_handle([p1, p2])  | 
 | 85 | +    self.assertNotEqual(h1, h2)  | 
 | 86 | + | 
 | 87 | + | 
 | 88 | +if __name__ == "__main__":  | 
 | 89 | +  test = unittest.main()  | 
 | 90 | +  sys.exit(0 if test.result.wasSuccessful() else 1)  | 
0 commit comments