Skip to content

Commit 2d1e942

Browse files
Merge pull request #984 from ParamThakkar123/refactor
Added a new Subpackage for Augmented Lagrangian
2 parents a474c18 + fc48f5e commit 2d1e942

File tree

6 files changed

+73
-10
lines changed

6 files changed

+73
-10
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ jobs:
1919
matrix:
2020
group:
2121
- Core
22+
- OptimizationAuglag
2223
- OptimizationBBO
2324
- OptimizationCMAEvolutionStrategy
2425
- OptimizationEvolutionary
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
name = "OptimizationAuglag"
2+
uuid = "2ea93f80-9333-43a1-a68d-1f53b957a421"
3+
authors = ["paramthakkar123 <[email protected]>"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
8+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
9+
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
10+
OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
11+
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
12+
13+
[extras]
14+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
15+
16+
[compat]
17+
ForwardDiff = "1.0.1"
18+
MLUtils = "0.4.8"
19+
Optimization = "4.4.0"
20+
OptimizationBase = "2.10.0"
21+
OptimizationOptimisers = "0.3.8"
22+
Test = "1.10.0"
23+
24+
[targets]
25+
test = ["Test"]

src/auglag.jl renamed to lib/OptimizationAuglag/src/OptimizationAuglag.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
module OptimizationAuglag
2+
3+
using Optimization
4+
using OptimizationBase.SciMLBase: OptimizationProblem, OptimizationFunction, OptimizationStats
5+
using OptimizationBase.LinearAlgebra: norm
6+
17
@kwdef struct AugLag
28
inner::Any
39
τ = 0.5
@@ -20,7 +26,7 @@ SciMLBase.requiresgradient(::AugLag) = true
2026
SciMLBase.allowsconstraints(::AugLag) = true
2127
SciMLBase.requiresconsjac(::AugLag) = true
2228

23-
function __map_optimizer_args(cache::Optimization.OptimizationCache, opt::AugLag;
29+
function __map_optimizer_args(cache::OptimizationBase.OptimizationCache, opt::AugLag;
2430
callback = nothing,
2531
maxiters::Union{Number, Nothing} = nothing,
2632
maxtime::Union{Number, Nothing} = nothing,
@@ -110,7 +116,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
110116
cache.f.cons(cons_tmp, θ)
111117
cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds]
112118
cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache.ucons[ineq_inds]
113-
opt_state = Optimization.OptimizationState(u = θ, objective = x[1], p = p)
119+
opt_state = Optimization.OptimizationState(u = θ, objective = x[1])
114120
if cache.callback(opt_state, x...)
115121
error("Optimization halted by callback.")
116122
end
@@ -176,10 +182,12 @@ function SciMLBase.__solve(cache::OptimizationCache{
176182
break
177183
end
178184
end
179-
stats = Optimization.OptimizationStats(; iterations = maxiters,
185+
stats = OptimizationStats(; iterations = maxiters,
180186
time = 0.0, fevals = maxiters, gevals = maxiters)
181187
return SciMLBase.build_solution(
182188
cache, cache.opt, θ, x,
183189
stats = stats, retcode = opt_ret)
184190
end
185191
end
192+
193+
end
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using OptimizationBase
2+
using MLUtils
3+
using OptimizationOptimisers
4+
using OptimizationAuglag
5+
using ForwardDiff
6+
using OptimizationBase: OptimizationCache
7+
using OptimizationBase.SciMLBase: OptimizationFunction
8+
using Test
9+
10+
@testset "OptimizationAuglag.jl" begin
11+
x0 = (-pi):0.001:pi
12+
y0 = sin.(x0)
13+
data = MLUtils.DataLoader((x0, y0), batchsize = 126)
14+
15+
function loss(coeffs, data)
16+
ypred = [evalpoly(data[1][i], coeffs) for i in eachindex(data[1])]
17+
return sum(abs2, ypred .- data[2])
18+
end
19+
20+
function cons1(res, coeffs, p = nothing)
21+
res[1] = coeffs[1] * coeffs[5] - 1
22+
return nothing
23+
end
24+
25+
optf = OptimizationFunction(loss, OptimizationBase.AutoSparseForwardDiff(), cons = cons1)
26+
callback = (st, l) -> (@show l; return false)
27+
28+
initpars = rand(5)
29+
l0 = optf(initpars, (x0, y0))
30+
31+
prob = OptimizationProblem(optf, initpars, data, lcons = [-Inf], ucons = [1],
32+
lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0])
33+
opt = solve(
34+
prob, OptimizationAuglag.AugLag(; inner = Adam()), maxiters = 10000, callback = callback)
35+
@test opt.objective < l0
36+
end

src/Optimization.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ include("utils.jl")
2424
include("state.jl")
2525
include("lbfgsb.jl")
2626
include("sophia.jl")
27-
include("auglag.jl")
2827

2928
export solve
3029

test/native.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,6 @@ prob = OptimizationProblem(optf, initpars, (x0, y0), lcons = [-Inf], ucons = [0.
5151
opt1 = solve(prob, Optimization.LBFGS(), maxiters = 1000, callback = callback)
5252
@test opt1.objective < l0
5353

54-
prob = OptimizationProblem(optf, initpars, data, lcons = [-Inf], ucons = [1],
55-
lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0])
56-
opt = solve(
57-
prob, Optimization.AugLag(; inner = Adam()), maxiters = 10000, callback = callback)
58-
@test opt.objective < l0
59-
6054
optf1 = OptimizationFunction(loss, AutoSparseForwardDiff())
6155
prob1 = OptimizationProblem(optf1, rand(5), data)
6256
sol1 = solve(prob1, OptimizationOptimisers.Adam(), maxiters = 1000, callback = callback)

0 commit comments

Comments
 (0)