Skip to content

Commit 93266c7

Browse files
authored
fix: empty buffers (#980)
* fix: empty buffers * test: empty buffer check
1 parent 2fc78b2 commit 93266c7

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

src/TracedUtils.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ function make_mlir_fn(
290290
seen_results,
291291
result,
292292
(:result,),
293-
concretein ? Reactant.TracedTrack : Reactant.TracedSetPath;
293+
concretein ? Reactant.NoStopTracedTrack : Reactant.TracedSetPath;
294294
runtime,
295295
)
296296

@@ -300,7 +300,7 @@ function make_mlir_fn(
300300
seen_results,
301301
traced_args[i],
302302
concretein ? (:resargs, i) : (),
303-
Reactant.TracedTrack;
303+
Reactant.NoStopTracedTrack;
304304
runtime,
305305
)
306306
end
@@ -311,6 +311,7 @@ function make_mlir_fn(
311311
(args_in_result != :all && has_argidx(v)) && continue
312312
push!(linear_results, v)
313313
end
314+
314315
if args_in_result == :mutated
315316
append!(linear_results, linear_args[mutated_args])
316317
end

test/integration/optimisers.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using Test, Reactant, Optimisers
2+
3+
is_empty_buffer(x::ConcreteIFRTNumber) = x.data === C_NULL
4+
is_empty_buffer(x::ConcretePJRTNumber) = any(x === C_NULL for x in x.data)
5+
6+
@testset "No Empty Buffers #861" begin
7+
opt = Descent(Reactant.to_rarray(0.1; track_numbers=Number))
8+
x = Reactant.to_rarray((ones(10), ones(3), ones(5)))
9+
opt_state = @jit Optimisers.setup(opt, x)
10+
11+
@test !is_empty_buffer(opt_state[1].rule.eta)
12+
@test !is_empty_buffer(opt_state[2].rule.eta)
13+
@test !is_empty_buffer(opt_state[3].rule.eta)
14+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
7474
@safetestset "SpecialFunctions" include("integration/special_functions.jl")
7575
@safetestset "Random" include("integration/random.jl")
7676
@safetestset "Python" include("integration/python.jl")
77+
@safetestset "Optimisers" include("integration/optimisers.jl")
7778
end
7879

7980
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"

0 commit comments

Comments
 (0)