@@ -5,6 +5,11 @@ function loss_function(model, x, y, ps, st)
5
5
return CrossEntropyLoss ()(y_hat, y)
6
6
end
7
7
8
+ function loss_function (model, x, ps, st)
9
+ y_hat, _ = model (x, ps, st)
10
+ return sum (abs2, y_hat)
11
+ end
12
+
8
13
function gradient_loss_function (model, x, y, ps, st)
9
14
dps = Enzyme. make_zero (ps)
10
15
_, res = Enzyme. autodiff (
@@ -20,6 +25,20 @@ function gradient_loss_function(model, x, y, ps, st)
20
25
return res, dps
21
26
end
22
27
28
+ function gradient_loss_function (model, x, ps, st)
29
+ dps = Enzyme. make_zero (ps)
30
+ _, res = Enzyme. autodiff (
31
+ set_runtime_activity (ReverseWithPrimal),
32
+ loss_function,
33
+ Active,
34
+ Const (model),
35
+ Const (x),
36
+ Duplicated (ps, dps),
37
+ Const (st),
38
+ )
39
+ return res, dps
40
+ end
41
+
23
42
@testset " Lux.jl Integration" begin
24
43
# Generate some data for the XOR problem: vectors of length 2, as columns of a matrix:
25
44
noisy = rand (Float32, 2 , 1000 ) # 2×1000 Matrix{Float32}
67
86
@test_skip dps1 ≈ dps2 atol = 1e-3 rtol = 1e-2
68
87
end
69
88
end
89
+
90
+ @testset " RNN Integration" begin
91
+ using Reactant, Lux, Enzyme, Random
92
+
93
+ model = Recurrence (RNNCell (4 => 4 ); ordering= BatchLastIndex ())
94
+ ps, st = Reactant. to_rarray (Lux. setup (Random. default_rng (), model))
95
+
96
+ x = Reactant. to_rarray (rand (Float32, 4 , 16 , 12 ))
97
+
98
+ # This test requires running optimizations between the enzyme autodiff passes
99
+ res, ∂ps = @jit gradient_loss_function (model, x, ps, st)
100
+ @test res isa Reactant. ConcreteRNumber
101
+ end
0 commit comments