@@ -51,7 +51,8 @@ TEST_F(XLAShardingTest, GetShardShape) {
5151 {2 , 3 },
5252 });
5353 auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
54- torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
54+ std::vector<int64_t > denormalized_tile_assignment = {0 , 1 , 2 , 3 };
55+ torch_xla::OpSharding sharding (xla_sharding, denormalized_tile_assignment);
5556 auto sharding_spec =
5657 std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
5758
@@ -60,7 +61,7 @@ TEST_F(XLAShardingTest, GetShardShape) {
6061 EXPECT_EQ (shard_shape, std::vector<int64_t >({4 , 4 }));
6162
6263 xla_sharding = xla::HloSharding::Replicate ().ToProto ();
63- sharding = torch_xla::OpSharding (xla_sharding, std:: nullopt );
64+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
6465 sharding_spec->sharding = sharding;
6566 shard_shape = ShardingUtil::GetShardShape (sharding_spec);
6667 // For replicated sharding, each dimension should be preserved
@@ -78,7 +79,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
7879 {2 , 3 },
7980 });
8081 auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
81- torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
82+ std::vector<int64_t > denormalized_tile_assignment = {0 , 1 , 2 , 3 };
83+ torch_xla::OpSharding sharding (xla_sharding, denormalized_tile_assignment);
8284 auto sharding_spec =
8385 std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
8486 auto shard_shape = ShardingUtil::GetShardShape (sharding_spec);
@@ -108,7 +110,7 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
108110 }
109111 }
110112 xla_sharding = xla::HloSharding::Replicate ().ToProto ();
111- sharding = torch_xla::OpSharding (xla_sharding, std:: nullopt );
113+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
112114 sharding_spec->sharding = sharding;
113115 shard_shape = ShardingUtil::GetShardShape (sharding_spec);
114116 replica_and_indices = ShardingUtil::GetShardReplicaAndIndicesForDevices (
@@ -126,6 +128,7 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
126128TEST_F (XLAShardingTest, ShardTensor) {
127129 std::vector<std::string> devices = {" TPU:0" , " TPU:1" , " TPU:2" , " TPU:3" ,
128130 " TPU:4" , " TPU:5" , " TPU:6" , " TPU:7" };
131+ std::vector<int64_t > denormalized_tile_assignment = {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 };
129132
130133 // 1D tiled
131134 at::Tensor tensor = at::ones ({8 }, at::TensorOptions (at::kFloat ));
@@ -136,7 +139,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
136139 CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ()),
137140 devices.size ())
138141 .ToProto ();
139- torch_xla::OpSharding sharding (xla_sharding, std:: nullopt );
142+ torch_xla::OpSharding sharding (xla_sharding, denormalized_tile_assignment );
140143 auto sharding_spec =
141144 std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
142145 auto shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -155,7 +158,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
155158 {4 , 5 , 6 , 7 },
156159 });
157160 xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
158- sharding = torch_xla::OpSharding (xla_sharding, std:: nullopt );
161+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
159162 sharding_spec =
160163 std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
161164 shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -168,7 +171,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
168171 // size should be smaller in dim=1 because it's not evenly divisible.
169172 xla::Array3D<int64_t > cube ({{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}});
170173 xla_sharding = xla::HloSharding::Tile (cube).ToProto ();
171- sharding = torch_xla::OpSharding (xla_sharding, std:: nullopt );
174+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
172175 sharding_spec->sharding = sharding;
173176 shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
174177 /* padded=*/ false );
@@ -178,7 +181,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
178181
179182 // Replicated, all shards should be identical.
180183 xla_sharding = xla::HloSharding::Replicate ().ToProto ();
181- sharding = torch_xla::OpSharding (xla_sharding, std:: nullopt );
184+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
182185 sharding_spec->sharding = sharding;
183186 shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
184187 /* padded=*/ false );
@@ -194,7 +197,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
194197 CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ());
195198 xla::Array4D<int64_t > tesseract ({{{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}}});
196199 xla_sharding = xla::HloSharding::Tile (tesseract).ToProto ();
197- sharding = torch_xla::OpSharding (xla_sharding, std:: nullopt );
200+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
198201 sharding_spec =
199202 std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
200203 shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -219,7 +222,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
219222 xla::Array<int64_t > hypercube (std::vector<int64_t >{1 , 1 , 2 , 2 , 2 });
220223 hypercube.FillIota (0 );
221224 xla_sharding = xla::HloSharding::Tile (hypercube).ToProto ();
222- sharding = torch_xla::OpSharding (xla_sharding, std:: nullopt );
225+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
223226 sharding_spec =
224227 std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
225228 shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -248,7 +251,8 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
248251 {6 , 7 , 2 , 3 },
249252 });
250253 auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
251- torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
254+ std::vector<int64_t > denormalized_tile_assignment = {4 , 5 , 0 , 1 , 6 , 7 , 2 , 3 };
255+ torch_xla::OpSharding sharding (xla_sharding, denormalized_tile_assignment);
252256 auto sharding_spec =
253257 std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
254258 // For devices at the start of the mesh, all shards should have the same
@@ -266,7 +270,8 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
266270 {2 , 3 , 6 , 7 },
267271 });
268272 xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
269- sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
273+ denormalized_tile_assignment = {0 , 1 , 4 , 5 , 2 , 3 , 6 , 7 };
274+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment);
270275 sharding_spec->sharding = sharding;
271276 shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
272277 /* padded=*/ false );
@@ -295,7 +300,8 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) {
295300 });
296301
297302 auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
298- torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
303+ std::vector<int64_t > denormalized_tile_assignment = {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 };
304+ torch_xla::OpSharding sharding (xla_sharding, denormalized_tile_assignment);
299305 auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
300306 sharding, global_shape, /* minibatch=*/ true );
301307 auto shards = ShardingUtil::ShardTensor (minibatch_tensor, sharding_spec,
@@ -314,14 +320,15 @@ TEST_F(XLAShardingTest, EqualShardingSpecs) {
314320 {4 , 5 , 6 , 7 },
315321 })
316322 .ToProto ();
317- torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
323+ std::vector<int64_t > denormalized_tile_assignment = {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 };
324+ torch_xla::OpSharding sharding (xla_sharding, denormalized_tile_assignment);
318325 XLATensor::ShardingSpec tiled_2d (sharding, tensor_shape);
319326 xla_sharding =
320327 xla::HloSharding::Tile ({{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}}).ToProto ();
321- sharding = torch_xla::OpSharding (xla_sharding, std:: nullopt );
328+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
322329 XLATensor::ShardingSpec tiled_3d (sharding, tensor_shape);
323330 xla_sharding = xla::HloSharding::Replicate ().ToProto ();
324- sharding = torch_xla::OpSharding (xla_sharding, std:: nullopt );
331+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
325332 XLATensor::ShardingSpec replicated (sharding, tensor_shape);
326333 EXPECT_TRUE (ShardingUtil::EqualShardingSpecs (tiled_2d, tiled_2d));
327334 EXPECT_FALSE (ShardingUtil::EqualShardingSpecs (tiled_2d, tiled_3d));
0 commit comments