1+ import sys
2+ import unittest
3+ import torch
4+ import numpy as np
5+
6+ from torch .distributed .tensor import DeviceMesh
7+ from torch .distributed ._tensor import DTensor
8+ from torch .distributed .tensor .placement_types import Replicate , Shard
9+ import torch_xla
10+ import torch_xla .runtime as xr
11+ import torch_xla .core .xla_model as xm
12+ from torch_xla .distributed .spmd .xla_sharded_tensor import XLAShardedTensor
13+ import test_xla_sharding_base
14+
15+
16+ class DTensorXLAFromLocalConversionTest (test_xla_sharding_base .XlaShardingTest ):
17+ """
18+ Test suite for the automatic conversion of regular tensors to XLAShardedTensor
19+ in DTensor.from_local() when using XLA device mesh.
20+ """
21+
22+ @classmethod
23+ def setUpClass (cls ):
24+ super ().setUpClass ()
25+
26+ def test_to_local (self ):
27+ from torch .distributed .tensor import distribute_tensor
28+ world_size = xr .global_runtime_device_count ()
29+ mesh = DeviceMesh ("xla" , list (range (world_size )))
30+
31+ big_tensor = torch .randn (100000 , 88 )
32+ sharded_tensor = XLAShardedTensor (big_tensor , mesh , [Shard (0 )])
33+
34+ local_tensor = sharded_tensor .to_local ()
35+
36+ # Verify the shapes are the same
37+ self .assertEqual (local_tensor .shape , big_tensor .shape )
38+
39+ # Check the value of the tensor
40+ torch .testing .assert_close (local_tensor , big_tensor , check_device = False )
41+
42+ def test_to_local_requires_grad (self ):
43+ """Test that gradients flow correctly through to_local()."""
44+ # Create a tensor with requires_grad=True
45+ world_size = xr .global_runtime_device_count ()
46+ mesh = DeviceMesh ("xla" , list (range (world_size )))
47+
48+ tensor = torch .randn (100_000 , 88 , requires_grad = True )
49+
50+ # Create XLAShardedTensor
51+ sharded_tensor = XLAShardedTensor (tensor , mesh , [Shard (0 )])
52+
53+ # Verify requires_grad is set
54+ self .assertTrue (sharded_tensor .requires_grad )
55+
56+ res = sharded_tensor .sum ()
57+ res .backward ()
58+
59+ # Verify grad are calculated
60+ self .assertTrue (sharded_tensor .grad is not None )
61+
62+ # Call to local function
63+ local_tensor = sharded_tensor .to_local ()
64+
65+ # Verify requires_grad is preserved
66+ self .assertTrue (local_tensor .requires_grad )
67+
68+ # All gradients should be 1.0 since we did a sum()
69+ self .assertTrue (torch .allclose (local_tensor .grad , torch .ones_like (tensor )))
70+
71+ print ("Gradient flow test successful" )
72+
73+ if __name__ == "__main__" :
74+ result = unittest .main (exit = False )
75+ sys .exit (0 if result .result .wasSuccessful () else 1 )
0 commit comments