Skip to content

Commit 3c7323b

Browse files
committed
fix #100
1 parent f9517dc commit 3c7323b

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

src/AbstractDifferentiation.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ function pushforward_function(
163163
xs...,
164164
)
165165
return (ds) -> begin
166+
z = (ds isa Tuple ? _zero.(xs, ds) : _zero.(xs, (ds,)))
166167
return jacobian(lowest(ab), (xds...,) -> begin
167168
if ds isa Tuple
168169
@assert length(xs) == length(ds)
@@ -172,7 +173,7 @@ function pushforward_function(
172173
newx = only(xs) + ds * only(xds)
173174
return f(newx)
174175
end
175-
end, _zero.(xs, ds)...)
176+
end, z...)
176177
end
177178
end
178179
function value_and_pushforward_function(
@@ -224,9 +225,11 @@ function pullback_function(ab::AbstractBackend, f, xs...)
224225
return (ws) -> begin
225226
return gradient(lowest(ab), (xs...,) -> begin
226227
vs = f(xs...)
227-
if ws isa Tuple
228+
if ws isa Tuple && length(ws) > 1
228229
@assert length(vs) == length(ws)
229230
return sum(Base.splat(_dot), zip(ws, vs))
231+
elseif ws isa Tuple && length(ws) == 1
232+
return _dot(vs, only(ws))
230233
else
231234
return _dot(vs, ws)
232235
end

test/test_utils.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,14 @@ function test_jvp(backend; multiple_inputs=true, vaugmented=false, rng=Random.GL
229229
end
230230

231231
valvec1, pf1 = AD.value_and_pushforward_function(backend, x -> fjac(x, yvec), xvec)(v[1])
232+
_valvec1, _pf1 = AD.value_and_pushforward_function(backend, x -> fjac(x, yvec), xvec)((v[1],))
233+
@test valvec1 == _valvec1
234+
@test pf1 == _pf1
235+
232236
valvec2, pf2 = AD.value_and_pushforward_function(backend, y -> fjac(xvec, y), yvec)(v[2])
237+
_valvec2, _pf2 = AD.value_and_pushforward_function(backend, y -> fjac(xvec, y), yvec)((v[2],))
238+
@test valvec2 == _valvec2
239+
@test pf2 == _pf2
233240

234241
if test_types
235242
@test valvec1 isa Vector{Float64}
@@ -247,7 +254,13 @@ function test_j′vp(backend; multiple_inputs=true, rng=Random.GLOBAL_RNG, test_
247254
w = rand(rng, length(fjac(xvec, yvec)))
248255
if multiple_inputs
249256
pb1 = AD.pullback_function(backend, fjac, xvec, yvec)(w)
257+
_pb1 = AD.pullback_function(backend, fjac, xvec, yvec)((w,))
258+
@test pb1 == _pb1
259+
250260
valvec, pb2 = AD.value_and_pullback_function(backend, fjac, xvec, yvec)(w)
261+
_valvec, _pb2 = AD.value_and_pullback_function(backend, fjac, xvec, yvec)((w,))
262+
@test valvec == _valvec
263+
@test pb2 == _pb2
251264

252265
if test_types
253266
@test valvec isa Vector{Float64}
@@ -264,7 +277,15 @@ function test_j′vp(backend; multiple_inputs=true, rng=Random.GLOBAL_RNG, test_
264277
end
265278

266279
valvec1, pb1 = AD.value_and_pullback_function(backend, x -> fjac(x, yvec), xvec)(w)
280+
_valvec1, _pb1 = AD.value_and_pullback_function(backend, x -> fjac(x, yvec), xvec)((w,))
281+
@test valvec1 == _valvec1
282+
@test pb1 == _pb1
283+
267284
valvec2, pb2 = AD.value_and_pullback_function(backend, y -> fjac(xvec, y), yvec)(w)
285+
_valvec2, _pb2 = AD.value_and_pullback_function(backend, y -> fjac(xvec, y), yvec)((w,))
286+
@test valvec2 == _valvec2
287+
@test pb2 == _pb2
288+
268289
if test_types
269290
@test valvec1 isa Vector{Float64}
270291
@test valvec2 isa Vector{Float64}

0 commit comments

Comments
 (0)