66import unittest
77from unittest .mock import patch
88import sys
9- import os
109
1110import torch
1211from torch import nn
2726from torch_xla ._internal import tpu
2827
2928
30- def should_convert_to_shardy ():
31- return os .environ .get ("CONVERT_SHLO_TO_SHARDY" ,
32- "" ).lower () in ("1" , "true" , "yes" )
33-
34-
3529class BasicXlaShardingTest (test_xla_sharding_base .XlaShardingTest ):
3630
3731 @classmethod
3832 def setUpClass (cls ):
3933 super ().setUpClass ()
34+ cls .convert_to_shardy = xu .check_env_flag ("CONVERT_SHLO_TO_SHARDY" )
4035
4136 def test_xla_sharded_tensor (self ):
4237 partition_spec = (0 , 1 )
@@ -244,7 +239,7 @@ def test_custom_tile_assignment(self):
244239 if self .n_devices > 1 :
245240 annotation = '{devices=[1,%d]%s}' % (self .n_devices , ',' .join (
246241 [str (i ) for i in reversed (range (self .n_devices ))]))
247- if should_convert_to_shardy () :
242+ if self . convert_to_shardy :
248243 annotation = '{devices=[1,%d]<=[%d]}' % (self .n_devices , self .n_devices )
249244 self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (xt ))
250245
@@ -260,7 +255,7 @@ def test_mark_sharding_2d(self):
260255 if self .n_devices > 1 :
261256 annotation = '{devices=[1,%d]%s}' % (self .n_devices , ',' .join (
262257 [str (i ) for i in range (self .n_devices )]))
263- if should_convert_to_shardy () :
258+ if self . convert_to_shardy :
264259 annotation = '{devices=[1,%d]<=[%d]}' % (self .n_devices , self .n_devices )
265260 self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (xt1 ))
266261
@@ -281,7 +276,7 @@ def test_mark_sharding_4d(self):
281276 annotation = '{devices=[1,1,%d,%d]%s}' % (
282277 z_dim , self .n_devices // z_dim , ',' .join (
283278 [str (i ) for i in range (self .n_devices )]))
284- if should_convert_to_shardy () :
279+ if self . convert_to_shardy :
285280 annotation = '{devices=[1,1,%d,%d]<=[%d]}' % (z_dim , self .n_devices //
286281 z_dim , self .n_devices )
287282 self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (xt ))
@@ -418,7 +413,7 @@ def test_tupled_partition_spec(self):
418413 xs .mark_sharding (t , mesh , ((0 , 1 ),))
419414 annotation = "{devices=[%d]%s}" % (self .n_devices , ',' .join (
420415 str (x ) for x in range (self .n_devices )))
421- if should_convert_to_shardy () :
416+ if self . convert_to_shardy :
422417 annotation = "{devices=[%d]<=[%d]}" % (self .n_devices , self .n_devices )
423418 self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
424419
@@ -432,7 +427,7 @@ def test_named_partial_tupled_partition_spec(self):
432427 xs .mark_sharding (t , mesh , (('r' , 'b' ), None ))
433428 annotation = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % (
434429 self .n_devices // 2 , ',' .join (str (x ) for x in range (self .n_devices )))
435- if should_convert_to_shardy () :
430+ if self . convert_to_shardy :
436431 annotation = "{devices=[2,1,%d]<=[%d] last_tile_dim_replicate}" % (
437432 self .n_devices // 2 , self .n_devices )
438433 self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
@@ -442,7 +437,7 @@ def test_named_partial_tupled_partition_spec(self):
442437 xs .mark_sharding (u , mesh , (None , ('b' , 'm' )))
443438 annotation = "{devices=[1,%d]%s}" % (self .n_devices , ',' .join (
444439 str (x ) for x in range (self .n_devices )))
445- if should_convert_to_shardy () :
440+ if self . convert_to_shardy :
446441 annotation = "{devices=[1,%d]<=[%d]}" % (self .n_devices , self .n_devices )
447442 self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (u ), annotation )
448443
@@ -452,7 +447,7 @@ def test_named_partial_tupled_partition_spec(self):
452447 device_order = mesh .get_logical_mesh ().transpose ((0 , 2 , 1 )).flatten ()
453448 annotation = "{devices=[1,%d,2]%s last_tile_dim_replicate}" % (
454449 self .n_devices // 2 , ',' .join (str (x ) for x in device_order ))
455- if should_convert_to_shardy () :
450+ if self . convert_to_shardy :
456451 annotation = "{devices=[1,%d,2]<=[2,%d]T(1,0) last_tile_dim_replicate}" % (
457452 self .n_devices // 2 , self .n_devices // 2 )
458453 self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (v ), annotation )
@@ -463,7 +458,7 @@ def test_named_partial_tupled_partition_spec(self):
463458 device_order = mesh .get_logical_mesh ().transpose ((2 , 1 , 0 )).flatten ()
464459 annotation = "{devices=[1,%d]%s}" % (self .n_devices , ',' .join (
465460 str (x ) for x in device_order ))
466- if should_convert_to_shardy () :
461+ if self . convert_to_shardy :
467462 annotation = "{devices=[1,%d]<=[2,%d]T(1,0)}" % (self .n_devices ,
468463 self .n_devices // 2 )
469464 self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (v ), annotation )
@@ -478,7 +473,7 @@ def test_multiple_tuples_in_spec(self):
478473 xs .mark_sharding (t , mesh , (('a' , 'b' ), ('c' , 'd' )))
479474 annotation = "{devices=[2,%d]%s}" % (self .n_devices // 2 , ',' .join (
480475 str (x ) for x in range (self .n_devices )))
481- if should_convert_to_shardy () :
476+ if self . convert_to_shardy :
482477 annotation = "{devices=[2,%d]<=[%d]}" % (self .n_devices // 2 ,
483478 self .n_devices )
484479 self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
@@ -491,7 +486,7 @@ def test_3d_tensor_2d_mesh(self):
491486 xs .mark_sharding (t , mesh , (None , 0 , 1 ))
492487 annotation = '{devices=[1,2,%d]%s}' % (self .n_devices // 2 , ',' .join (
493488 str (x ) for x in range (self .n_devices )))
494- if should_convert_to_shardy () :
489+ if self . convert_to_shardy :
495490 annotation = '{devices=[1,2,%d]<=[%d]}' % (self .n_devices // 2 ,
496491 self .n_devices )
497492 self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
@@ -1013,8 +1008,7 @@ def test_op_sharding_cache(self):
10131008
10141009 t = torch .randn (1 , self .n_devices ).to ('xla' )
10151010 xs .mark_sharding (t , mesh , (0 , 1 ))
1016- counter_name = "CreateIotaOpSharding" if should_convert_to_shardy (
1017- ) else "CreateOpSharding"
1011+ counter_name = "CreateIotaOpSharding" if self .convert_to_shardy else "CreateOpSharding"
10181012 self .assertIn (counter_name , met .counter_names ())
10191013 self .assertEqual (met .counter_value (counter_name ), 1 )
10201014
@@ -1435,7 +1429,7 @@ def test_data_loader_with_sharding(self):
14351429 data , _ = iter (train_device_loader ).__next__ ()
14361430 self .assertEqual (data .size (), torch .Size ([8 , 3 , 64 , 64 ]))
14371431 annotation = f"{{devices=[{ mesh .size ()} ,1,1,1]{ ',' .join ([str (i ) for i in range (mesh .size ())])} }}"
1438- if should_convert_to_shardy () :
1432+ if self . convert_to_shardy :
14391433 annotation = f"{{devices=[{ mesh .size ()} ,1,1,1]<=[{ mesh .size ()} ]}}"
14401434 self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (data ), annotation )
14411435
@@ -1458,7 +1452,7 @@ def test_data_loader_with_non_batch_size(self):
14581452 data , _ = iter (train_device_loader ).__next__ ()
14591453 self .assertEqual (data .size (), torch .Size ([mesh .size () - 1 , 3 , 64 , 64 ]))
14601454 annotation = f"{{devices=[{ mesh .size ()} ,1,1,1]{ ',' .join ([str (i ) for i in range (mesh .size ())])} }}"
1461- if should_convert_to_shardy () :
1455+ if self . convert_to_shardy :
14621456 annotation = f"{{devices=[{ mesh .size ()} ,1,1,1]<=[{ mesh .size ()} ]}}"
14631457 self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (data ), annotation )
14641458
0 commit comments