Skip to content

Commit 0dd16db

Browse files
committed
fix jacobians
1 parent f0dda1f commit 0dd16db

File tree

4 files changed

+15
-9
lines changed

4 files changed

+15
-9
lines changed

lib/NonlinearSolveBase/src/jacobian.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ function construct_jacobian_cache(
6161
end
6262

6363
J = if !needs_jac
64-
StatefulJacobianOperator(JacobianOperator(prob, fu, u; jvp_autodiff, vjp_autodiff), u, p)
64+
JacobianOperator(prob, fu, u; jvp_autodiff, vjp_autodiff)
6565
else
6666
if f.jac_prototype === nothing
6767
# While this is technically wasteful, it gives out the type of the Jacobian
@@ -96,7 +96,7 @@ function construct_jacobian_cache(
9696
linsolve = missing
9797
)
9898
if SciMLBase.has_jac(f) || SciMLBase.has_vjp(f) || SciMLBase.has_jvp(f)
99-
return JacobianCache(u, f, fu, u, p, stats, autodiff, nothing)
99+
return JacobianCache(fu, f, fu, p, stats, autodiff, nothing)
100100
end
101101
if autodiff === nothing
102102
throw(ArgumentError("`autodiff` argument to `construct_jacobian_cache` must be \
@@ -124,10 +124,12 @@ function InternalAPI.reinit!(cache::JacobianCache; p = cache.p, kwargs...)
124124
cache.p = p
125125
end
126126

127-
# Core Computation
128-
(cache::JacobianCache)(::Nothing) = cache.J
129-
(cache::JacobianCache{<:Number})(::Nothing) = cache.J
127+
# Deprecations
128+
(cache::JacobianCache{<:Number})(::Nothing) = error("Please report a bug to NonlinearSolve.jl")
129+
(cache::JacobianCache{<:JacobianOperator})(::Nothing) = error("Please report a bug to NonlinearSolve.jl")
130+
(cache::JacobianCache)(::Nothing) = error("Please report a bug to NonlinearSolve.jl")
130131

132+
# Core Computation
131133
## Numbers
132134
function (cache::JacobianCache{<:Number})(u)
133135
cache.stats.njacs += 1
@@ -168,6 +170,10 @@ function (cache::JacobianCache)(u)
168170
end
169171
end
170172

173+
function (cache::JacobianCache{<:JacobianOperator})(u)
174+
return StatefulJacobianOperator(cache.J, u, cache.p)
175+
end
176+
171177
# Sparse Automatic Differentiation
172178
function construct_concrete_adtype(f::NonlinearFunction, ad::AbstractADType)
173179
if f.sparsity === nothing

lib/NonlinearSolveBase/src/wrappers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,5 +106,5 @@ function construct_extension_jac(
106106

107107
initial_jacobian isa Val{false} && return J_final
108108

109-
return J_final, Jₚ(nothing)
109+
return J_final, Jₚ(u0)
110110
end

lib/NonlinearSolveFirstOrder/src/solve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ function SciMLBase.__init(
170170
prob, alg, prob.f, fu, u, prob.p;
171171
stats, alg.autodiff, linsolve, alg.jvp_autodiff, alg.vjp_autodiff
172172
)
173-
J = jac_cache(nothing)
173+
J = jac_cache(u)
174174

175175
descent_cache = InternalAPI.init(
176176
prob, alg.descent, J, fu, u; stats, abstol, reltol, internalnorm,
@@ -238,7 +238,7 @@ function InternalAPI.step!(
238238
J = cache.jac_cache(cache.u)
239239
new_jacobian = true
240240
else
241-
J = cache.jac_cache(nothing)
241+
J = cache.jac_cache(cache.u)
242242
new_jacobian = false
243243
end
244244
end

lib/NonlinearSolveQuasiNewton/src/initialization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ function InternalAPI.init(
125125
jac_cache = NonlinearSolveBase.construct_jacobian_cache(
126126
prob, solver, prob.f, fu, u, p; stats, autodiff, linsolve
127127
)
128-
J = alg.structure(jac_cache(nothing))
128+
J = alg.structure(jac_cache(u))
129129
return InitializedApproximateJacobianCache(
130130
J, alg.structure, alg, jac_cache, false, internalnorm
131131
)

0 commit comments

Comments
 (0)