1
- # xla::ifrt::HloSharding (distinct from xla::HloSharding)
2
- mutable struct HloSharding
1
+ mutable struct Sharding
3
2
ptr:: Ptr{Cvoid}
4
3
5
- function HloSharding (ptr:: Ptr{Cvoid} )
4
+ function Sharding (ptr:: Ptr{Cvoid} )
6
5
@assert ptr != C_NULL
7
- # return finalizer(free_hlo_sharding, new(ptr))
8
- return new (ptr)
9
- end
10
- end
11
-
12
- function free_hlo_sharding (hlo_sharding:: HloSharding )
13
- @ccall MLIR. API. mlir_c. free_ifrt_hlo_sharding (hlo_sharding. ptr:: Ptr{Cvoid} ):: Cvoid
14
- end
15
-
16
- function Base. convert (:: Type{XLA.HloSharding} , sharding:: HloSharding )
17
- GC. @preserve sharding begin
18
- return XLA. HloSharding (
19
- @ccall MLIR. API. mlir_c. ifrt_hlo_sharding_to_xla_hlo_sharding (
20
- sharding. ptr:: Ptr{Cvoid}
21
- ):: Ptr{Cvoid}
22
- )
6
+ return finalizer (free_sharding, new (ptr))
23
7
end
24
8
end
25
9
26
- function HloSharding (
27
- device_list:: AbstractVector{<:Device} , xla_hlo_sharding:: XLA.HloSharding
28
- )
10
+ function Sharding (device_list:: AbstractVector{<:Device} , xla_hlo_sharding:: XLA.HloSharding )
29
11
addressable_devices = filter (XLA. is_addressable, device_list)
30
12
default_memory_kind = convert (MemoryKind, XLA. default_memory (addressable_devices))
31
- return HloSharding (device_list, xla_hlo_sharding, default_memory_kind)
32
- end
33
-
34
- function HloSharding (
35
- device_list:: AbstractVector{<:Device} ,
36
- xla_hlo_sharding:: XLA.HloSharding ,
37
- memoy_kind:: AbstractString ,
38
- )
39
- return HloSharding (device_list, xla_hlo_sharding, MemoryKind (memoy_kind))
13
+ return Sharding (device_list, xla_hlo_sharding, default_memory_kind)
40
14
end
41
15
42
- function HloSharding (
16
+ function Sharding (
43
17
device_list:: AbstractVector{<:Device} ,
44
18
xla_hlo_sharding:: XLA.HloSharding ,
45
- memory_kind:: MemoryKind ,
19
+ memory_kind:: Union{AbstractString, MemoryKind} ,
46
20
)
21
+ memory_kind isa AbstractString && (memory_kind = MemoryKind (memory_kind))
47
22
client = XLA. client (device_list)
48
23
GC. @preserve device_list memory_kind xla_hlo_sharding client begin
49
- return HloSharding (
50
- @ccall MLIR. API. mlir_c. ifrt_hlo_sharding_from_xla_hlo_sharding (
24
+ return Sharding (
25
+ @ccall MLIR. API. mlir_c. ifrt_sharding_from_xla_hlo_sharding (
51
26
client. client:: Ptr{Cvoid} ,
52
27
[d. device for d in device_list]:: Ptr{Ptr{Cvoid}} ,
53
28
length (device_list):: Int32 ,
@@ -58,87 +33,39 @@ function HloSharding(
58
33
end
59
34
end
60
35
61
- function Base. string (hlo_sharding:: HloSharding )
62
- GC. @preserve hlo_sharding begin
63
- str = @ccall MLIR. API. mlir_c. ifrt_hlo_sharding_to_string (
64
- hlo_sharding. ptr:: Ptr{Cvoid}
65
- ):: Cstring
66
- end
67
- return XLA. unsafe_string_and_free (str)
68
- end
69
-
70
- function Base. show (io:: IO , :: MIME"text/plain" , hlo_sharding:: HloSharding )
71
- print (io, " XLA.IFRT.HloSharding(\" " , string (hlo_sharding), " \" )" )
72
- return nothing
73
- end
74
-
75
- # HloSharding is more specific than Sharding. But Sharding is a neater way to deal with
76
- # most of the IFRT APIs.
77
- mutable struct Sharding
78
- ptr:: Ptr{Cvoid}
79
-
80
- function Sharding (ptr:: Ptr{Cvoid} )
81
- @assert ptr != C_NULL
82
- # return finalizer(free_sharding, new(ptr))
83
- return new (ptr)
84
- end
85
- end
86
-
87
- function Sharding (device_list:: AbstractVector{<:Device} , xla_hlo_sharding:: XLA.HloSharding )
88
- return convert (Sharding, HloSharding (device_list, xla_hlo_sharding))
89
- end
90
-
91
- function Sharding (
92
- device_list:: AbstractVector{<:Device} ,
93
- xla_hlo_sharding:: XLA.HloSharding ,
94
- memory_kind:: Union{AbstractString,MemoryKind} ,
95
- )
96
- return convert (Sharding, HloSharding (device_list, xla_hlo_sharding, memory_kind))
97
- end
98
-
99
36
function free_sharding (sharding:: Sharding )
100
37
@ccall MLIR. API. mlir_c. free_ifrt_sharding (sharding. ptr:: Ptr{Cvoid} ):: Cvoid
101
38
end
102
39
103
- function XLA. devices (sharding:: Sharding )
40
+ function XLA. num_devices (sharding:: Sharding )
104
41
GC. @preserve sharding begin
105
- ndevices = @ccall MLIR. API. mlir_c. ifrt_sharding_devices_size (
42
+ return @ccall MLIR. API. mlir_c. ifrt_sharding_devices_size (
106
43
sharding. ptr:: Ptr{Cvoid}
107
44
):: Int32
108
45
end
46
+ end
47
+
48
+ function XLA. devices (sharding:: Sharding )
49
+ ndevices = XLA. num_devices (sharding)
109
50
devices = Ref {NTuple{Int64(ndevices),Ptr{Cvoid}}} ()
110
51
GC. @preserve sharding devices begin
111
52
@ccall MLIR. API. mlir_c. ifrt_sharding_to_device_list (
112
53
sharding. ptr:: Ptr{Cvoid} , devices:: Ptr{Ptr{Cvoid}}
113
54
):: Cvoid
114
55
end
115
- return [ Device (device) for device in devices[]]
56
+ return map (Device, devices[])
116
57
end
117
58
118
- function Base. convert (:: Type{Sharding} , hlo_sharding:: HloSharding )
119
- GC. @preserve hlo_sharding begin
120
- return Sharding (
121
- @ccall MLIR. API. mlir_c. ifrt_sharding_from_ifrt_hlo_sharding (
122
- hlo_sharding. ptr:: Ptr{Cvoid}
123
- ):: Ptr{Cvoid}
124
- )
125
- end
126
- end
127
-
128
- function Base. convert (:: Type{HloSharding} , sharding:: Sharding )
59
+ function Base. convert (:: Type{XLA.HloSharding} , sharding:: Sharding )
129
60
GC. @preserve sharding begin
130
- return HloSharding (
131
- @ccall MLIR. API. mlir_c. ifrt_sharding_to_ifrt_hlo_sharding (
61
+ return XLA . HloSharding (
62
+ @ccall MLIR. API. mlir_c. ifrt_sharding_to_xla_hlo_sharding (
132
63
sharding. ptr:: Ptr{Cvoid}
133
64
):: Ptr{Cvoid}
134
65
)
135
66
end
136
67
end
137
68
138
- function Base. convert (:: Type{XLA.HloSharding} , sharding:: Sharding )
139
- return convert (XLA. HloSharding, convert (HloSharding, sharding))
140
- end
141
-
142
69
function Base. string (sharding:: Sharding )
143
70
GC. @preserve sharding begin
144
71
str = @ccall MLIR. API. mlir_c. ifrt_sharding_to_string (
0 commit comments