Skip to content

Commit 454822c

Browse files
authored
feat: FillArrays support (#1568)
* feat: FillArrays support * fix: Ones/Zeros
1 parent 5d2f6d8 commit 454822c

File tree

9 files changed

+170
-3
lines changed

9 files changed

+170
-3
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
3232
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3333
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3434
DLFP8Types = "f4c16678-4a16-415b-82ef-ed337c5d6c7c"
35+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
3536
Float8s = "81dfefd7-55b0-40c6-a251-db853704e186"
3637
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
3738
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
@@ -54,6 +55,7 @@ ReactantAbstractFFTsExt = "AbstractFFTs"
5455
ReactantArrayInterfaceExt = "ArrayInterface"
5556
ReactantCUDAExt = ["CUDA", "GPUCompiler", "KernelAbstractions", "LLVM"]
5657
ReactantDLFP8TypesExt = "DLFP8Types"
58+
ReactantFillArraysExt = "FillArrays"
5759
ReactantFloat8sExt = "Float8s"
5860
ReactantKernelAbstractionsExt = "KernelAbstractions"
5961
ReactantMPIExt = "MPI"
@@ -77,6 +79,7 @@ Downloads = "1.6"
7779
EnumX = "1"
7880
Enzyme = "0.13.49"
7981
EnzymeCore = "0.8.11"
82+
FillArrays = "1.13"
8083
Float8s = "0.1"
8184
Functors = "0.5"
8285
GPUArraysCore = "0.2"
@@ -103,9 +106,9 @@ Scratch = "1.2"
103106
Sockets = "1.10"
104107
SpecialFunctions = "2.4"
105108
Statistics = "1.10"
106-
unzip_jll = "6"
107109
YaoBlocks = "0.13, 0.14"
108110
julia = "1.10"
111+
unzip_jll = "6"
109112

110113
[extras]
111114
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

ext/ReactantFillArraysExt.jl

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
module ReactantFillArraysExt
2+
3+
using Reactant: Reactant, TracedUtils, TracedRNumber, Ops, Sharding, unwrapped_eltype
4+
using ReactantCore: ReactantCore
5+
using FillArrays: FillArrays, AbstractFill, Fill, Ones, Zeros, OneElement
6+
using GPUArraysCore: @allowscalar
7+
8+
# Tracing
9+
Reactant._parent_type(T::Type{<:AbstractFill}) = T
10+
Reactant._parent_type(T::Type{<:OneElement}) = T
11+
12+
for AT in (Fill, Ones, Zeros)
13+
@eval Base.@nospecializeinfer function Reactant.traced_type_inner(
14+
@nospecialize(FA::Type{$(AT){T,N,Axes}}),
15+
seen,
16+
mode::Reactant.TraceMode,
17+
@nospecialize(track_numbers::Type),
18+
@nospecialize(sharding),
19+
@nospecialize(runtime)
20+
) where {T,N,Axes}
21+
# T will be a number so we need to trace it
22+
return $(AT){
23+
Reactant.traced_type_inner(T, seen, mode, Number, sharding, runtime),N,Axes
24+
}
25+
end
26+
end
27+
28+
Base.@nospecializeinfer function Reactant.make_tracer(
29+
seen, @nospecialize(prev::Fill{T,N,Axes}), @nospecialize(path), mode; kwargs...
30+
) where {T,N,Axes}
31+
return Fill(
32+
Reactant.make_tracer(
33+
seen, prev.value, (path..., 1), mode; kwargs..., track_numbers=Number
34+
),
35+
prev.axes,
36+
)
37+
end
38+
39+
Base.@nospecializeinfer function Reactant.make_tracer(
40+
seen,
41+
@nospecialize(prev::Ones{T,N,Axes}),
42+
@nospecialize(path),
43+
mode;
44+
@nospecialize(sharding = Sharding.NoSharding()),
45+
@nospecialize(runtime = nothing),
46+
kwargs...,
47+
) where {T,N,Axes}
48+
return Ones(
49+
Reactant.traced_type_inner(T, seen, mode, Number, sharding, runtime), prev.axes
50+
)
51+
end
52+
53+
Base.@nospecializeinfer function Reactant.make_tracer(
54+
seen,
55+
@nospecialize(prev::Zeros{T,N,Axes}),
56+
@nospecialize(path),
57+
mode;
58+
@nospecialize(sharding = Sharding.NoSharding()),
59+
@nospecialize(runtime = nothing),
60+
kwargs...,
61+
) where {T,N,Axes}
62+
return Zeros(
63+
Reactant.traced_type_inner(T, seen, mode, Number, sharding, runtime), prev.axes
64+
)
65+
end
66+
67+
Base.@nospecializeinfer function Reactant.traced_type_inner(
68+
@nospecialize(FA::Type{OneElement{T,N,I,A}}),
69+
seen,
70+
mode::Reactant.TraceMode,
71+
@nospecialize(track_numbers::Type),
72+
@nospecialize(sharding),
73+
@nospecialize(runtime)
74+
) where {T,N,I,A}
75+
# T will be a number so we need to trace it
76+
return OneElement{
77+
Reactant.traced_type_inner(T, seen, mode, Number, sharding, runtime),N,I,A
78+
}
79+
end
80+
81+
Base.@nospecializeinfer function Reactant.make_tracer(
82+
seen, @nospecialize(prev::OneElement{T,N,I,A}), @nospecialize(path), mode; kwargs...
83+
) where {T,N,I,A}
84+
return OneElement(
85+
Reactant.make_tracer(
86+
seen, prev.val, (path..., 1), mode; kwargs..., track_numbers=Number
87+
),
88+
prev.ind,
89+
prev.axes,
90+
)
91+
end
92+
93+
# Materialize into a dense array
94+
function ReactantCore.materialize_traced_array(x::Fill{T}) where {T}
95+
return TracedUtils.broadcast_to_size(
96+
TracedUtils.promote_to(TracedRNumber{unwrapped_eltype(T)}, x.value), size(x)
97+
)
98+
end
99+
100+
function ReactantCore.materialize_traced_array(x::Ones{T}) where {T}
101+
return TracedUtils.broadcast_to_size(unwrapped_eltype(T)(1), size(x))
102+
end
103+
104+
function ReactantCore.materialize_traced_array(x::Zeros{T}) where {T}
105+
return TracedUtils.broadcast_to_size(unwrapped_eltype(T)(0), size(x))
106+
end
107+
108+
function ReactantCore.materialize_traced_array(x::OneElement{T}) where {T}
109+
y = TracedUtils.broadcast_to_size(unwrapped_eltype(T)(0), size(x))
110+
@allowscalar setindex!(y, x.val, x.ind...)
111+
return y
112+
end
113+
114+
# some functions to avoid bad performance
115+
for AT in (Fill, Ones, Zeros, OneElement)
116+
@eval function Base.similar(x::$AT{<:TracedRNumber}, ::Type{T}, dims::Dims) where {T}
117+
return TracedUtils.broadcast_to_size(unwrapped_eltype(T)(0), dims)
118+
end
119+
end
120+
121+
end

src/Compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ function create_result(
288288
sym = Symbol("result", var_idx[])
289289
var_idx[] += 1
290290

291-
@assert haskey(result_stores, path)
291+
@assert haskey(result_stores, path) "Expected $(path) in $(keys(result_stores))"
292292
restore = result_stores[path]
293293
delete!(result_stores, path)
294294
if path_to_shard_info !== nothing && haskey(path_to_shard_info, path)

src/ConcreteRArray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ for T in Base.uniontypes(ReactantPrimitive)
8383
end
8484

8585
function Base.convert(::Type{T}, x::AbstractConcreteNumber) where {T<:Number}
86+
T == typeof(x) && return x
8687
return convert(T, to_number(x))
8788
end
8889

src/Reactant.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,20 @@ function ancestor(T::Type{<:AbstractArray})
4242
p_T == T && return T
4343
return ancestor(p_T)
4444
end
45+
if applicable(_parent_type, T)
46+
p_T = _parent_type(T)
47+
p_T == T && return T
48+
return ancestor(p_T)
49+
end
4550
@warn "`Adapt.parent_type` is not implemented for $(T). Assuming $T isn't a wrapped \
4651
array." maxlog = 1
4752
return T
4853
end
4954

55+
# A lot of packages don't define `Adapt.parent_type`. We use `_parent_type` as a way to
56+
# define the parent type of an array without type-piracy.
57+
function _parent_type end
58+
5059
include("accelerators/Accelerators.jl")
5160

5261
using .Accelerators.TPU: has_tpu

src/xla/PJRT/Buffer.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,9 @@ function XLA.buffer_on_cpu(buffer::Buffer)
167167
end
168168

169169
function XLA.to_host(buffer::Buffer, data, sharding)
170-
GC.@preserve buffer begin
170+
@assert data !== C_NULL
171+
@assert buffer.buffer !== C_NULL
172+
GC.@preserve buffer data begin
171173
@ccall MLIR.API.mlir_c.BufferToHost(
172174
buffer.buffer::Ptr{Cvoid}, data::Ptr{Cvoid}
173175
)::Cvoid

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ DLFP8Types = "f4c16678-4a16-415b-82ef-ed337c5d6c7c"
66
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
77
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
88
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
9+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
910
Float8s = "81dfefd7-55b0-40c6-a251-db853704e186"
1011
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1112
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"

test/integration/fillarrays.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using Reactant, Test, FillArrays
2+
3+
fn(x, y) = (2 .* x .- 3) * y'
4+
5+
@testset "Fill" begin
6+
x = Fill(2.0f0, 4, 5)
7+
rx = Reactant.to_rarray(x)
8+
9+
@test @jit(fn(rx, rx)) fn(x, x)
10+
11+
@testset "Ones" begin
12+
y = Ones(Float32, 4, 5)
13+
ry = Reactant.to_rarray(y)
14+
@test @jit(fn(rx, ry)) fn(x, y)
15+
end
16+
17+
@testset "Zeros" begin
18+
y = Zeros(Float32, 4, 5)
19+
ry = Reactant.to_rarray(y)
20+
@test @jit(fn(rx, ry)) fn(x, y)
21+
end
22+
end
23+
24+
@testset "OneElement" begin
25+
x = OneElement(3.4f0, (3, 4), (32, 32))
26+
rx = Reactant.to_rarray(x)
27+
28+
@test @jit(fn(rx, rx)) fn(x, x)
29+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
5050
@safetestset "Random" include("integration/random.jl")
5151
@safetestset "Python" include("integration/python.jl")
5252
@safetestset "Optimisers" include("integration/optimisers.jl")
53+
@safetestset "FillArrays" include("integration/fillarrays.jl")
5354
end
5455

5556
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"

0 commit comments

Comments
 (0)