Skip to content

Commit c2c2a70

Browse files
committed
Removed redundant _optimize function.
1 parent 8f90f83 commit c2c2a70

File tree

1 file changed

+22
-37
lines changed

1 file changed

+22
-37
lines changed

ext/TuringOptimExt.jl

Lines changed: 22 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,21 @@ mle = optimize(model, MLE(), NelderMead())
127127
```
128128
"""
129129
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MLE, options::Optim.Options=Optim.Options(); kwargs...)
130-
return _mle_optimize(model, options; kwargs...)
130+
ctx = Turing.OptimizationContext(DynamicPPL.LikelihoodContext())
131+
f = Turing.OptimLogDensity(model, ctx)
132+
init_vals = DynamicPPL.getparams(f)
133+
optimizer = Optim.LBFGS()
134+
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
131135
end
132136
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MLE, init_vals::AbstractArray, options::Optim.Options=Optim.Options(); kwargs...)
133-
return _mle_optimize(model, init_vals, options; kwargs...)
137+
optimizer = Optim.LBFGS()
138+
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
134139
end
135140
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MLE, optimizer::Optim.AbstractOptimizer, options::Optim.Options=Optim.Options(); kwargs...)
136-
return _mle_optimize(model, optimizer, options; kwargs...)
141+
ctx = Turing.OptimizationContext(DynamicPPL.LikelihoodContext())
142+
f = Turing.OptimLogDensity(model, ctx)
143+
init_vals = DynamicPPL.getparams(f)
144+
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
137145
end
138146
function Optim.optimize(
139147
model::DynamicPPL.Model,
@@ -173,13 +181,21 @@ map_est = optimize(model, MAP(), NelderMead())
173181
"""
174182

175183
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MAP, options::Optim.Options=Optim.Options(); kwargs...)
176-
return _map_optimize(model, options; kwargs...)
184+
ctx = Turing.OptimizationContext(DynamicPPL.DefaultContext())
185+
f = Turing.OptimLogDensity(model, ctx)
186+
init_vals = DynamicPPL.getparams(f)
187+
optimizer = Optim.LBFGS()
188+
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
177189
end
178190
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MAP, init_vals::AbstractArray, options::Optim.Options=Optim.Options(); kwargs...)
179-
return _map_optimize(model, init_vals, options; kwargs...)
191+
optimizer = Optim.LBFGS()
192+
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
180193
end
181194
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MAP, optimizer::Optim.AbstractOptimizer, options::Optim.Options=Optim.Options(); kwargs...)
182-
return _map_optimize(model, optimizer, options; kwargs...)
195+
ctx = Turing.OptimizationContext(DynamicPPL.DefaultContext())
196+
f = Turing.OptimLogDensity(model, ctx)
197+
init_vals = DynamicPPL.getparams(f)
198+
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
183199
end
184200
function Optim.optimize(
185201
model::DynamicPPL.Model,
@@ -202,37 +218,6 @@ end
202218
203219
Estimate a mode, i.e., compute a MLE or MAP estimate.
204220
"""
205-
function _optimize(
206-
model::DynamicPPL.Model,
207-
f::Turing.OptimLogDensity,
208-
optimizer::Optim.AbstractOptimizer=Optim.LBFGS(),
209-
args...;
210-
kwargs...
211-
)
212-
return _optimize(model, f, DynamicPPL.getparams(f), optimizer, args...; kwargs...)
213-
end
214-
215-
function _optimize(
216-
model::DynamicPPL.Model,
217-
f::Turing.OptimLogDensity,
218-
options::Optim.Options=Optim.Options(),
219-
args...;
220-
kwargs...
221-
)
222-
return _optimize(model, f, DynamicPPL.getparams(f), Optim.LBFGS(), args...; kwargs...)
223-
end
224-
225-
function _optimize(
226-
model::DynamicPPL.Model,
227-
f::Turing.OptimLogDensity,
228-
init_vals::AbstractArray=DynamicPPL.getparams(f),
229-
options::Optim.Options=Optim.Options(),
230-
args...;
231-
kwargs...
232-
)
233-
return _optimize(model, f, init_vals, Optim.LBFGS(), options, args...; kwargs...)
234-
end
235-
236221
function _optimize(
237222
model::DynamicPPL.Model,
238223
f::Turing.OptimLogDensity,

0 commit comments

Comments
 (0)