Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
3e683dc
Add some more functions to OpSharding pybinding
hshahTT Jun 19, 2025
9d21f15
Add some log statements
hshahTT Jun 22, 2025
bcbbc92
Merge branch 'master' of https://github.com/pytorch/xla into hshah/op…
hshahTT Jun 22, 2025
99be1d9
Add some working code for V2 sharding for TILED types (need to clean up)
hshahTT Jun 23, 2025
9331e24
Add some code to handle Partial sharding in V2 format (doesn't work yet)
hshahTT Jun 24, 2025
af5cd2d
Code mostly works now, some partition specs with tuples are facing is…
hshahTT Jun 26, 2025
35e7fa6
Merge branch 'master' of https://github.com/pytorch/xla into hshah/op…
hshahTT Jun 26, 2025
ee0f1fb
Add gspmd->shardy pass within pjrt client
hshahTT Jun 26, 2025
fb6b25b
Merge branch 'master' of https://github.com/pytorch/xla into hshah/op…
hshahTT Jun 28, 2025
9c009ee
Merge branch 'hshah/op-sharding-v2' of github.com:hshahTT/pytorch-xla…
hshahTT Jun 28, 2025
db5e795
Merge branch 'master' of https://github.com/pytorch/xla into hshah/op…
hshahTT Jun 30, 2025
cfbd78f
Merge branch 'hshah/op-sharding-v2' of github.com:hshahTT/pytorch-xla…
hshahTT Jun 30, 2025
582d7cc
Remove previous v2 implementations
hshahTT Jun 30, 2025
a3d3297
Merge branch 'hshah/opsharding-v2' of github.com:tenstorrent/pytorch-…
hshahTT Jun 30, 2025
cba7922
Add some more functions to OpSharding pybinding
hshahTT Jun 19, 2025
39cf5f1
Add some log statements
hshahTT Jun 22, 2025
d20c68a
Add some working code for V2 sharding for TILED types (need to clean up)
hshahTT Jun 23, 2025
4c47abd
Add some code to handle Partial sharding in V2 format (doesn't work yet)
hshahTT Jun 24, 2025
0456513
Code mostly works now, some partition specs with tuples are facing is…
hshahTT Jun 26, 2025
18da0db
Remove previous v2 implementations
hshahTT Jun 30, 2025
c47024f
Merge branch 'hshah/opsharding-v2' of github.com:tenstorrent/pytorch-…
hshahTT Jul 12, 2025
87572da
Cleanup code
hshahTT Jul 18, 2025
4b50249
Merge branch 'master' of https://github.com/pytorch/xla into hshah/ad…
hshahTT Jul 18, 2025
ea1b945
Fix formatting
hshahTT Jul 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1504,12 +1504,19 @@ void InitXlaModuleBindings(py::module m) {

// Define the _XLAC.OpSharding class.
PythonScope<py::class_<xla::OpSharding>>(m, "OpSharding")
// Constructor for V1 shardings
.def_init([](const py::list& tile_assignment,
const py::list& group_assignment,
const py::list& replication_groups, int sharding_type) {
return ShardingUtil::CreateOpSharding(
tile_assignment, group_assignment, replication_groups,
ShardingUtil::ShardingType(sharding_type));
})
// Constructor for V2 shardings.
.def_init([](const py::list& dims, const py::list& reshape_dims,
const py::list& transpose_perm) {
return ShardingUtil::CreateIotaOpSharding(dims, reshape_dims,
transpose_perm);
});

// Define the _XLAC.PjRtPlugin class.
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ cc_library(
"@xla//xla/mlir_hlo:all_passes",
"@xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo",
"@xla//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo",
"@xla//xla/service/spmd/shardy/stablehlo_round_trip:stablehlo_import",
],
)

Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/pjrt_registry.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/util.h"
Expand Down Expand Up @@ -641,6 +642,9 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
mlir::ModuleOp mlir_module =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
ConvertHloToStableHlo(instance.computation.mutable_proto(), &mlir_module);
if (runtime::sys_util::GetEnvBool("CONVERT_SHLO_TO_SHARDY", false)) {
ConvertStableHloToSdy(&mlir_module);
}
executable = util::RaisePythonValueErrorOnFailure([&] {
return fake_xla_compile_
? fake_xla_compile_()
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/runtime/stablehlo_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h"
#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_import.h"

namespace torch_xla {

Expand Down Expand Up @@ -89,6 +90,7 @@ static absl::Status mhloToStablehloHelper(mlir::ModuleOp* mlir_module,
torch_xla::runtime::CreateRemoveXlaMarkTensorOpsPass());
pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
pm.addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());

if (!mlir::succeeded(pm.run(*mlir_module))) {
return absl::Status(
absl::StatusCode::kInternal,
Expand All @@ -111,6 +113,14 @@ void ConvertHloToStableHlo(const xla::HloModuleProto* proto,
<< getHloModuleStr(proto);
}

void ConvertStableHloToSdy(mlir::ModuleOp* mlir_module) {
mlir::PassManager pm(mlir_module->getContext());
xla::sdy::addStablehloImportPipeline(pm, false, false);
if (!mlir::succeeded(pm.run(*mlir_module))) {
XLA_ERROR() << "StableHLO -> SDY conversion failed.\n";
}
}

std::string hloToStablehlo(const xla::HloModuleProto* proto,
bool emit_bytecode) {
mlir::MLIRContext context;
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/runtime/stablehlo_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ namespace torch_xla {
std::string hloToStablehlo(const xla::HloModuleProto* proto,
bool emit_bytecode);

void ConvertStableHloToSdy(mlir::ModuleOp* mlir_module);

void ConvertHloToStableHlo(const xla::HloModuleProto* proto,
mlir::ModuleOp* mlir_module);

Expand Down
17 changes: 17 additions & 0 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,23 @@ bool ShardingUtil::EqualOpShardings(const xla::OpSharding& a,
return xla::protobuf_util::HaveSameSerialization(a, b);
}

xla::OpSharding ShardingUtil::CreateIotaOpSharding(
const py::list& dims, const py::list& reshape_dims,
const py::list& transpose_perm) {
auto dims_vec = dims.cast<std::vector<int64_t>>();
auto reshape_dims_vec = reshape_dims.cast<std::vector<int64_t>>();
auto transpose_perm_vec = transpose_perm.cast<std::vector<int>>();
std::vector<xla::OpSharding::Type> subgroup_types;
if (dims_vec.size() > transpose_perm.size()) {
subgroup_types.push_back(xla::OpSharding::REPLICATED);
}
return xla::HloSharding::Subgroup(
xla::TileAssignment(dims_vec, reshape_dims_vec,
transpose_perm_vec),
subgroup_types)
.ToProto();
}

xla::OpSharding ShardingUtil::CreateOpSharding(
const py::list& tile_assignment, const py::list& group_assignment,
const py::list& replication_groups, ShardingType sharding_type) {
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ class ShardingUtil {
const py::list& group_assignment,
const py::list& replication_groups,
ShardingType sharding_type);
// Creates an xla::OpSharding for TILED and PARTIAL types using the
// HloShardingV2 system.
static xla::OpSharding CreateIotaOpSharding(const py::list& dims,
const py::list& reshape_dims,
const py::list& transpose_perm);

// Returns the shape of the resulting shards of `tensor` after applying
// `sharding`. This assumes the shards will be padded to ensure they all
Expand Down
67 changes: 66 additions & 1 deletion torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
from collections.abc import Generator, MutableMapping
import math
import os
from collections import OrderedDict, defaultdict
from dataclasses import dataclass, field
import torch
Expand Down Expand Up @@ -118,9 +119,18 @@ def get_axis_name_idx(self, name: str) -> int:
return None
return self.axis_names.index(name)

def _validate_translated_partition_spec(self, partition_spec: tuple):
flat_specs = np.hstack([d for d in partition_spec])
specs = [d for d in flat_specs if d is not None]
assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \
f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape."
assert len(specs) == len(np.unique(specs)), \
f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."

@functools.lru_cache(maxsize=None)
def _get_op_sharding_args(self, partition_spec: PartitionSpec):
partition_spec = _translate_named_partition_spec(self, partition_spec)
self._validate_translated_partition_spec(partition_spec)
flat_specs = np.hstack([d for d in partition_spec])
specs = [d for d in flat_specs if d is not None]
assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \
Expand All @@ -142,6 +152,57 @@ def _get_op_sharding_args(self, partition_spec: PartitionSpec):
sharding_type = int(sharding_type)
return tile_assignment, group_assignment, replication_groups, sharding_type

@functools.lru_cache(maxsize=None)
def _get_op_sharding_args_v2(self, partition_spec: PartitionSpec):
"""
Returns the appropriate dims, reshape_dims, and transpose_perm for the given partition spec.
"""
partition_spec = _translate_named_partition_spec(self, partition_spec)
self._validate_translated_partition_spec(partition_spec)

dims = []
used_axes = OrderedDict()
for axis in partition_spec:
if isinstance(axis, tuple):
dim_size = 1
for i in axis:
assert i is not None, "None not allowed within tuple"
dim_size *= self.mesh_shape[i]
used_axes[i] = True
dims.append(dim_size)
elif axis is not None:
assert isinstance(axis, int), "Axis must be an int or a tuple of ints"
dims.append(self.mesh_shape[axis])
used_axes[axis] = True
else:
# Replicated mesh axis
dims.append(1)

transpose_perm = [k for k in used_axes.keys()]
for i in range(len(self.mesh_shape)):
if i not in used_axes:
dims.append(self.mesh_shape[i])
transpose_perm.append(i)
reshape_dims = list(self.mesh_shape)

return dims, reshape_dims, transpose_perm

@functools.lru_cache(maxsize=None)
def get_op_sharding_v2(
self, partition_spec: PartitionSpec) -> torch_xla._XLAC.OpSharding:
"""
Return the OpSharding for the given partition spec using V2 annotations.
"""
if len(partition_spec) == 0:
return torch_xla._XLAC.OpSharding([], [], [], ShardingType.REPLICATED)
sharding_type = _get_sharding_type(partition_spec, self.size())
if sharding_type not in (ShardingType.TILED, ShardingType.PARTIAL):
return torch_xla._XLAC.OpSharding([], [], [0], sharding_type)

dims, reshape_dims, transpose_perm = self._get_op_sharding_args_v2(
partition_spec)
return torch_xla._XLAC.OpSharding(dims, reshape_dims, transpose_perm)

@functools.lru_cache(maxsize=None)
def get_op_sharding(
self, partition_spec: PartitionSpec) -> torch_xla._XLAC.OpSharding:
Expand All @@ -157,6 +218,7 @@ def get_op_sharding(

tile_assignment, group_assignment, replication_groups, sharding_type = self._get_op_sharding_args(
partition_spec)

return torch_xla._XLAC.OpSharding(tile_assignment, group_assignment,
replication_groups, sharding_type)

Expand Down Expand Up @@ -648,7 +710,10 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
t.shard_(NamedSharding(jmesh, P(*partition_spec)))
return t

op_sharding = mesh.get_op_sharding(partition_spec)
if os.environ.get('CONVERT_SHLO_TO_SHARDY', False):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be hidden under a different env var or does shardy inherently only understand V2?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shardy only understands V2. I wasn't able to get the pass working with a V1 graph, and also Kevin mentioned that V2 is a required work item for getting the Shardy pass working: pytorch#9348 (comment)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, awesome, then it can remain under the CONVERT_SHLO_TO_SHARDY flag.

op_sharding = mesh.get_op_sharding_v2(partition_spec)
else:
op_sharding = mesh.get_op_sharding(partition_spec)
annotate_func = torch_xla._XLAC._xla_mark_sharding
annotate_func(unwrap_sharded_tensor(t), op_sharding)
return wrap_as_sharded_tensor(t)
Expand Down