Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,30 +163,32 @@ Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
end
end

function combinations(xs, n)
n < 1 && return [[]]
cs = combinations(xs, n-1)
[[x, c...] for x in xs, c in cs]
for (T, S) in [(:TrackedArray, :TrackedArray), (:TrackedArray, :AbstractArray), (:AbstractArray, :TrackedArray)]
@eval Base.vcat(A::$T, B::$S, Cs::AbstractArray...) = track(vcat, A, B, Cs...)
@eval Base.hcat(A::$T, B::$S, Cs::AbstractArray...) = track(hcat, A, B, Cs...)
end

for i = 0:2, c = combinations([:AbstractArray, :TrackedArray, :Number], i), f = [:hcat, :vcat]
cnames = map(_ -> gensym(), c)
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Union{AbstractArray,Number}...) =
track($f, $(cnames...), x, xs...)
for (T, S) in [(:TrackedVector, :TrackedVector), (:TrackedVector, :AbstractVector), (:AbstractVector, :TrackedVector)]
@eval Base.vcat(A::$T, B::$S, Cs::AbstractVector...) = track(vcat, A, B, Cs...)
end

for i = 0:2, c = combinations([:AbstractVecOrMat, :TrackedVecOrMat], i), f = [:hcat, :vcat]
cnames = map(_ -> gensym(), c)
@eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVecOrMat{T}, xs::AbstractVecOrMat{T}...) where T =
track($f, $(cnames...), x, xs...)
for (T, S) in [(:TrackedVecOrMat, :TrackedVecOrMat), (:TrackedVecOrMat, :AbstractVecOrMat), (:AbstractVecOrMat, :TrackedVecOrMat)]
@eval Base.vcat(A::$T, B::$S, Cs::AbstractVecOrMat...) = track(vcat, A, B, Cs...)
@eval Base.hcat(A::$T, B::$S, Cs::AbstractVecOrMat...) = track(hcat, A, B, Cs...)
end

for i = 0:2, c = combinations([:AbstractVector, :TrackedVector], i), f = [:hcat, :vcat]
cnames = map(_ -> gensym(), c)
@eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVector{T}, xs::AbstractVector{T}...) where T =
track($f, $(cnames...), x, xs...)
for (T, S) in [(:TrackedArray, :Real), (:Real, :TrackedArray), (:TrackedArray, :TrackedArray)]
@eval Base.vcat(A::$T, B::$S, Cs::Union{AbstractArray, Real}...) = track(vcat, A, B, Cs...)
@eval Base.hcat(A::$T, B::$S, Cs::Union{AbstractArray, Real}...) = track(hcat, A, B, Cs...)
end
for (T, S) in [(:TrackedReal, :Real), (:Real, :TrackedReal), (:TrackedReal, :TrackedReal)]
@eval Base.vcat(A::$T, B::$S, Cs::Real...) = track(vcat, A, B, Cs...)
@eval Base.hcat(A::$T, B::$S, Cs::Real...) = track(hcat, A, B, Cs...)
end

Base.vcat(A::TrackedArray) = track(vcat, A)
Base.hcat(A::TrackedArray) = track(hcat, A)

Base.vcat(A::TrackedReal) = track(vcat, A)
Base.hcat(A::TrackedReal) = track(hcat, A)

@grad function vcat(xs...)
vcat(data.(xs)...), function (Δ)
start = 0
Expand Down Expand Up @@ -218,12 +220,12 @@ end
end
end

for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i)
cnames = map(_ -> gensym(), c)
@eval Base.cat($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...; dims) =
track(cat, $(cnames...), x, xs..., dims = dims)
for (T, S) in [(:TrackedArray, :TrackedArray), (:TrackedArray, :AbstractArray), (:AbstractArray, :TrackedArray)]
@eval Base.cat(A::$T, B::$S, Cs::AbstractArray...; dims) = track(cat, A, B, Cs...; dims = dims)
end

Base.cat(A::TrackedArray; dims) = track(cat, A; dims = dims)

@grad function cat(Xs...; dims)
cat(data.(Xs)..., dims = dims), function (Δ)
start = ntuple(i -> 0, Val(ndims(Δ)))
Expand Down Expand Up @@ -418,6 +420,9 @@ end
LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x...)
@grad diagm(i, x) = diagm(i => data(x)), Δ -> (nothing, diag(Δ, i))

# fix Matrix(Diagonal(param([1,2,3]))) after https://github.com/JuliaLang/julia/pull/44615
(::Type{Matrix})(d::Diagonal{<:Any,<:TrackedArray}) = diagm(0 => d.diag)

x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)
x::TrackedMatrix * y::TrackedMatrix = track(*, x, y)
Expand Down
4 changes: 2 additions & 2 deletions src/numeric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ function ngradient(f, xs::AbstractArray...)
return grads
end

gradcheck(f, xs...) =
gradcheck(f, xs...; rtol = 1e-5, atol = 1e-5) =
all(isapprox.(ngradient(f, xs...),
data.(gradient(f, xs...)), rtol = 1e-5, atol = 1e-5))
data.(gradient(f, xs...)); rtol = rtol, atol = atol))
48 changes: 32 additions & 16 deletions test/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ using Statistics: mean, std
using Random
# using StatsBase

gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
gradtest(f, xs::AbstractArray...; kw...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...; kw...)
gradtest(f, dims...; kw...) = gradtest(f, rand.(Float64, dims)...; kw...)

@testset "Tracker" begin # overall testset, rest of the file
@testset "gradtests 1" begin

@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W) -> σ.(W*x), 5, (2,5))
Expand Down Expand Up @@ -45,20 +45,24 @@ end
@test gradtest(logdet, map((x) -> x*x', (rand(4, 4),))[1])
@test gradtest((x) -> logabsdet(x)[1], (4, 4))

end # @testset gradtests

@testset "indexing & slicing" begin
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
@test gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
end

function promotiontest(f, A, B, C)
r0 = f(A, B, C)
r1 = f(param(A), B, C)
r2 = f(A, param(B), C)
r3 = f(A, B, param(C))
# r3 = f(A, B, param(C)) # no longer cater to tracked array in 3rd position
r4 = f(param(A), param(B), param(C))

@test !isa(r0, TrackedArray)
@test all(isa.([r1,r2,r3,r4], TrackedArray))
@test r1 == r2 == r3 == r4
# @test all(isa.([r1,r2,r3,r4], TrackedArray))
# @test r1 == r2 == r3 == r4
@test all(isa.([r1,r2,r4], TrackedArray))
@test r1 == r2 == r4
@test r0 == Tracker.data(r4)
end

Expand All @@ -68,7 +72,7 @@ end
rvcat(x...) = reduce(vcat, x)
rhcat(x...) = reduce(hcat, x)

@testset for vcatf in [vcat, cat1, rvcat]
@testset "2-arg $vcatf" for vcatf in [vcat, cat1, rvcat]
@test gradtest(vcatf, rand(5), rand(3))
@test gradtest(vcatf, rand(5), rand(3), rand(8))
@test gradtest(vcatf, rand(5)', rand(5)')
Expand All @@ -79,7 +83,7 @@ end
end


@testset for hcatf in [hcat, cat2, rhcat]
@testset "2-arg $hcatf" for hcatf in [hcat, cat2, rhcat]
@test gradtest(hcatf, rand(5), rand(5))
@test gradtest(hcatf, rand(5)', rand(5)')
@test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8))
Expand All @@ -89,7 +93,7 @@ end
@test gradtest(hcatf, rand(5), rand(5,2))
end

@testset for catf in [vcat, cat1, rvcat, hcat, cat2, rhcat, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
@testset "1-arg $catf" for catf in [vcat, cat1, rvcat, hcat, cat2, rhcat, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
@test gradtest(catf, rand(5))
@test gradtest(catf, rand(5)')
@test gradtest(catf, rand(2,5))
Expand Down Expand Up @@ -133,6 +137,13 @@ end
@test hcat(1, param([1 2 3;])) isa TrackedArray
@test vcat(param(1), 2) isa TrackedArray
end

@testset "ambiguities" begin
@test vcat(param([1, 2, 3]), [2,3]) isa TrackedArray
@test vcat(param([1, 2, 3]), [2.0, 3.0]) isa TrackedArray
@test hcat(param([1 2 3]), [2, 3]') isa TrackedArray
@test hcat(param([1 2 3]), [2.0, 3.0]') isa TrackedArray
end

end

Expand All @@ -141,6 +152,8 @@ end
@test gradtest(x->x[z], randn(MersenneTwister(123456), 3))
end

@testset "gradtests 2" begin

@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
@test gradtest(x -> PermutedDimsArray(x, [3,1,2]), rand(4,5,6))

Expand All @@ -159,6 +172,7 @@ end
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))

@test gradtest(x -> diagm(0 => x), rand(3))
@test gradtest(x -> Matrix(Diagonal(x)), rand(3))

@test gradtest(W -> inv(log.(W * W)), (5,5))
@test gradtest((A, B) -> A / B , (1,5), (5,5))
Expand All @@ -178,6 +192,8 @@ end
gradtest(A -> log.(A * A) \ exp.(B * B), (5, 5))
end

end # @testset "gradtests 2"

@testset "mean" begin
@test gradtest(mean, rand(2, 3))

Expand Down Expand Up @@ -208,6 +224,8 @@ end
@test gradtest(x -> minimum(x, dims=[1, 2]), rand(2, 3, 4))
end

@testset "gradtests 3" begin

@test gradtest(x -> std(x), rand(5,5))
@test gradtest(x -> std(x, dims = 1), rand(5,5))
@test gradtest(x -> std(x, dims = 1, corrected = false), rand(5,5))
Expand All @@ -224,6 +242,8 @@ end
2y + x
end

end # @testset "gradtests 3"

@testset "transpose" begin
w = Tracker.TrackedArray(rand(5,5))
x = Tracker.TrackedArray(rand(5,5))
Expand Down Expand Up @@ -299,17 +319,15 @@ end
@test transpose(w)*transpose(x) isa TrackedArray
end

@testset "conv" begin
for spatial_rank in (1, 2, 3)
@testset "conv, $(spatial_rank)d" for spatial_rank in (1, 2, 3)
x = rand(repeat([10], spatial_rank)..., 3, 2)
w = rand(repeat([3], spatial_rank)..., 3, 3)
cdims = DenseConvDims(x, w)
@test gradtest((x, w) -> conv(x, w, cdims), x, w)
y = conv(x, w, cdims)
@test gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w)
dcdims = DepthwiseConvDims(x, w)
@test gradtest((x, w) -> depthwiseconv(x, w, dcdims), x, w)
end
@test_skip gradtest((x, w) -> depthwiseconv(x, w, dcdims), x, w)
end

@testset "pooling" begin
Expand All @@ -321,7 +339,6 @@ end
end
end


@test gradtest(x -> Float64.(x), 5)

@testset "equality & order" begin
Expand Down Expand Up @@ -480,4 +497,3 @@ end
@test size(y) == (5, 3)
end

end # overall testset