@@ -57,24 +57,26 @@ def visualize_sharding(sharding: str,
5757 # eg: '{devices=[2,2]0,1,2,3}'
5858 # eg: '{replicated}'
5959 # eg: '{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}'
60+ print (f"Visualizing { sharding } (showing up to the first two dimensions)" )
6061 if sharding == '{replicated}' or len (sharding ) == 0 :
6162 heights = 1
6263 widths = 1
6364 num_devices = xr .global_runtime_device_count ()
6465 device_ids = list (range (num_devices ))
6566 slices .setdefault ((0 , 0 ), device_ids )
6667 else :
67- sharding_spac = sharding [sharding .index ('[' ):sharding .index (']' ) + 1 ]
68+ sharding_spec = sharding [sharding .index ('[' ) +
69+ 1 :sharding .index (']' )].split ("," )
6870 device_list_original = sharding .split (' last_tile_dim_replicate' )
6971 if len (device_list_original ) == 2 and device_list_original [1 ] == '}' :
7072 try :
7173 device_list_original_first = device_list_original [0 ]
7274 device_list = device_list_original_first [device_list_original_first .
7375 index (']' ) + 1 :]
7476 device_indices_map = [int (s ) for s in device_list .split (',' )]
75- heights = int (sharding_spac [ 1 ])
76- widths = int (sharding_spac [ 3 ])
77- last_dim_depth = int (sharding_spac [ 5 ])
77+ heights = int (sharding_spec [ 0 ])
78+ widths = int (sharding_spec [ 1 ])
79+ last_dim_depth = int (sharding_spec [ - 1 ])
7880 devices_len = len (device_indices_map )
7981 len_after_dim_down = devices_len // last_dim_depth
8082 for i in range (len_after_dim_down ):
@@ -96,8 +98,8 @@ def visualize_sharding(sharding: str,
9698 device_list = device_list_original_first [device_list_original_first .
9799 index (']' ) + 1 :- 1 ]
98100 device_indices_map = [int (i ) for i in device_list .split (',' )]
99- heights = int (sharding_spac [ 1 ])
100- widths = int (sharding_spac [ 3 ])
101+ heights = int (sharding_spec [ 0 ])
102+ widths = int (sharding_spec [ 1 ])
101103 devices_len = len (device_indices_map )
102104 for i in range (devices_len ):
103105 slices .setdefault ((i // widths , i % widths ), device_indices_map [i ])
0 commit comments