@@ -2029,6 +2029,68 @@ extern "C" void ifrt_sharding_to_index_domains(
2029
2029
}
2030
2030
}
2031
2031
2032
+ extern " C" bool hlo_sharding_is_tuple (xla::HloSharding *hloSharding) {
2033
+ return hloSharding->IsTuple ();
2034
+ }
2035
+
2036
+ extern " C" bool hlo_sharding_is_replicated (xla::HloSharding *hloSharding) {
2037
+ return hloSharding->IsReplicated ();
2038
+ }
2039
+
2040
+ extern " C" bool hlo_sharding_is_manual (xla::HloSharding *hloSharding) {
2041
+ return hloSharding->IsManual ();
2042
+ }
2043
+
2044
+ extern " C" bool hlo_sharding_is_unknown (xla::HloSharding *hloSharding) {
2045
+ return hloSharding->IsUnknown ();
2046
+ }
2047
+
2048
+ extern " C" bool hlo_sharding_is_tiled (xla::HloSharding *hloSharding) {
2049
+ return hloSharding->IsTiled ();
2050
+ }
2051
+
2052
+ extern " C" bool hlo_sharding_is_maximal (xla::HloSharding *hloSharding) {
2053
+ return hloSharding->IsTileMaximal ();
2054
+ }
2055
+
2056
+ extern " C" bool
2057
+ hlo_sharding_replicate_on_last_tile_dim (xla::HloSharding *hloSharding) {
2058
+ return hloSharding->ReplicateOnLastTileDim ();
2059
+ }
2060
+
2061
+ extern " C" int32_t
2062
+ hlo_sharding_tile_assignment_dimensions_size (xla::HloSharding *hloSharding) {
2063
+ return static_cast <int32_t >(hloSharding->tile_assignment ().num_dimensions ());
2064
+ }
2065
+
2066
+ extern " C" int32_t
2067
+ hlo_sharding_tile_assignment_devices_size (xla::HloSharding *hloSharding) {
2068
+ return static_cast <int32_t >(hloSharding->tile_assignment ().num_elements ());
2069
+ }
2070
+
2071
+ extern " C" void
2072
+ hlo_sharding_tile_assignment_dimensions (xla::HloSharding *hloSharding,
2073
+ int64_t *dims, int32_t size) {
2074
+ auto tileAssignmentDims = hloSharding->tile_assignment ().dimensions ();
2075
+ for (int32_t i = 0 ; i < size; i++) {
2076
+ dims[i] = tileAssignmentDims[i];
2077
+ }
2078
+ }
2079
+
2080
+ extern " C" void
2081
+ hlo_sharding_tile_assignment_devices (xla::HloSharding *hloSharding,
2082
+ int64_t *devices, int32_t size) {
2083
+ auto tileAssignmentDevices = hloSharding->tile_assignment ().array ().data ();
2084
+ for (int32_t i = 0 ; i < size; i++) {
2085
+ devices[i] = tileAssignmentDevices[i];
2086
+ }
2087
+ }
2088
+
2089
+ extern " C" bool hlo_sharding_check_eq (xla::HloSharding *hloSharding,
2090
+ xla::HloSharding *other) {
2091
+ return *hloSharding == *other;
2092
+ }
2093
+
2032
2094
#pragma endregion
2033
2095
2034
2096
typedef ifrt::Future<> IfRtFutureType;
0 commit comments