Skip to content

Commit a334a73

Browse files
authored
fix: run finalizer on array (#1076)
* fix: run finalizer on array * fix: more missing finalizer * fix: avoid double free of HloSharding * fix: remove the ifrt::HloSharding middleman * Update Project.toml
1 parent 2d20337 commit a334a73

File tree

3 files changed

+26
-101
lines changed

3 files changed

+26
-101
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.59"
4+
version = "0.2.60"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -87,7 +87,7 @@ PythonCall = "0.9"
8787
Random = "1.10"
8888
Random123 = "1.7"
8989
ReactantCore = "0.1.9"
90-
Reactant_jll = "0.0.108"
90+
Reactant_jll = "0.0.109"
9191
ScopedValues = "1.3.0"
9292
Scratch = "1.2"
9393
Sockets = "1.10"

src/xla/IFRT/Array.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ mutable struct Array <: XLA.AbstractBuffer
22
buffer::Ptr{Cvoid}
33

44
function Array(buffer::Ptr{Cvoid})
5-
# return finalizer(free_ifrt_array, new(buffer))
6-
return new(buffer)
5+
return finalizer(free_ifrt_array, new(buffer))
76
end
87
end
98

@@ -128,9 +127,8 @@ function Array(
128127
end
129128

130129
@inline function free_ifrt_array(buffer::Array)
131-
sbuffer = buffer.buffer
132-
if sbuffer != C_NULL
133-
@ccall MLIR.API.mlir_c.ifrt_free_array(sbuffer::Ptr{Cvoid})::Cvoid
130+
if buffer.buffer != C_NULL
131+
@ccall MLIR.API.mlir_c.ifrt_free_array(buffer.buffer::Ptr{Cvoid})::Cvoid
134132
end
135133
end
136134

@@ -243,7 +241,7 @@ function disassemble_into_single_device_arrays(array::Array, only_addressable_de
243241
arrays = GC.@preserve array begin
244242
@ccall MLIR.API.mlir_c.ifrt_array_disassemble_into_single_device_arrays(
245243
array.buffer::Ptr{Cvoid},
246-
Int32(0)::Int32,
244+
0::Int32,
247245
c_single_device_shard_semantics::Int32,
248246
narrays::Ptr{Int32},
249247
)::Ptr{Ptr{Cvoid}}

src/xla/IFRT/Sharding.jl

Lines changed: 20 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,28 @@
1-
# xla::ifrt::HloSharding (distinct from xla::HloSharding)
2-
mutable struct HloSharding
1+
mutable struct Sharding
32
ptr::Ptr{Cvoid}
43

5-
function HloSharding(ptr::Ptr{Cvoid})
4+
function Sharding(ptr::Ptr{Cvoid})
65
@assert ptr != C_NULL
7-
# return finalizer(free_hlo_sharding, new(ptr))
8-
return new(ptr)
9-
end
10-
end
11-
12-
function free_hlo_sharding(hlo_sharding::HloSharding)
13-
@ccall MLIR.API.mlir_c.free_ifrt_hlo_sharding(hlo_sharding.ptr::Ptr{Cvoid})::Cvoid
14-
end
15-
16-
function Base.convert(::Type{XLA.HloSharding}, sharding::HloSharding)
17-
GC.@preserve sharding begin
18-
return XLA.HloSharding(
19-
@ccall MLIR.API.mlir_c.ifrt_hlo_sharding_to_xla_hlo_sharding(
20-
sharding.ptr::Ptr{Cvoid}
21-
)::Ptr{Cvoid}
22-
)
6+
return finalizer(free_sharding, new(ptr))
237
end
248
end
259

26-
function HloSharding(
27-
device_list::AbstractVector{<:Device}, xla_hlo_sharding::XLA.HloSharding
28-
)
10+
function Sharding(device_list::AbstractVector{<:Device}, xla_hlo_sharding::XLA.HloSharding)
2911
addressable_devices = filter(XLA.is_addressable, device_list)
3012
default_memory_kind = convert(MemoryKind, XLA.default_memory(addressable_devices))
31-
return HloSharding(device_list, xla_hlo_sharding, default_memory_kind)
32-
end
33-
34-
function HloSharding(
35-
device_list::AbstractVector{<:Device},
36-
xla_hlo_sharding::XLA.HloSharding,
37-
memoy_kind::AbstractString,
38-
)
39-
return HloSharding(device_list, xla_hlo_sharding, MemoryKind(memoy_kind))
13+
return Sharding(device_list, xla_hlo_sharding, default_memory_kind)
4014
end
4115

42-
function HloSharding(
16+
function Sharding(
4317
device_list::AbstractVector{<:Device},
4418
xla_hlo_sharding::XLA.HloSharding,
45-
memory_kind::MemoryKind,
19+
memory_kind::Union{AbstractString,MemoryKind},
4620
)
21+
memory_kind isa AbstractString && (memory_kind = MemoryKind(memory_kind))
4722
client = XLA.client(device_list)
4823
GC.@preserve device_list memory_kind xla_hlo_sharding client begin
49-
return HloSharding(
50-
@ccall MLIR.API.mlir_c.ifrt_hlo_sharding_from_xla_hlo_sharding(
24+
return Sharding(
25+
@ccall MLIR.API.mlir_c.ifrt_sharding_from_xla_hlo_sharding(
5126
client.client::Ptr{Cvoid},
5227
[d.device for d in device_list]::Ptr{Ptr{Cvoid}},
5328
length(device_list)::Int32,
@@ -58,87 +33,39 @@ function HloSharding(
5833
end
5934
end
6035

61-
function Base.string(hlo_sharding::HloSharding)
62-
GC.@preserve hlo_sharding begin
63-
str = @ccall MLIR.API.mlir_c.ifrt_hlo_sharding_to_string(
64-
hlo_sharding.ptr::Ptr{Cvoid}
65-
)::Cstring
66-
end
67-
return XLA.unsafe_string_and_free(str)
68-
end
69-
70-
function Base.show(io::IO, ::MIME"text/plain", hlo_sharding::HloSharding)
71-
print(io, "XLA.IFRT.HloSharding(\"", string(hlo_sharding), "\")")
72-
return nothing
73-
end
74-
75-
# HloSharding is more specific than Sharding. But Sharding is a neater way to deal with
76-
# most of the IFRT APIs.
77-
mutable struct Sharding
78-
ptr::Ptr{Cvoid}
79-
80-
function Sharding(ptr::Ptr{Cvoid})
81-
@assert ptr != C_NULL
82-
# return finalizer(free_sharding, new(ptr))
83-
return new(ptr)
84-
end
85-
end
86-
87-
function Sharding(device_list::AbstractVector{<:Device}, xla_hlo_sharding::XLA.HloSharding)
88-
return convert(Sharding, HloSharding(device_list, xla_hlo_sharding))
89-
end
90-
91-
function Sharding(
92-
device_list::AbstractVector{<:Device},
93-
xla_hlo_sharding::XLA.HloSharding,
94-
memory_kind::Union{AbstractString,MemoryKind},
95-
)
96-
return convert(Sharding, HloSharding(device_list, xla_hlo_sharding, memory_kind))
97-
end
98-
9936
function free_sharding(sharding::Sharding)
10037
@ccall MLIR.API.mlir_c.free_ifrt_sharding(sharding.ptr::Ptr{Cvoid})::Cvoid
10138
end
10239

103-
function XLA.devices(sharding::Sharding)
40+
function XLA.num_devices(sharding::Sharding)
10441
GC.@preserve sharding begin
105-
ndevices = @ccall MLIR.API.mlir_c.ifrt_sharding_devices_size(
42+
return @ccall MLIR.API.mlir_c.ifrt_sharding_devices_size(
10643
sharding.ptr::Ptr{Cvoid}
10744
)::Int32
10845
end
46+
end
47+
48+
function XLA.devices(sharding::Sharding)
49+
ndevices = XLA.num_devices(sharding)
10950
devices = Ref{NTuple{Int64(ndevices),Ptr{Cvoid}}}()
11051
GC.@preserve sharding devices begin
11152
@ccall MLIR.API.mlir_c.ifrt_sharding_to_device_list(
11253
sharding.ptr::Ptr{Cvoid}, devices::Ptr{Ptr{Cvoid}}
11354
)::Cvoid
11455
end
115-
return [Device(device) for device in devices[]]
56+
return map(Device, devices[])
11657
end
11758

118-
function Base.convert(::Type{Sharding}, hlo_sharding::HloSharding)
119-
GC.@preserve hlo_sharding begin
120-
return Sharding(
121-
@ccall MLIR.API.mlir_c.ifrt_sharding_from_ifrt_hlo_sharding(
122-
hlo_sharding.ptr::Ptr{Cvoid}
123-
)::Ptr{Cvoid}
124-
)
125-
end
126-
end
127-
128-
function Base.convert(::Type{HloSharding}, sharding::Sharding)
59+
function Base.convert(::Type{XLA.HloSharding}, sharding::Sharding)
12960
GC.@preserve sharding begin
130-
return HloSharding(
131-
@ccall MLIR.API.mlir_c.ifrt_sharding_to_ifrt_hlo_sharding(
61+
return XLA.HloSharding(
62+
@ccall MLIR.API.mlir_c.ifrt_sharding_to_xla_hlo_sharding(
13263
sharding.ptr::Ptr{Cvoid}
13364
)::Ptr{Cvoid}
13465
)
13566
end
13667
end
13768

138-
function Base.convert(::Type{XLA.HloSharding}, sharding::Sharding)
139-
return convert(XLA.HloSharding, convert(HloSharding, sharding))
140-
end
141-
14269
function Base.string(sharding::Sharding)
14370
GC.@preserve sharding begin
14471
str = @ccall MLIR.API.mlir_c.ifrt_sharding_to_string(

0 commit comments

Comments
 (0)