@@ -413,20 +413,28 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
413413  xla::XlaComputation xla_computation =
414414      GetValueOrThrow (b.Build (/* remove_dynamic_dimensions=*/ false ));
415415
416-   std::vector<torch::lazy::BackendDataPtr> parameters_data;
417-   parameters_data.push_back (
416+   std::vector<XLATensorPtr> tensors{XLATensor::Create (
418417      torch_xla::runtime::GetComputationClientOrDie ()->CreateDataPlaceholder (
419-           bridge::GetDefaultDevice ()->toString (), std::move (shape)));
418+           bridge::GetDefaultDevice ()->toString (), std::move (shape)))};
419+   std::vector<std::vector<int64_t >> denormalized_tile_assignments;
420+   for  (auto  tensor : tensors) {
421+     auto  sharding_spec = tensor->sharding_spec ();
422+     if  (sharding_spec) {
423+       denormalized_tile_assignments.push_back (
424+           sharding_spec->sharding .GetDenormalizedTileAssignment ());
425+     }
426+   }
420427
421428  std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
422-   instances.push_back ({std::move (xla_computation),
423-                        bridge::GetDefaultDevice ()->toString (),
424-                        {bridge::GetDefaultDevice ()->toString ()},
425-                        &shape,
426-                        /* should_wrap_parameter=*/ false ,
427-                        /* is_sharded=*/ true ,
428-                        /* allow_spmd_sharding_propagation_to_output=*/ true ,
429-                        /* parameters_data=*/  parameters_data});
429+   instances.push_back (
430+       {std::move (xla_computation),
431+        bridge::GetDefaultDevice ()->toString (),
432+        {bridge::GetDefaultDevice ()->toString ()},
433+        &shape,
434+        /* should_wrap_parameter=*/ false ,
435+        /* is_sharded=*/ true ,
436+        /* allow_spmd_sharding_propagation_to_output=*/ true ,
437+        /* denormalized_tile_assignments=*/  denormalized_tile_assignments});
430438
431439  std::vector<
432440      std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
@@ -437,9 +445,6 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
437445          " add"  , std::move (computations[0 ]->move_computation ()));
438446
439447  //  Prepare output sharding propagation, expect a sharded output placeholder.
440-   std::vector<XLATensorPtr> tensors{XLATensor::Create (
441-       torch_xla::runtime::GetComputationClientOrDie ()->CreateDataPlaceholder (
442-           bridge::GetDefaultDevice ()->toString (), std::move (shape)))};
443448  std::vector<torch::lazy::BackendDataPtr> data_placeholders;
444449  std::vector<XLATensor::ShardingSpecPtr> sharding_specs;
445450  ShardingUtil::PrepareOutputShardingPropagation (
0 commit comments