Skip to content

no gradients if we save the Flux.params into a variable #1346

@MariusDrulea

Description

@MariusDrulea

See the following MWE:

using Flux

model = Dense(2, 2)

xt = rand(Float32, 2, 4) # batch size of 4
yt = rand(Float32, 2, 4)

ps = Flux.params(model)
loss_fun(m, x, y) = 1/2*sum(p->sum(p.^2), ps)

loss_fun_explicit(m, x, y) = 1/2*sum(m.weight.^2) + 1/2*sum(m.bias.^2)

loss_fun_slow(m, x, y) = 1/2*sum(p->sum(p.^2), Flux.params(m))

∇m = gradient(m->loss_fun(m, xt, yt), model)    
∇m_explicit = gradient(m->loss_fun_explicit(m, xt, yt), model)    
∇m_slow = gradient(m->loss_fun_slow(m, xt, yt), model)    

@show ∇m
@show ∇m_explicit
@show ∇m_slow

The values of the gradients are bellow. ∇m_explicit and ∇m_slow are equal and correct, but ∇m is nothing.

∇m = (nothing,)
∇m_explicit = ((weight = Float32[0.69311625 -1.0913904; -0.12783962 -0.15561718], bias = Float32[0.0, 0.0], σ = nothing),)
∇m_slow = ((weight = Float32[0.69311625 -1.0913904; -0.12783962 -0.15561718], bias = Float32[0.0, 0.0], σ = nothing),)

Metadata

Metadata

Assignees

No one assigned

    Labels

    implicitusing Params, Grads

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions