Skip to content

Commit e446ebd

Browse files
committed
Safeguard Memory for v1.10
1 parent 68f1453 commit e446ebd

File tree

3 files changed

+111
-104
lines changed

3 files changed

+111
-104
lines changed

src/Tracing.jl

Lines changed: 70 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1811,95 +1811,97 @@ Base.@nospecializeinfer function make_tracer(
18111811
return res
18121812
end
18131813

1814-
1815-
Base.@nospecializeinfer function make_tracer(
1816-
seen,
1817-
@nospecialize(prev::Memory),
1818-
@nospecialize(path),
1819-
mode;
1820-
@nospecialize(track_numbers::Type = Union{}),
1821-
@nospecialize(sharding = Sharding.NoSharding()),
1822-
@nospecialize(runtime = nothing),
1823-
@nospecialize(device = nothing),
1824-
@nospecialize(client = nothing),
1825-
kwargs...,
1826-
)
1827-
RT = Core.Typeof(prev)
1828-
# XXX: If someone wants to shard the same array with different shardings, we need to
1829-
# somehow handle this correctly... Right now we just use the first sharding.
1830-
if mode != NoStopTracedTrack && haskey(seen, prev)
1831-
if mode == TracedToTypes
1832-
visited = seen[prev]
1833-
push!(path, visited)
1834-
return nothing
1814+
if isdefined(Base, :Memory)
1815+
Base.@nospecializeinfer function make_tracer(
1816+
seen,
1817+
@nospecialize(prev::Memory),
1818+
@nospecialize(path),
1819+
mode;
1820+
@nospecialize(track_numbers::Type = Union{}),
1821+
@nospecialize(sharding = Sharding.NoSharding()),
1822+
@nospecialize(runtime = nothing),
1823+
@nospecialize(device = nothing),
1824+
@nospecialize(client = nothing),
1825+
kwargs...,
1826+
)
1827+
RT = Core.Typeof(prev)
1828+
# XXX: If someone wants to shard the same array with different shardings, we need to
1829+
# somehow handle this correctly... Right now we just use the first sharding.
1830+
if mode != NoStopTracedTrack && haskey(seen, prev)
1831+
if mode == TracedToTypes
1832+
visited = seen[prev]
1833+
push!(path, visited)
1834+
return nothing
1835+
end
1836+
return seen[prev]
18351837
end
1836-
return seen[prev]
1837-
end
1838-
if eltype(RT) <: ReactantPrimitive
1839-
if mode == ArrayToConcrete
1840-
runtime isa Val{:PJRT} &&
1841-
(return seen[prev] = ConcretePJRTArray(prev; sharding, device, client))
1842-
runtime isa Val{:IFRT} &&
1843-
(return seen[prev] = ConcreteIFRTArray(prev; sharding, device, client))
1844-
error("Unsupported runtime $runtime")
1838+
if eltype(RT) <: ReactantPrimitive
1839+
if mode == ArrayToConcrete
1840+
runtime isa Val{:PJRT} &&
1841+
(return seen[prev] = ConcretePJRTArray(prev; sharding, device, client))
1842+
runtime isa Val{:IFRT} &&
1843+
(return seen[prev] = ConcreteIFRTArray(prev; sharding, device, client))
1844+
error("Unsupported runtime $runtime")
1845+
elseif mode == TracedToTypes
1846+
# Original array can get mutated so we store a copy:
1847+
push!(path, copy(prev))
1848+
seen[prev] = VisitedObject(length(seen) + 1)
1849+
return nothing
1850+
end
18451851
elseif mode == TracedToTypes
1846-
# Original array can get mutated so we store a copy:
1847-
push!(path, copy(prev))
1848-
seen[prev] = VisitedObject(length(seen) + 1)
1852+
push!(path, RT)
1853+
for I in eachindex(prev)
1854+
if isassigned(prev, I)
1855+
pv = prev[I]
1856+
make_tracer(
1857+
seen,
1858+
pv,
1859+
path,
1860+
mode;
1861+
track_numbers,
1862+
sharding,
1863+
runtime,
1864+
device,
1865+
client,
1866+
kwargs...,
1867+
)
1868+
end
1869+
end
18491870
return nothing
18501871
end
1851-
elseif mode == TracedToTypes
1852-
push!(path, RT)
1872+
TT = traced_type(eltype(RT), Val(mode), track_numbers, sharding, runtime)
1873+
newa = Array{TT,ndims(RT)}(undef, size(prev))
1874+
seen[prev] = newa
1875+
same = true
18531876
for I in eachindex(prev)
18541877
if isassigned(prev, I)
18551878
pv = prev[I]
1856-
make_tracer(
1879+
nv = make_tracer(
18571880
seen,
18581881
pv,
1859-
path,
1882+
append_path(path, I),
18601883
mode;
18611884
track_numbers,
1862-
sharding,
1885+
sharding=Base.getproperty(sharding, I),
18631886
runtime,
18641887
device,
18651888
client,
18661889
kwargs...,
18671890
)
1891+
if pv !== nv
1892+
same = false
1893+
end
1894+
@inbounds newa[I] = nv
18681895
end
18691896
end
1870-
return nothing
1871-
end
1872-
TT = traced_type(eltype(RT), Val(mode), track_numbers, sharding, runtime)
1873-
newa = Array{TT,ndims(RT)}(undef, size(prev))
1874-
seen[prev] = newa
1875-
same = true
1876-
for I in eachindex(prev)
1877-
if isassigned(prev, I)
1878-
pv = prev[I]
1879-
nv = make_tracer(
1880-
seen,
1881-
pv,
1882-
append_path(path, I),
1883-
mode;
1884-
track_numbers,
1885-
sharding=Base.getproperty(sharding, I),
1886-
runtime,
1887-
device,
1888-
client,
1889-
kwargs...,
1890-
)
1891-
if pv !== nv
1892-
same = false
1893-
end
1894-
@inbounds newa[I] = nv
1897+
if same
1898+
seen[prev] = prev
1899+
return prev
18951900
end
1901+
return newa
18961902
end
1897-
if same
1898-
seen[prev] = prev
1899-
return prev
1900-
end
1901-
return newa
19021903
end
1904+
19031905
Base.@nospecializeinfer function make_tracer(
19041906
seen,
19051907
@nospecialize(prev::Sharding.Mesh),

src/Types.jl

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -192,18 +192,20 @@ function ConcretePJRTArray(
192192
return ConcretePJRTArray{T,N,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo)
193193
end
194194

195-
function ConcretePJRTArray(
196-
data::Memory{T};
197-
client::Union{Nothing,XLA.PJRT.Client}=nothing,
198-
idx::Union{Int,Nothing}=nothing,
199-
device::Union{Nothing,XLA.PJRT.Device}=nothing,
200-
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
201-
) where {T}
202-
theclient, thedevice = _select_client_and_device(client, idx, device, sharding)
203-
sharded_data, shardinfo = sharding(theclient, thedevice, data)
204-
shape = size(data)
205-
nsharded = length(sharded_data)
206-
return ConcretePJRTArray{T,1,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo)
195+
if isdefined(Base, :Memory)
196+
function ConcretePJRTArray(
197+
data::Memory{T};
198+
client::Union{Nothing,XLA.PJRT.Client}=nothing,
199+
idx::Union{Int,Nothing}=nothing,
200+
device::Union{Nothing,XLA.PJRT.Device}=nothing,
201+
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
202+
) where {T}
203+
theclient, thedevice = _select_client_and_device(client, idx, device, sharding)
204+
sharded_data, shardinfo = sharding(theclient, thedevice, data)
205+
shape = size(data)
206+
nsharded = length(sharded_data)
207+
return ConcretePJRTArray{T,1,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo)
208+
end
207209
end
208210

209211
Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data)
@@ -334,17 +336,19 @@ function ConcreteIFRTArray(
334336
return ConcreteIFRTArray{T,N,typeof(shardinfo)}(sharded_data, shape, shardinfo, padding)
335337
end
336338

337-
function ConcreteIFRTArray(
338-
data::Memory{T};
339-
client::Union{Nothing,XLA.IFRT.Client}=nothing,
340-
idx::Union{Int,Nothing}=nothing,
341-
device::Union{Nothing,XLA.IFRT.Device}=nothing,
342-
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
343-
) where {T}
344-
theclient, thedevice = _select_client_and_device(client, idx, device, sharding)
345-
sharded_data, shardinfo, padding = sharding(theclient, nothing, data)
346-
shape = size(data)
347-
return ConcreteIFRTArray{T,1,typeof(shardinfo)}(sharded_data, shape, shardinfo)
339+
if isdefined(Base, :Memory)
340+
function ConcreteIFRTArray(
341+
data::Memory{T};
342+
client::Union{Nothing,XLA.IFRT.Client}=nothing,
343+
idx::Union{Int,Nothing}=nothing,
344+
device::Union{Nothing,XLA.IFRT.Device}=nothing,
345+
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
346+
) where {T}
347+
theclient, thedevice = _select_client_and_device(client, idx, device, sharding)
348+
sharded_data, shardinfo, padding = sharding(theclient, nothing, data)
349+
shape = size(data)
350+
return ConcreteIFRTArray{T,1,typeof(shardinfo)}(sharded_data, shape, shardinfo)
351+
end
348352
end
349353

350354
# Assemble data from multiple arrays. Needed in distributed setting where each process wont

src/xla/PJRT/Buffer.jl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,21 @@ function Buffer(client::Client, array::Array{T,N}, device::Device) where {T,N}
2121
return Buffer(buffer)
2222
end
2323

24-
25-
function Buffer(client::Client, memory::Memory{T}, device::Device) where {T}
26-
sizear = collect(Int64, reverse(size(memory)))
27-
buffer = GC.@preserve memory sizear begin
28-
@ccall MLIR.API.mlir_c.ArrayFromHostBuffer(
29-
client.client::Ptr{Cvoid},
30-
pointer(memory)::Ptr{T},
31-
XLA.primitive_type(T)::UInt64,
32-
1::Csize_t,
33-
pointer(sizear)::Ptr{Int64},
34-
device.device::Ptr{Cvoid},
35-
)::Ptr{Cvoid}
24+
if isdefined(Base, :Memory)
25+
function Buffer(client::Client, memory::Memory{T}, device::Device) where {T}
26+
sizear = collect(Int64, reverse(size(memory)))
27+
buffer = GC.@preserve memory sizear begin
28+
@ccall MLIR.API.mlir_c.ArrayFromHostBuffer(
29+
client.client::Ptr{Cvoid},
30+
pointer(memory)::Ptr{T},
31+
XLA.primitive_type(T)::UInt64,
32+
1::Csize_t,
33+
pointer(sizear)::Ptr{Int64},
34+
device.device::Ptr{Cvoid},
35+
)::Ptr{Cvoid}
36+
end
37+
return Buffer(buffer)
3638
end
37-
return Buffer(buffer)
3839
end
3940

4041
function Base.similar(a::Buffer)

0 commit comments

Comments
 (0)