|
| 1 | +using OptimizationBase, Optimization |
| 2 | +using OptimizationBase.SciMLBase: solve, OptimizationFunction, OptimizationProblem |
| 3 | +using OptimizationSophia |
| 4 | +using Lux, MLUtils, Random, ComponentArrays |
| 5 | +using SciMLSensitivity |
| 6 | +using Test |
| 7 | +using Zygote |
| 8 | +using OrdinaryDiffEqTsit5 |
| 9 | + |
| 10 | +function dudt_(u, p, t) |
| 11 | + ann(u, p, st)[1] .* u |
| 12 | +end |
| 13 | + |
| 14 | +function newtons_cooling(du, u, p, t) |
| 15 | + temp = u[1] |
| 16 | + k, temp_m = p |
| 17 | + du[1] = dT = -k * (temp - temp_m) |
| 18 | +end |
| 19 | + |
| 20 | +function true_sol(du, u, p, t) |
| 21 | + true_p = [log(2) / 8.0, 100.0] |
| 22 | + newtons_cooling(du, u, true_p, t) |
| 23 | +end |
| 24 | + |
| 25 | +function callback(state, l) #callback function to observe training |
| 26 | + display(l) |
| 27 | + return l < 1e-2 |
| 28 | +end |
| 29 | + |
| 30 | +function predict_adjoint(fullp, time_batch) |
| 31 | + Array(solve(prob, Tsit5(), p = fullp, saveat = time_batch)) |
| 32 | +end |
| 33 | + |
| 34 | +function loss_adjoint(fullp, p) |
| 35 | + (batch, time_batch) = p |
| 36 | + pred = predict_adjoint(fullp, time_batch) |
| 37 | + sum(abs2, batch .- pred) |
| 38 | +end |
| 39 | + |
| 40 | +u0 = Float32[200.0] |
| 41 | +datasize = 30 |
| 42 | +tspan = (0.0f0, 1.5f0) |
| 43 | +rng = Random.default_rng() |
| 44 | + |
| 45 | +ann = Lux.Chain(Lux.Dense(1, 8, tanh), Lux.Dense(8, 1, tanh)) |
| 46 | +pp, st = Lux.setup(rng, ann) |
| 47 | +pp = ComponentArray(pp) |
| 48 | + |
| 49 | +prob = ODEProblem{false}(dudt_, u0, tspan, pp) |
| 50 | + |
| 51 | +t = range(tspan[1], tspan[2], length = datasize) |
| 52 | +true_prob = ODEProblem(true_sol, u0, tspan) |
| 53 | +ode_data = Array(solve(true_prob, Tsit5(), saveat = t)) |
| 54 | + |
| 55 | +k = 10 |
| 56 | +train_loader = MLUtils.DataLoader((ode_data, t), batchsize = k) |
| 57 | + |
| 58 | +l1 = loss_adjoint(pp, (train_loader.data[1], train_loader.data[2]))[1] |
| 59 | + |
| 60 | +optfun = OptimizationFunction(loss_adjoint, |
| 61 | + OptimizationBase.AutoZygote()) |
| 62 | +optprob = OptimizationProblem(optfun, pp, train_loader) |
| 63 | + |
| 64 | +res1 = solve(optprob, |
| 65 | + OptimizationSophia.Sophia(), callback = callback, |
| 66 | + maxiters = 2000) |
| 67 | +@test 10res1.objective < l1 |
| 68 | + |
| 69 | +# Test Sophia with ComponentArrays + Enzyme (shadow generation fix) |
| 70 | +using ComponentArrays |
| 71 | +x0_comp = ComponentVector(a = 0.0, b = 0.0) |
| 72 | +rosenbrock_comp(x, p = nothing) = (1 - x.a)^2 + 100 * (x.b - x.a^2)^2 |
| 73 | + |
| 74 | +optf_sophia = OptimizationFunction(rosenbrock_comp, AutoEnzyme()) |
| 75 | +prob_sophia = OptimizationProblem(optf_sophia, x0_comp) |
| 76 | +res_sophia = solve(prob_sophia, OptimizationSophia.Sophia(η=0.01, k=5), maxiters = 50) |
| 77 | +@test res_sophia.objective < rosenbrock_comp(x0_comp) # Test optimization progress |
| 78 | +@test res_sophia.retcode == Optimization.SciMLBase.ReturnCode.Success |
0 commit comments