Skip to content

Commit f823b86

Browse files
authored
feat: more JLL changes (#1014)
* feat: more JLL changes * Update deps/ReactantExtra/API.cpp
1 parent 464c94e commit f823b86

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,6 +2029,68 @@ extern "C" void ifrt_sharding_to_index_domains(
20292029
}
20302030
}
20312031

2032+
extern "C" bool hlo_sharding_is_tuple(xla::HloSharding *hloSharding) {
2033+
return hloSharding->IsTuple();
2034+
}
2035+
2036+
extern "C" bool hlo_sharding_is_replicated(xla::HloSharding *hloSharding) {
2037+
return hloSharding->IsReplicated();
2038+
}
2039+
2040+
extern "C" bool hlo_sharding_is_manual(xla::HloSharding *hloSharding) {
2041+
return hloSharding->IsManual();
2042+
}
2043+
2044+
extern "C" bool hlo_sharding_is_unknown(xla::HloSharding *hloSharding) {
2045+
return hloSharding->IsUnknown();
2046+
}
2047+
2048+
extern "C" bool hlo_sharding_is_tiled(xla::HloSharding *hloSharding) {
2049+
return hloSharding->IsTiled();
2050+
}
2051+
2052+
extern "C" bool hlo_sharding_is_maximal(xla::HloSharding *hloSharding) {
2053+
return hloSharding->IsTileMaximal();
2054+
}
2055+
2056+
extern "C" bool
2057+
hlo_sharding_replicate_on_last_tile_dim(xla::HloSharding *hloSharding) {
2058+
return hloSharding->ReplicateOnLastTileDim();
2059+
}
2060+
2061+
extern "C" int32_t
2062+
hlo_sharding_tile_assignment_dimensions_size(xla::HloSharding *hloSharding) {
2063+
return static_cast<int32_t>(hloSharding->tile_assignment().num_dimensions());
2064+
}
2065+
2066+
extern "C" int32_t
2067+
hlo_sharding_tile_assignment_devices_size(xla::HloSharding *hloSharding) {
2068+
return static_cast<int32_t>(hloSharding->tile_assignment().num_elements());
2069+
}
2070+
2071+
extern "C" void
2072+
hlo_sharding_tile_assignment_dimensions(xla::HloSharding *hloSharding,
2073+
int64_t *dims, int32_t size) {
2074+
auto tileAssignmentDims = hloSharding->tile_assignment().dimensions();
2075+
for (int32_t i = 0; i < size; i++) {
2076+
dims[i] = tileAssignmentDims[i];
2077+
}
2078+
}
2079+
2080+
extern "C" void
2081+
hlo_sharding_tile_assignment_devices(xla::HloSharding *hloSharding,
2082+
int64_t *devices, int32_t size) {
2083+
auto tileAssignmentDevices = hloSharding->tile_assignment().array().data();
2084+
for (int32_t i = 0; i < size; i++) {
2085+
devices[i] = tileAssignmentDevices[i];
2086+
}
2087+
}
2088+
2089+
extern "C" bool hlo_sharding_check_eq(xla::HloSharding *hloSharding,
2090+
xla::HloSharding *other) {
2091+
return *hloSharding == *other;
2092+
}
2093+
20322094
#pragma endregion
20332095

20342096
typedef ifrt::Future<> IfRtFutureType;

deps/ReactantExtra/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,7 @@ cc_library(
495495
"-Wl,-exported_symbol,_op_sharding_*",
496496
"-Wl,-exported_symbol,_hloShardingToTensorShardingAttr",
497497
"-Wl,-exported_symbol,_dump_operation",
498+
"-Wl,-exported_symbol,_hlo_sharding_*",
498499
]}),
499500
deps = [
500501
"@enzyme//:EnzymeMLIR",

0 commit comments

Comments
 (0)