Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLLogging = "a6db7da4-7206-11f0-1eab-35f2a5dbe1d1"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Expand Down Expand Up @@ -85,7 +86,7 @@ LeastSquaresOptim = "0.8.5"
LineSearch = "0.1.4"
LineSearches = "7.3"
LinearAlgebra = "1.10"
LinearSolve = "3.46"
LinearSolve = "3.48"
MINPACK = "1.2"
MPI = "0.20.22"
NLSolvers = "0.5"
Expand All @@ -106,9 +107,9 @@ Random = "1.10"
ReTestItems = "1.24"
Reexport = "1.2.2"
ReverseDiff = "1.15"
SciMLLogging = "1.3"
SIAMFANLEquations = "1.0.1"
SciMLBase = "2.127"
SciMLLogging = "1.3"
SimpleNonlinearSolve = "2.11"
SparseArrays = "1.10"
SparseConnectivityTracer = "1"
Expand Down Expand Up @@ -146,9 +147,9 @@ PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLLogging = "a6db7da4-7206-11f0-1eab-35f2a5dbe1d1"
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SciMLLogging = "a6db7da4-7206-11f0-1eab-35f2a5dbe1d1"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,8 @@ function set_lincache_A!(lincache, new_A)
return
end

function LinearSolve.update_tolerances!(cache::LinearSolveJLCache; kwargs...)
LinearSolve.update_tolerances!(cache.lincache; kwargs...)
end

end
2 changes: 2 additions & 0 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ export RelTerminationMode, AbsTerminationMode,
export DescentResult, SteepestDescent, NewtonDescent, DampedNewtonDescent, Dogleg,
GeodesicAcceleration

export EisenstatWalkerForcing2

export NonlinearSolvePolyAlgorithm

export NonlinearVerbosity
Expand Down
18 changes: 12 additions & 6 deletions lib/NonlinearSolveBase/src/verbosity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ verbose = NonlinearVerbosity(
termination_condition
# Numerical
threshold_state
forcing
end

# Group classifications
Expand All @@ -76,7 +77,7 @@ const error_control_options = (
:termination_condition
)
const performance_options = ()
const numerical_options = (:threshold_state,)
const numerical_options = (:threshold_state,:forcing)

function option_group(option::Symbol)
if option in error_control_options
Expand Down Expand Up @@ -138,7 +139,8 @@ function NonlinearVerbosity(;
alias_u0_immutable = WarnLevel(),
linsolve_failed_noncurrent = WarnLevel(),
termination_condition = WarnLevel(),
threshold_state = WarnLevel()
threshold_state = WarnLevel(),
forcing = Silent(),
)

# Apply group-level settings
Expand Down Expand Up @@ -173,7 +175,8 @@ function NonlinearVerbosity(verbose::AbstractVerbosityPreset)
alias_u0_immutable = Silent(),
linsolve_failed_noncurrent = WarnLevel(),
termination_condition = Silent(),
threshold_state = Silent()
threshold_state = Silent(),
forcing = Silent(),
)
elseif verbose isa Standard
# Standard: Everything from Minimal + non-fatal warnings
Expand All @@ -186,7 +189,8 @@ function NonlinearVerbosity(verbose::AbstractVerbosityPreset)
alias_u0_immutable = WarnLevel(),
linsolve_failed_noncurrent = WarnLevel(),
termination_condition = WarnLevel(),
threshold_state = WarnLevel()
threshold_state = WarnLevel(),
forcing = InfoLevel(),
)
elseif verbose isa All
# All: Maximum verbosity - every possible logging message at InfoLevel
Expand All @@ -196,7 +200,8 @@ function NonlinearVerbosity(verbose::AbstractVerbosityPreset)
alias_u0_immutable = WarnLevel(),
linsolve_failed_noncurrent = WarnLevel(),
termination_condition = WarnLevel(),
threshold_state = InfoLevel()
threshold_state = InfoLevel(),
forcing = InfoLevel(),
)
end
end
Expand All @@ -208,7 +213,8 @@ end
Silent(),
Silent(),
Silent(),
Silent()
Silent(),
Silent(),
)
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ using ForwardDiff: ForwardDiff, Dual # Default Forward Mode AD

include("solve.jl")
include("raphson.jl")
include("eisenstat_walker.jl")
include("gauss_newton.jl")
include("levenberg_marquardt.jl")
include("trust_region.jl")
Expand Down
111 changes: 111 additions & 0 deletions lib/NonlinearSolveFirstOrder/src/eisenstat_walker.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""
EisenstatWalkerForcing2(; η₀ = 0.5, ηₘₐₓ = 0.9, γ = 0.9, α = 2, safeguard = true, safeguard_threshold = 0.1)

Algorithm 2 from the classical work by Eisenstat and Walker (1996) as described by formula (2.6):
ηₖ = γ * (||rₖ|| / ||rₖ₋₁||)^α

Here the variables denote:
rₖ residual at iteration k
η₀ ∈ [0,1) initial value for η
ηₘₐₓ ∈ [0,1) maximum value for η
γ ∈ [0,1) correction factor
α ∈ [1,2) correction exponent

Furthermore, the proposed safeguard is implemented:
ηₖ = max(ηₖ, γ*ηₖ₋₁^α) if γ*ηₖ₋₁^α > safeguard_threshold
to prevent ηₖ from shrinking too fast.
"""
@concrete struct EisenstatWalkerForcing2
η₀
ηₘₐₓ
γ
α
safeguard
safeguard_threshold
end

function EisenstatWalkerForcing2(; η₀ = 0.5, ηₘₐₓ = 0.9, γ = 0.9, α = 2, safeguard = true, safeguard_threshold = 0.1)
EisenstatWalkerForcing2(η₀, ηₘₐₓ, γ, α, safeguard, safeguard_threshold)
end


@concrete mutable struct EisenstatWalkerForcing2Cache
p::EisenstatWalkerForcing2
η
rnorm
rnorm_prev
internalnorm
verbosity
end



function pre_step_forcing!(cache::EisenstatWalkerForcing2Cache, descend_cache::NonlinearSolveBase.NewtonDescentCache, J, u, fu, iter)
@SciMLMessage("Eisenstat-Walker forcing residual norm $(cache.rnorm) with rate estimate $(cache.rnorm / cache.rnorm_prev).", cache.verbosity, :forcing)

# On the first iteration we initialize η with the default initial value and stop.
if iter == 0
cache.η = cache.p.η₀
@SciMLMessage("Eisenstat-Walker initial iteration to η=$(cache.η).", cache.verbosity, :forcing)
LinearSolve.update_tolerances!(descend_cache.lincache; reltol=cache.η)
return nothing
end

# Store previous
ηprev = cache.η

# Formula (2.6)
# ||r|| > 0 should be guaranteed by the convergence criterion
(; rnorm, rnorm_prev) = cache
(; α, γ) = cache.p
cache.η = γ * (rnorm / rnorm_prev)^α

# Safeguard 2 to prevent over-solving
if cache.p.safeguard
ηsg = γ*ηprev^α
if ηsg > cache.p.safeguard_threshold && ηsg > cache.η
cache.η = ηsg
end
end

# Far away from the root we also need to respect η ∈ [0,1)
cache.η = clamp(cache.η, 0.0, cache.p.ηₘₐₓ)

@SciMLMessage("Eisenstat-Walker iter $iter update to η=$(cache.η).", cache.verbosity, :forcing)

# Communicate new relative tolerance to linear solve
LinearSolve.update_tolerances!(descend_cache.lincache; reltol=cache.η)

return nothing
end



function post_step_forcing!(cache::EisenstatWalkerForcing2Cache, J, u, fu, δu, iter)
# Cache previous residual norm
cache.rnorm_prev = cache.rnorm
cache.rnorm = cache.internalnorm(fu)

# @SciMLMessage("Eisenstat-Walker sanity check: $(cache.internalnorm(fu + J*δu)) ≤ $(cache.η * cache.internalnorm(fu)).", cache.verbosity, :linear_verbosity)
end



function InternalAPI.init(
prob::AbstractNonlinearProblem, alg::EisenstatWalkerForcing2, f, fu, u, p,
args...; verbose, internalnorm::F = L2_NORM, kwargs...
) where {F}
fu_norm = internalnorm(fu)

return EisenstatWalkerForcing2Cache(
alg, alg.η₀, fu_norm, fu_norm, internalnorm, verbose
)
end



function InternalAPI.reinit!(
cache::EisenstatWalkerForcing2Cache; p = cache.p, kwargs...
)
cache.p = p
end
3 changes: 2 additions & 1 deletion lib/NonlinearSolveFirstOrder/src/levenberg_marquardt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ function LevenbergMarquardt(;
vjp_autodiff,
jvp_autodiff,
name = :LevenbergMarquardt,
concrete_jac = Val(true)
concrete_jac = Val(true),

)
end

Expand Down
7 changes: 5 additions & 2 deletions lib/NonlinearSolveFirstOrder/src/raphson.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""
NewtonRaphson(;
concrete_jac = nothing, linsolve = nothing, linesearch = missing,
autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing
autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing,
forcing = nothing,
)

An advanced NewtonRaphson implementation with support for efficient handling of sparse
Expand All @@ -10,13 +11,15 @@ for large-scale and numerically-difficult nonlinear systems.
"""
function NewtonRaphson(;
concrete_jac = nothing, linsolve = nothing, linesearch = missing,
autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing
autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing,
forcing = nothing,
)
return GeneralizedFirstOrderAlgorithm(;
linesearch,
descent = NewtonDescent(; linsolve),
autodiff, vjp_autodiff, jvp_autodiff,
concrete_jac,
forcing,
name = :NewtonRaphson
)
end
33 changes: 29 additions & 4 deletions lib/NonlinearSolveFirstOrder/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ order of convergence.
linesearch
trustregion
descent
forcing
max_shrink_times::Int

autodiff
Expand All @@ -38,12 +39,12 @@ end
function GeneralizedFirstOrderAlgorithm(;
descent, linesearch = missing, trustregion = missing, autodiff = nothing,
vjp_autodiff = nothing, jvp_autodiff = nothing, max_shrink_times::Int = typemax(Int),
concrete_jac = Val(false), name::Symbol = :unknown
concrete_jac = Val(false), forcing = nothing, name::Symbol = :unknown
)
concrete_jac = concrete_jac isa Bool ? Val(concrete_jac) :
(concrete_jac isa Val ? concrete_jac : Val(concrete_jac !== nothing))
return GeneralizedFirstOrderAlgorithm(
linesearch, trustregion, descent, max_shrink_times,
linesearch, trustregion, descent, forcing, max_shrink_times,
autodiff, vjp_autodiff, jvp_autodiff,
concrete_jac, name
)
Expand All @@ -62,6 +63,7 @@ end
# Internal Caches
jac_cache
descent_cache
forcing_cache
linesearch_cache
trustregion_cache

Expand Down Expand Up @@ -125,7 +127,7 @@ function InternalAPI.reinit_self!(
end

NonlinearSolveBase.@internal_caches(GeneralizedFirstOrderAlgorithmCache,
:jac_cache, :descent_cache, :linesearch_cache, :trustregion_cache)
:jac_cache, :descent_cache, :linesearch_cache, :trustregion_cache, :forcing_cache)

function SciMLBase.__init(
prob::AbstractNonlinearProblem, alg::GeneralizedFirstOrderAlgorithm, args...;
Expand Down Expand Up @@ -196,6 +198,7 @@ function SciMLBase.__init(

has_linesearch = alg.linesearch !== missing && alg.linesearch !== nothing
has_trustregion = alg.trustregion !== missing && alg.trustregion !== nothing
has_forcing = alg.forcing !== missing && alg.forcing !== nothing

if has_trustregion && has_linesearch
error("TrustRegion and LineSearch methods are algorithmically incompatible.")
Expand All @@ -204,6 +207,7 @@ function SciMLBase.__init(
globalization = Val(:None)
linesearch_cache = nothing
trustregion_cache = nothing
forcing_cache = nothing

if has_trustregion
NonlinearSolveBase.supports_trust_region(alg.descent) ||
Expand All @@ -228,13 +232,24 @@ function SciMLBase.__init(
globalization = Val(:LineSearch)
end

if has_forcing
forcing_cache = InternalAPI.init(
prob, alg.forcing, fu, u, u, prob.p; stats, internalnorm,
autodiff = ifelse(
provided_jvp_autodiff, alg.jvp_autodiff, alg.vjp_autodiff
),
verbose,
kwargs...
)
end

trace = NonlinearSolveBase.init_nonlinearsolve_trace(
prob, alg, u, fu, J, du; kwargs...
)

cache = GeneralizedFirstOrderAlgorithmCache(
fu, u, u_cache, prob.p, alg, prob, globalization,
jac_cache, descent_cache, linesearch_cache, trustregion_cache,
jac_cache, descent_cache, forcing_cache, linesearch_cache, trustregion_cache,
stats, 0, maxiters, maxtime, alg.max_shrink_times, timer,
0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs,
initializealg, verbose
Expand All @@ -259,6 +274,12 @@ function InternalAPI.step!(
end
end

has_forcing = cache.forcing_cache !== nothing && cache.forcing_cache !== missing

if has_forcing
pre_step_forcing!(cache.forcing_cache, cache.descent_cache, J, cache.u, cache.fu, cache.nsteps)
end

@static_timeit cache.timer "descent" begin
if cache.trustregion_cache !== nothing &&
hasfield(typeof(cache.trustregion_cache), :trust_region)
Expand Down Expand Up @@ -293,6 +314,10 @@ function InternalAPI.step!(
δu, descent_intermediates = descent_result.δu, descent_result.extras

if descent_result.success
if has_forcing
post_step_forcing!(cache.forcing_cache, J, cache.u, cache.fu, δu, cache.nsteps)
end

cache.make_new_jacobian = true
if cache.globalization isa Val{:LineSearch}
@static_timeit cache.timer "linesearch" begin
Expand Down
Loading
Loading