@@ -31,6 +31,7 @@ class BasicXlaShardingTest(test_xla_sharding_base.XlaShardingTest):
3131  @classmethod  
3232  def  setUpClass (cls ):
3333    super ().setUpClass ()
34+     cls .convert_to_shardy  =  xu .check_env_flag ("CONVERT_SHLO_TO_SHARDY" )
3435
3536  def  test_xla_sharded_tensor (self ):
3637    partition_spec  =  (0 , 1 )
@@ -238,6 +239,8 @@ def test_custom_tile_assignment(self):
238239    if  self .n_devices  >  1 :
239240      annotation  =  '{devices=[1,%d]%s}'  %  (self .n_devices , ',' .join (
240241          [str (i ) for  i  in  reversed (range (self .n_devices ))]))
242+       if  self .convert_to_shardy :
243+         annotation  =  '{devices=[1,%d]<=[%d]}'  %  (self .n_devices , self .n_devices )
241244      self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (xt ))
242245
243246  def  test_mark_sharding_2d (self ):
@@ -252,6 +255,8 @@ def test_mark_sharding_2d(self):
252255    if  self .n_devices  >  1 :
253256      annotation  =  '{devices=[1,%d]%s}'  %  (self .n_devices , ',' .join (
254257          [str (i ) for  i  in  range (self .n_devices )]))
258+       if  self .convert_to_shardy :
259+         annotation  =  '{devices=[1,%d]<=[%d]}'  %  (self .n_devices , self .n_devices )
255260      self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (xt1 ))
256261
257262    actual  =  (xt1  +  xt2 ).cpu ()
@@ -271,6 +276,9 @@ def test_mark_sharding_4d(self):
271276      annotation  =  '{devices=[1,1,%d,%d]%s}'  %  (
272277          z_dim , self .n_devices  //  z_dim , ',' .join (
273278              [str (i ) for  i  in  range (self .n_devices )]))
279+       if  self .convert_to_shardy :
280+         annotation  =  '{devices=[1,1,%d,%d]<=[%d]}'  %  (z_dim , self .n_devices  // 
281+                                                       z_dim , self .n_devices )
274282      self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (xt ))
275283
276284    actual  =  (xt  +  xt ).cpu ()
@@ -403,9 +411,11 @@ def test_tupled_partition_spec(self):
403411    mesh  =  self ._get_mesh ((2 , self .n_devices  //  2 ))
404412    t  =  torch .randn (16 ).to ('xla' )
405413    xs .mark_sharding (t , mesh , ((0 , 1 ),))
406-     self .assertEqual (
407-         torch_xla ._XLAC ._get_xla_sharding_spec (t ), "{devices=[%d]%s}"  % 
408-         (self .n_devices , ',' .join (str (x ) for  x  in  range (self .n_devices ))))
414+     annotation  =  "{devices=[%d]%s}"  %  (self .n_devices , ',' .join (
415+         str (x ) for  x  in  range (self .n_devices )))
416+     if  self .convert_to_shardy :
417+       annotation  =  "{devices=[%d]<=[%d]}"  %  (self .n_devices , self .n_devices )
418+     self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
409419
410420  @unittest .skipUnless (xr .global_runtime_device_count () >=  4 , 
411421                       "Multiple devices required for tupled partition spec" ) 
@@ -415,34 +425,43 @@ def test_named_partial_tupled_partition_spec(self):
415425    # Shard the first dimension on `r` and `b`, replicate the second dimension 
416426    t  =  torch .randn (16 , 16 ).to ('xla' )
417427    xs .mark_sharding (t , mesh , (('r' , 'b' ), None ))
418-     self .assertEqual (
419-         torch_xla ._XLAC ._get_xla_sharding_spec (t ),
420-         "{devices=[2,1,%d]%s last_tile_dim_replicate}"  % 
421-         (self .n_devices  //  2 , ',' .join (str (x ) for  x  in  range (self .n_devices ))))
428+     annotation  =  "{devices=[2,1,%d]%s last_tile_dim_replicate}"  %  (
429+         self .n_devices  //  2 , ',' .join (str (x ) for  x  in  range (self .n_devices )))
430+     if  self .convert_to_shardy :
431+       annotation  =  "{devices=[2,1,%d]<=[%d] last_tile_dim_replicate}"  %  (
432+           self .n_devices  //  2 , self .n_devices )
433+     self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
422434
423435    # Replicate the first dimension, shard the second on `b` and `m` 
424436    u  =  torch .randn (16 , 16 ).to ('xla' )
425437    xs .mark_sharding (u , mesh , (None , ('b' , 'm' )))
426-     self .assertEqual (
427-         torch_xla ._XLAC ._get_xla_sharding_spec (u ), "{devices=[1,%d]%s}"  % 
428-         (self .n_devices , ',' .join (str (x ) for  x  in  range (self .n_devices ))))
438+     annotation  =  "{devices=[1,%d]%s}"  %  (self .n_devices , ',' .join (
439+         str (x ) for  x  in  range (self .n_devices )))
440+     if  self .convert_to_shardy :
441+       annotation  =  "{devices=[1,%d]<=[%d]}"  %  (self .n_devices , self .n_devices )
442+     self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (u ), annotation )
429443
430444    # Replicate the first dimension, shard the second on `r` and `m` 
431445    v  =  torch .randn (16 , 16 ).to ('xla' )
432446    xs .mark_sharding (v , mesh , (None , ('r' , 'm' )))
433447    device_order  =  mesh .get_logical_mesh ().transpose ((0 , 2 , 1 )).flatten ()
434-     self .assertEqual (
435-         torch_xla ._XLAC ._get_xla_sharding_spec (v ),
436-         "{devices=[1,%d,2]%s last_tile_dim_replicate}"  % 
437-         (self .n_devices  //  2 , ',' .join (str (x ) for  x  in  device_order )))
448+     annotation  =  "{devices=[1,%d,2]%s last_tile_dim_replicate}"  %  (
449+         self .n_devices  //  2 , ',' .join (str (x ) for  x  in  device_order ))
450+     if  self .convert_to_shardy :
451+       annotation  =  "{devices=[1,%d,2]<=[2,%d]T(1,0) last_tile_dim_replicate}"  %  (
452+           self .n_devices  //  2 , self .n_devices  //  2 )
453+     self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (v ), annotation )
438454
439455    # Replicate the first dimension, shard the second on `m` and `b` 
440456    v  =  torch .randn (16 , 16 ).to ('xla' )
441457    xs .mark_sharding (v , mesh , (None , ('m' , 'b' )))
442458    device_order  =  mesh .get_logical_mesh ().transpose ((2 , 1 , 0 )).flatten ()
443-     self .assertEqual (
444-         torch_xla ._XLAC ._get_xla_sharding_spec (v ), "{devices=[1,%d]%s}"  % 
445-         (self .n_devices , ',' .join (str (x ) for  x  in  device_order )))
459+     annotation  =  "{devices=[1,%d]%s}"  %  (self .n_devices , ',' .join (
460+         str (x ) for  x  in  device_order ))
461+     if  self .convert_to_shardy :
462+       annotation  =  "{devices=[1,%d]<=[2,%d]T(1,0)}"  %  (self .n_devices ,
463+                                                        self .n_devices  //  2 )
464+     self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (v ), annotation )
446465
447466  @unittest .skipUnless (xr .global_runtime_device_count () >  1 , 
448467                       'Multiple devices required for tupled partition spec' ) 
@@ -452,19 +471,25 @@ def test_multiple_tuples_in_spec(self):
452471        ('a' , 'b' , 'c' , 'd' ))
453472    t  =  torch .randn (2 , 2 ).to ('xla' )
454473    xs .mark_sharding (t , mesh , (('a' , 'b' ), ('c' , 'd' )))
455-     self .assertEqual (
456-         torch_xla ._XLAC ._get_xla_sharding_spec (t ), "{devices=[2,%d]%s}"  % 
457-         (self .n_devices  //  2 , ',' .join (str (x ) for  x  in  range (self .n_devices ))))
474+     annotation  =  "{devices=[2,%d]%s}"  %  (self .n_devices  //  2 , ',' .join (
475+         str (x ) for  x  in  range (self .n_devices )))
476+     if  self .convert_to_shardy :
477+       annotation  =  "{devices=[2,%d]<=[%d]}"  %  (self .n_devices  //  2 ,
478+                                                self .n_devices )
479+     self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
458480
459481  @unittest .skipUnless (xr .global_runtime_device_count () >  1 , 
460482                       'At least 2 devices needed for 2D mesh' ) 
461483  def  test_3d_tensor_2d_mesh (self ):
462484    mesh  =  self ._get_mesh ((2 , self .n_devices  //  2 ))
463485    t  =  torch .randn (16 , 16 , 16 ).to ('xla' )
464486    xs .mark_sharding (t , mesh , (None , 0 , 1 ))
465-     self .assertEqual (
466-         torch_xla ._XLAC ._get_xla_sharding_spec (t ), '{devices=[1,2,%d]%s}'  % 
467-         (self .n_devices  //  2 , ',' .join (str (x ) for  x  in  range (self .n_devices ))))
487+     annotation  =  '{devices=[1,2,%d]%s}'  %  (self .n_devices  //  2 , ',' .join (
488+         str (x ) for  x  in  range (self .n_devices )))
489+     if  self .convert_to_shardy :
490+       annotation  =  '{devices=[1,2,%d]<=[%d]}'  %  (self .n_devices  //  2 ,
491+                                                  self .n_devices )
492+     self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
468493
469494  def  test_partial_replication_addmm (self ):
470495    device  =  torch_xla .device ()
@@ -984,18 +1009,20 @@ def test_op_sharding_cache(self):
9841009
9851010    t  =  torch .randn (1 , self .n_devices ).to ('xla' )
9861011    xs .mark_sharding (t , mesh , (0 , 1 ))
987-     self .assertIn ("CreateOpSharding" , met .counter_names ())
988-     self .assertEqual (met .counter_value ("CreateOpSharding" ), 1 )
1012+     counter_name  =  "CreateIotaOpSharding"  if  self .convert_to_shardy  else  "CreateOpSharding" 
1013+     self .assertIn (counter_name , met .counter_names ())
1014+     self .assertEqual (met .counter_value (counter_name ), 1 )
9891015
9901016    # Sharding with the same partition spec should not result in another call 
9911017    u  =  torch .randn (1 , self .n_devices ).to ('xla' )
9921018    xs .mark_sharding (u , mesh , (0 , 1 ))
993-     self .assertEqual (met .counter_value ("CreateOpSharding" ), 1 )
1019+     self .assertEqual (met .counter_value (counter_name ), 1 )
9941020
995-     # Changing the partition spec will result in another CreateOpSharding 
1021+     # Changing the partition spec will result in another 
1022+     # CreateOpSharding or CreatingIotaOpSharding call 
9961023    v  =  torch .randn (1 , self .n_devices ).to ('xla' )
9971024    xs .mark_sharding (v , mesh , (0 , None ))
998-     self .assertEqual (met .counter_value ("CreateOpSharding" ), 2 )
1025+     self .assertEqual (met .counter_value (counter_name ), 2 )
9991026
10001027  def  test_from_cpu_shards_replicated (self ):
10011028    from_cpu_shards  =  torch_xla ._XLAC ._global_tensor_from_cpu_shards 
@@ -1398,10 +1425,10 @@ def test_data_loader_with_sharding(self):
13981425        input_sharding = xs .ShardingSpec (mesh , ('data' , None , None , None )))
13991426    data , _  =  iter (train_device_loader ).__next__ ()
14001427    self .assertEqual (data .size (), torch .Size ([8 , 3 , 64 , 64 ]))
1401-     self . assertEqual ( 
1402-          torch_xla . _XLAC . _get_xla_sharding_spec ( data ), 
1403-          f"{{devices=[{ mesh .size ()}  ,1,1,1]{ ',' . join ([ str ( i )  for   i   in   range ( mesh .size ())]) }  }}" 
1404-     )
1428+     annotation   =   f"{{devices=[ { mesh . size () } ,1,1,1] { ',' . join ([ str ( i )  for   i   in   range ( mesh . size ())]) } }}" 
1429+     if   self . convert_to_shardy : 
1430+       annotation   =  f"{{devices=[{ mesh .size ()}  ,1,1,1]<=[ { mesh .size ()} ] }}" 
1431+     self . assertEqual ( torch_xla . _XLAC . _get_xla_sharding_spec ( data ),  annotation )
14051432
14061433  @unittest .skipUnless ( 
14071434      xr .global_runtime_device_count () >  1 , 
@@ -1421,10 +1448,10 @@ def test_data_loader_with_non_batch_size(self):
14211448        input_sharding = xs .ShardingSpec (mesh , ('data' , None , None , None )))
14221449    data , _  =  iter (train_device_loader ).__next__ ()
14231450    self .assertEqual (data .size (), torch .Size ([mesh .size () -  1 , 3 , 64 , 64 ]))
1424-     self . assertEqual ( 
1425-          torch_xla . _XLAC . _get_xla_sharding_spec ( data ), 
1426-          f"{{devices=[{ mesh .size ()}  ,1,1,1]{ ',' . join ([ str ( i )  for   i   in   range ( mesh .size ())]) }  }}" 
1427-     )
1451+     annotation   =   f"{{devices=[ { mesh . size () } ,1,1,1] { ',' . join ([ str ( i )  for   i   in   range ( mesh . size ())]) } }}" 
1452+     if   self . convert_to_shardy : 
1453+       annotation   =  f"{{devices=[{ mesh .size ()}  ,1,1,1]<=[ { mesh .size ()} ] }}" 
1454+     self . assertEqual ( torch_xla . _XLAC . _get_xla_sharding_spec ( data ),  annotation )
14281455
14291456  @unittest .skipUnless ( 
14301457      xr .global_runtime_device_count () >  1 , 
0 commit comments