Skip to content

Commit 4ec5450

Browse files
committed
improve fix
1 parent 602c23b commit 4ec5450

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

src/AbstractDifferentiation.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,17 @@ function pushforward_function(
163163
xs...,
164164
)
165165
return (ds) -> begin
166-
z = (ds isa Tuple ? _zero.(xs, ds) : _zero.(xs, (ds,)))
166+
if ds isa Tuple
167+
@assert length(xs) == length(ds)
168+
z = _zero.(xs, ds)
169+
elseif length(xs) == 1
170+
z = _zero.(xs, (ds,))
171+
else
172+
z = 0
173+
throw(ArgumentError("The input and tangents are not of compatible sizes."))
174+
end
167175
return jacobian(lowest(ab), (xds...,) -> begin
168176
if ds isa Tuple
169-
@assert length(xs) == length(ds)
170177
newxs = xs .+ ds .* xds
171178
return f(newxs...)
172179
else
@@ -225,7 +232,7 @@ function pullback_function(ab::AbstractBackend, f, xs...)
225232
return (ws) -> begin
226233
return gradient(lowest(ab), (xs...,) -> begin
227234
vs = f(xs...)
228-
if ws isa Tuple && length(ws) > 1
235+
if ws isa Tuple && vs isa Tuple
229236
@assert length(vs) == length(ws)
230237
return sum(Base.splat(_dot), zip(ws, vs))
231238
elseif ws isa Tuple && length(ws) == 1

0 commit comments

Comments
 (0)