Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/optimise/train.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using ProgressLogging: @progress, @withprogress, @logprogress
import Zygote: Params, gradient
import Zygote: Params, gradient, withgradient


"""
Expand Down Expand Up @@ -126,9 +126,13 @@ function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
n = (itrsz == Base.HasLength()) || (itrsz == Base.HasShape{1}()) ? length(data) : 0
@withprogress for (i, d) in enumerate(data)
try
gs = gradient(ps) do
l, gs = withgradient(ps) do
loss(batchmemaybe(d)...)
end
if !isfinite(l)
@warn "Loss is $l on item $i, stopping training"
break
end
update!(opt, ps, gs)
cb()
catch ex
Expand Down
12 changes: 12 additions & 0 deletions test/optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ end
Flux.train!(loss, Flux.params(r), (r,), Descent())
end

@testset "Stop on NaN" begin
m = Dense(1 => 1)
m.weight .= 0
CNT = 0
Flux.train!(Flux.params(m), 1:100, Descent(0.1)) do i
CNT += 1
i == 51 ? NaN32 : sum(m([1f0]))
end
@test CNT == 51 # stopped early
@test m.weight[1] ≈ -5 # did not corrupt weights
end

@testset "ExpDecay" begin

@testset "Sanity Check" begin
Expand Down