diff --git a/src/Nonlinear/ReverseAD/forward_over_reverse.jl b/src/Nonlinear/ReverseAD/forward_over_reverse.jl index bcc1f86a98..ed083a78c5 100644 --- a/src/Nonlinear/ReverseAD/forward_over_reverse.jl +++ b/src/Nonlinear/ReverseAD/forward_over_reverse.jl @@ -50,61 +50,19 @@ function _eval_hessian_inner( @assert length(ex.hess_I) == 0 return 0 end - T = ForwardDiff.Partials{CHUNK,Float64} # This is our element type. Coloring.prepare_seed_matrix!(ex.seed_matrix, ex.rinfo) - local_to_global_idx = ex.rinfo.local_indices - input_ϵ_raw, output_ϵ_raw = d.input_ϵ, d.output_ϵ - input_ϵ = _reinterpret_unsafe(T, input_ϵ_raw) - output_ϵ = _reinterpret_unsafe(T, output_ϵ_raw) # Compute hessian-vector products num_products = size(ex.seed_matrix, 2) # number of hessian-vector products num_chunks = div(num_products, CHUNK) - @assert size(ex.seed_matrix, 1) == length(local_to_global_idx) - for k in 1:CHUNK:(CHUNK*num_chunks) - for r in 1:length(local_to_global_idx) - # set up directional derivatives - @inbounds idx = local_to_global_idx[r] - # load up ex.seed_matrix[r,k,k+1,...,k+CHUNK-1] into input_ϵ - for s in 1:CHUNK - input_ϵ_raw[(idx-1)*CHUNK+s] = ex.seed_matrix[r, k+s-1] - end - @inbounds output_ϵ[idx] = zero(T) - end - _hessian_slice_inner(d, ex, input_ϵ, output_ϵ, T) - # collect directional derivatives - for r in 1:length(local_to_global_idx) - idx = local_to_global_idx[r] - # load output_ϵ into ex.seed_matrix[r,k,k+1,...,k+CHUNK-1] - for s in 1:CHUNK - ex.seed_matrix[r, k+s-1] = output_ϵ_raw[(idx-1)*CHUNK+s] - end - @inbounds input_ϵ[idx] = zero(T) - end + @assert size(ex.seed_matrix, 1) == length(ex.rinfo.local_indices) + for offset in 1:CHUNK:(CHUNK*num_chunks) + _eval_hessian_chunk(d, ex, offset, CHUNK, Val(CHUNK)) end # leftover chunk remaining = num_products - CHUNK * num_chunks if remaining > 0 - k = CHUNK * num_chunks + 1 - for r in 1:length(local_to_global_idx) - # set up directional derivatives - @inbounds idx = local_to_global_idx[r] - # load up ex.seed_matrix[r,k,k+1,...,k+remaining-1] into input_ϵ - for s in 1:remaining - # leave junk in the unused components - input_ϵ_raw[(idx-1)*CHUNK+s] = ex.seed_matrix[r, k+s-1] - end - @inbounds output_ϵ[idx] = zero(T) - end - _hessian_slice_inner(d, ex, input_ϵ, output_ϵ, T) - # collect directional derivatives - for r in 1:length(local_to_global_idx) - idx = local_to_global_idx[r] - # load output_ϵ into ex.seed_matrix[r,k,k+1,...,k+remaining-1] - for s in 1:remaining - ex.seed_matrix[r, k+s-1] = output_ϵ_raw[(idx-1)*CHUNK+s] - end - @inbounds input_ϵ[idx] = zero(T) - end + offset = CHUNK * num_chunks + 1 + _eval_hessian_chunk(d, ex, offset, remaining, Val(CHUNK)) end want, got = nzcount + length(ex.hess_I), length(H) if want > got @@ -127,7 +85,40 @@ function _eval_hessian_inner( return length(ex.hess_I) end -function _hessian_slice_inner(d, ex, input_ϵ, output_ϵ, ::Type{T}) where {T} +function _eval_hessian_chunk( + d::NLPEvaluator, + ex::_FunctionStorage, + offset::Int, + chunk::Int, + ::Val{CHUNK}, +) where {CHUNK} + for r in eachindex(ex.rinfo.local_indices) + # set up directional derivatives + @inbounds idx = ex.rinfo.local_indices[r] + # load up ex.seed_matrix[r,k,k+1,...,k+remaining-1] into input_ϵ + for s in 1:chunk + # If `chunk < CHUNK`, leaves junk in the unused components + d.input_ϵ[(idx-1)*CHUNK+s] = ex.seed_matrix[r, offset+s-1] + end + end + _hessian_slice_inner(d, ex, Val(CHUNK)) + fill!(d.input_ϵ, 0.0) + # collect directional derivatives + for r in eachindex(ex.rinfo.local_indices) + @inbounds idx = ex.rinfo.local_indices[r] + # load output_ϵ into ex.seed_matrix[r,k,k+1,...,k+remaining-1] + for s in 1:chunk + ex.seed_matrix[r, offset+s-1] = d.output_ϵ[(idx-1)*CHUNK+s] + end + end + return +end + +function _hessian_slice_inner(d, ex, ::Val{CHUNK}) where {CHUNK} + T = ForwardDiff.Partials{CHUNK,Float64} # This is our element type. + input_ϵ = _reinterpret_unsafe(T, d.input_ϵ) + fill!(d.output_ϵ, 0.0) + output_ϵ = _reinterpret_unsafe(T, d.output_ϵ) subexpr_forward_values_ϵ = _reinterpret_unsafe(T, d.subexpression_forward_values_ϵ) for i in ex.dependent_subexpressions