Skip to content

Commit 10ffcc8

Browse files
committed
FixedSizeArray -> ConcreteArray
1 parent c4dd54e commit 10ffcc8

File tree

4 files changed

+139
-10
lines changed

4 files changed

+139
-10
lines changed

ext/ReactantFixedSizeArraysExt.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,25 @@ using Reactant: TracedRArray, TracedRNumber, Ops
66
using ReactantCore: ReactantCore
77

88
function Reactant.traced_type_inner(
9-
@nospecialize(_::Type{FixedSizeArrays.FixedSizeArray{T, N, Memory{I}}}),
9+
@nospecialize(_::Type{FixedSizeArrays.FixedSizeArrayDefault{T, N}}),
1010
seen,
1111
@nospecialize(mode::Reactant.TraceMode),
1212
@nospecialize(track_numbers::Type),
1313
@nospecialize(sharding),
1414
@nospecialize(runtime)
15-
) where {T, N, I}
15+
) where {T, N}
1616
T2 = Reactant.TracedRNumber{T}
17-
I2 = Reactant.TracedRNumber{I}
18-
return FixedSizeArrays.FixedSizeArray{T2, N, Memory{I2}}
17+
return FixedSizeArrays.FixedSizeArrayDefault{T2, N}
1918
end
2019

2120
Base.@nospecializeinfer function Reactant.make_tracer(
2221
seen,
23-
@nospecialize(prev::FixedSizeArrays.FixedSizeArray{T, N, Memory{I}}),
22+
@nospecialize(prev::FixedSizeArrays.FixedSizeArrayDefault{T, N}),
2423
@nospecialize(path),
2524
mode; kwargs...
26-
) where {T, N, I}
27-
return FixedSizeArrays.FixedSizeArray(
28-
Reactant.make_tracer(
29-
seen, parent(prev), (path..., 1), mode; kwargs..., track_numbers=Number
30-
)
25+
) where {T, N}
26+
return Reactant.make_tracer(
27+
seen, parent(prev), (path..., 1), mode; kwargs..., track_numbers=Number
3128
)
3229
end
3330

src/Tracing.jl

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1811,6 +1811,95 @@ Base.@nospecializeinfer function make_tracer(
18111811
return res
18121812
end
18131813

1814+
1815+
Base.@nospecializeinfer function make_tracer(
1816+
seen,
1817+
@nospecialize(prev::Memory),
1818+
@nospecialize(path),
1819+
mode;
1820+
@nospecialize(track_numbers::Type = Union{}),
1821+
@nospecialize(sharding = Sharding.NoSharding()),
1822+
@nospecialize(runtime = nothing),
1823+
@nospecialize(device = nothing),
1824+
@nospecialize(client = nothing),
1825+
kwargs...,
1826+
)
1827+
RT = Core.Typeof(prev)
1828+
# XXX: If someone wants to shard the same array with different shardings, we need to
1829+
# somehow handle this correctly... Right now we just use the first sharding.
1830+
if mode != NoStopTracedTrack && haskey(seen, prev)
1831+
if mode == TracedToTypes
1832+
visited = seen[prev]
1833+
push!(path, visited)
1834+
return nothing
1835+
end
1836+
return seen[prev]
1837+
end
1838+
if eltype(RT) <: ReactantPrimitive
1839+
if mode == ArrayToConcrete
1840+
runtime isa Val{:PJRT} &&
1841+
(return seen[prev] = ConcretePJRTArray(prev; sharding, device, client))
1842+
runtime isa Val{:IFRT} &&
1843+
(return seen[prev] = ConcreteIFRTArray(prev; sharding, device, client))
1844+
error("Unsupported runtime $runtime")
1845+
elseif mode == TracedToTypes
1846+
# Original array can get mutated so we store a copy:
1847+
push!(path, copy(prev))
1848+
seen[prev] = VisitedObject(length(seen) + 1)
1849+
return nothing
1850+
end
1851+
elseif mode == TracedToTypes
1852+
push!(path, RT)
1853+
for I in eachindex(prev)
1854+
if isassigned(prev, I)
1855+
pv = prev[I]
1856+
make_tracer(
1857+
seen,
1858+
pv,
1859+
path,
1860+
mode;
1861+
track_numbers,
1862+
sharding,
1863+
runtime,
1864+
device,
1865+
client,
1866+
kwargs...,
1867+
)
1868+
end
1869+
end
1870+
return nothing
1871+
end
1872+
TT = traced_type(eltype(RT), Val(mode), track_numbers, sharding, runtime)
1873+
newa = Array{TT,ndims(RT)}(undef, size(prev))
1874+
seen[prev] = newa
1875+
same = true
1876+
for I in eachindex(prev)
1877+
if isassigned(prev, I)
1878+
pv = prev[I]
1879+
nv = make_tracer(
1880+
seen,
1881+
pv,
1882+
append_path(path, I),
1883+
mode;
1884+
track_numbers,
1885+
sharding=Base.getproperty(sharding, I),
1886+
runtime,
1887+
device,
1888+
client,
1889+
kwargs...,
1890+
)
1891+
if pv !== nv
1892+
same = false
1893+
end
1894+
@inbounds newa[I] = nv
1895+
end
1896+
end
1897+
if same
1898+
seen[prev] = prev
1899+
return prev
1900+
end
1901+
return newa
1902+
end
18141903
Base.@nospecializeinfer function make_tracer(
18151904
seen,
18161905
@nospecialize(prev::Sharding.Mesh),

src/Types.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,20 @@ function ConcretePJRTArray(
192192
return ConcretePJRTArray{T,N,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo)
193193
end
194194

195+
function ConcretePJRTArray(
196+
data::Memory{T};
197+
client::Union{Nothing,XLA.PJRT.Client}=nothing,
198+
idx::Union{Int,Nothing}=nothing,
199+
device::Union{Nothing,XLA.PJRT.Device}=nothing,
200+
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
201+
) where {T}
202+
theclient, thedevice = _select_client_and_device(client, idx, device, sharding)
203+
sharded_data, shardinfo = sharding(theclient, thedevice, data)
204+
shape = size(data)
205+
nsharded = length(sharded_data)
206+
return ConcretePJRTArray{T,1,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo)
207+
end
208+
195209
Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data)
196210
XLA.client(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = XLA.client(x.data)
197211
function XLA.device(x::Union{ConcretePJRTArray,ConcretePJRTNumber})
@@ -320,6 +334,19 @@ function ConcreteIFRTArray(
320334
return ConcreteIFRTArray{T,N,typeof(shardinfo)}(sharded_data, shape, shardinfo, padding)
321335
end
322336

337+
function ConcreteIFRTArray(
338+
data::Memory{T};
339+
client::Union{Nothing,XLA.IFRT.Client}=nothing,
340+
idx::Union{Int,Nothing}=nothing,
341+
device::Union{Nothing,XLA.IFRT.Device}=nothing,
342+
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
343+
) where {T}
344+
theclient, thedevice = _select_client_and_device(client, idx, device, sharding)
345+
sharded_data, shardinfo, padding = sharding(theclient, nothing, data)
346+
shape = size(data)
347+
return ConcreteIFRTArray{T,1,typeof(shardinfo)}(sharded_data, shape, shardinfo)
348+
end
349+
323350
# Assemble data from multiple arrays. Needed in distributed setting where each process wont
324351
# have enough host memory to hold all the arrays. We assume that the data is only provided
325352
# for all of the addressable devices.

src/xla/PJRT/Buffer.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,22 @@ function Buffer(client::Client, array::Array{T,N}, device::Device) where {T,N}
2121
return Buffer(buffer)
2222
end
2323

24+
25+
function Buffer(client::Client, memory::Memory{T}, device::Device) where {T}
26+
sizear = collect(Int64, reverse(size(memory)))
27+
buffer = GC.@preserve memory sizear begin
28+
@ccall MLIR.API.mlir_c.ArrayFromHostBuffer(
29+
client.client::Ptr{Cvoid},
30+
pointer(memory)::Ptr{T},
31+
XLA.primitive_type(T)::UInt64,
32+
1::Csize_t,
33+
pointer(sizear)::Ptr{Int64},
34+
device.device::Ptr{Cvoid},
35+
)::Ptr{Cvoid}
36+
end
37+
return Buffer(buffer)
38+
end
39+
2440
function Base.similar(a::Buffer)
2541
buffer = GC.@preserve a begin
2642
@ccall MLIR.API.mlir_c.UninitPJRTBuffer(

0 commit comments

Comments
 (0)