Skip to content

Commit 658a3d7

Browse files
committed
support for N dim arrays, tests
1 parent 702ac1d commit 658a3d7

File tree

4 files changed

+35
-2
lines changed

4 files changed

+35
-2
lines changed

ext/ReactantFixedSizeArraysExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ Base.@nospecializeinfer function Reactant.make_tracer(
2424
mode;
2525
kwargs...,
2626
) where {T,N}
27-
return Reactant.make_tracer(
27+
shape = size(prev)
28+
return reshape(Reactant.make_tracer(
2829
seen, parent(prev), (path..., 1), mode; kwargs..., track_numbers=Number
29-
)
30+
), shape)
3031
end
3132

3233
end
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
2+
using Reactant, Test, FixedSizeArrays
3+
4+
fn(x, y) = (2 .* x .- 3) * y'
5+
6+
@testset "FixedSizeArrays" begin
7+
@testset "1D" begin
8+
x = FixedSizeArray(fill(3.0f0, 100))
9+
rx = Reactant.to_rarray(x)
10+
@test @jit(fn(rx, rx)) fn(x, x)
11+
end
12+
@testset "2D" begin
13+
x = FixedSizeArray(fill(3.0f0, (4,5)))
14+
rx = Reactant.to_rarray(x)
15+
@test @jit(fn(rx, rx)) fn(x, x)
16+
end
17+
end

test/memory.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using Reactant, Test
2+
3+
fn(x,y) = sin.(x) .+ cos.(y)
4+
5+
@testset "Memory test" begin
6+
x = Memory{Float32}(fill(2.0f0, 10))
7+
x_ra = Reactant.to_rarray(x)
8+
9+
@test @jit(fn(x_ra,x_ra)) fn(x,x)
10+
end
11+

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
3838
@safetestset "Config" include("config.jl")
3939
@safetestset "Batching" include("batching.jl")
4040
@safetestset "QA" include("qa.jl")
41+
if isdefined(Base, :Memory)
42+
@safetestset "Memory" include("memory.jl")
43+
end
4144
end
4245

4346
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
@@ -52,6 +55,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
5255
@safetestset "Python" include("integration/python.jl")
5356
@safetestset "Optimisers" include("integration/optimisers.jl")
5457
@safetestset "FillArrays" include("integration/fillarrays.jl")
58+
@safetestset "FixedSizeArrays" include("integration/fixedsizearrays.jl")
5559
end
5660

5761
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"

0 commit comments

Comments
 (0)