-
-
Notifications
You must be signed in to change notification settings - Fork 615
Open
Description
I am following the simple example of fitting a line.
I used slightly different sample points than in the tutorial and the training completly diverges to infinity. Here is the working code that is exactly as in the tutorial:
using Flux
using Statistics
Flux.Random.seed!(42)
f(x) = 4x + 2
m = Dense(1=>1)
loss(model, x, y) = mean(abs2.(model(x) .- y));
using Flux: train!
opt = Descent()
x = hcat(0:5 ...)
d = [(x, f.(x))]
train!(loss, m, d,opt); loss(m, x, f.(x))
m.weight, m.bias
for epoch in 1:200
train!(loss, m, d, opt)
end
@show loss(m, x, f.(x))
however when I use a slightly different set of sample points, for example x = hcat(1:6 ...)
or x = hcat(0:6 ...)
or x = hcat(1:10 ...)
or anything else than in the tutorial the loss from the train!
function becomes larger and larger diverging to infinity.
I would expect the training to be more stable, especially considering that this is an extremly simple model.
(-1:5
converges, -1:6
diverges, 0:8
diverges, ...)
Metadata
Metadata
Assignees
Labels
No labels