From 97a24899edd642dcdbbe4ff3df90dd269e3765a7 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 10 Oct 2021 20:06:34 +0200 Subject: [PATCH] Improvements to pullback_function --- src/AbstractDifferentiation.jl | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index 227ae96..55504ed 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -221,31 +221,27 @@ _zero(::AbstractVector, d::AbstractMatrix) = zero(similar(d, size(d, 2))) _zero(::AbstractMatrix, d::AbstractMatrix) = zero(d) _zero(::Any, d::Any) = zero(d) +@inline _dot(x, y) = dot(x, y) +@inline function _dot(x::AbstractVector, y::UniformScaling) + @assert length(x) == 1 + return @inbounds dot(x[1], y.λ) +end +@inline function _dot(x::AbstractVector, y::AbstractMatrix) + @assert size(y, 2) == 1 + return dot(x, y) +end + function pullback_function(ab::AbstractBackend, f, xs...) return (ws) -> begin - jacs = jacobian(lowest(ab), (xs...,) -> begin + return gradient(lowest(ab), (xs...,) -> begin vs = f(xs...) if ws isa Tuple @assert length(vs) == length(ws) - return sum(zip(vs, ws)) do v, w - if w isa Union{AbstractMatrix, UniformScaling} && v isa AbstractVector - return w' * v - else - # for arbitrary arrays - return dot(w, v) - end - end + return sum(Base.splat(_dot), zip(ws, vs)) else - w, v = ws, vs - if w isa Union{AbstractMatrix, UniformScaling} && v isa AbstractVector - return w' * v - else - # for arbitrary arrays - return dot(w, v) - end + return _dot(vs, ws) end end, xs...) - return adjoint.(jacs) end end function value_and_pullback_function(