|
| 1 | +using Reactant, Lux, Random, Statistics, Enzyme, Functors, OneHotArrays |
| 2 | + |
| 3 | +function crossentropy(ŷ, y) |
| 4 | + logŷ = log.(ŷ) |
| 5 | + result = y .* logŷ |
| 6 | + return -sum(result) |
| 7 | +end |
| 8 | + |
| 9 | +function loss_function(model, x, y, ps, st) |
| 10 | + y_hat, _ = model(x, ps, st) |
| 11 | + # return CrossEntropyLoss()(y_hat, y) # <-- needs handling of xlogx xlogy from LuxOps |
| 12 | + return crossentropy(y_hat, y) |
| 13 | +end |
| 14 | + |
| 15 | +function gradient_loss_function(model, x, y, ps, st) |
| 16 | + dps = Enzyme.make_zero(ps) |
| 17 | + _, res = Enzyme.autodiff( |
| 18 | + ReverseWithPrimal, |
| 19 | + loss_function, |
| 20 | + Active, |
| 21 | + Const(model), |
| 22 | + Const(x), |
| 23 | + Const(y), |
| 24 | + Duplicated(ps, dps), |
| 25 | + Const(st), |
| 26 | + ) |
| 27 | + return res, dps |
| 28 | +end |
| 29 | + |
| 30 | +@testset "Lux.jl Integration" begin |
| 31 | + # Generate some data for the XOR problem: vectors of length 2, as columns of a matrix: |
| 32 | + noisy = rand(Float32, 2, 1000) # 2×1000 Matrix{Float32} |
| 33 | + truth = [xor(col[1] > 0.5, col[2] > 0.5) for col in eachcol(noisy)] # 1000-element Vector{Bool} |
| 34 | + |
| 35 | + # Define our model, a multi-layer perceptron with one hidden layer of size 3: |
| 36 | + model = Lux.Chain( |
| 37 | + Lux.Dense(2 => 3, tanh), # activation function inside layer |
| 38 | + Lux.BatchNorm(3, sigmoid), |
| 39 | + Lux.Dense(3 => 2), |
| 40 | + softmax, |
| 41 | + ) |
| 42 | + ps, st = Lux.setup(Xoshiro(123), model) |
| 43 | + |
| 44 | + origout, _ = model(noisy, ps, Lux.testmode(st)) |
| 45 | + |
| 46 | + cmodel = Reactant.to_rarray(model) |
| 47 | + cps = Reactant.to_rarray(ps) |
| 48 | + cst = Reactant.to_rarray(Lux.testmode(st)) |
| 49 | + cst2 = Reactant.to_rarray(st) |
| 50 | + cnoisy = Reactant.ConcreteRArray(noisy) |
| 51 | + |
| 52 | + f = Reactant.compile((a, b, c, d) -> first(a(b, c, d)), (cmodel, cnoisy, cps, cst)) |
| 53 | + |
| 54 | + comp = f(cmodel, cnoisy, cps, cst) |
| 55 | + |
| 56 | + @test comp ≈ origout atol = 1e-5 rtol = 1e-2 |
| 57 | + |
| 58 | + target = onehotbatch(truth, [true, false]) # 2×1000 OneHotMatrix |
| 59 | + |
| 60 | + ctarget = Reactant.ConcreteRArray(Array{Float32}(target)) |
| 61 | + # ctarget = Reactant.to_rarray(target) |
| 62 | + |
| 63 | + res, dps = gradient_loss_function(model, noisy, target, ps, st) |
| 64 | + |
| 65 | + compiled_gradient = Reactant.compile( |
| 66 | + gradient_loss_function, (cmodel, cnoisy, ctarget, cps, cst2) |
| 67 | + ) |
| 68 | + |
| 69 | + res_reactant, dps_reactant = compiled_gradient(cmodel, cnoisy, ctarget, cps, cst2) |
| 70 | + |
| 71 | + @test res ≈ res_reactant atol = 1e-5 rtol = 1e-2 |
| 72 | + for (dps1, dps2) in zip(fleaves(dps), fleaves(dps_reactant)) |
| 73 | + @test dps1 ≈ dps2 atol = 1e-5 rtol = 1e-2 |
| 74 | + end |
| 75 | +end |
0 commit comments