Skip to content

getting sane Lux use cases fully working #55

@ExpandingMan

Description

@ExpandingMan

More details to come, but here is a good, fairly minimal example showing how the most basic use case in Lux fails.

using Lux, LuxCUDA, Enzyme, Optimisers, Random, Static
using MLUtils, ADTypes
using Zygote
using OneHotArrays

_main() = quote
    device = gpu_device()
    rng = Xoshiro(999)

    Xdat = onehotbatch(rand(rng, 1:4, 128), 1:4)
    ydat = randn(rng, 4, 128)

    data = DataLoader((Xdat, ydat), batchsize=32, partial=false) |> device

    f = Chain(
        Dense(4=>8, relu), Dense(8=>8, relu), Dense(8=>4),
    )

    (θ, ψ) = Lux.setup(rng, f) |> device

    (x, y) = first(data)
    (ŷ, _) = Lux.apply(f, x, θ, ψ)

    opt = Adam(0.001f0)
    s = Training.TrainState(f, θ, ψ, opt)

    (_, ℓ, _, s) = Training.single_train_step!(AutoZygote(), MSELoss(), (x, y), s)
end

This fails with

ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.

If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] errorscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:151
  [3] _assertscalar(op::String, behavior::GPUArraysCore.ScalarIndexing)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:124
  [4] assertscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:112
  [5] getindex
    @ ~/.julia/packages/GPUArrays/uiVyU/src/host/indexing.jl:50 [inlined]
  [6] getindex
    @ ~/.julia/dev/OneHotArrays/src/array.jl:67 [inlined]
  [7] getindex
    @ ~/.julia/juliaup/julia-1.11.5+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/adjtrans.jl:334 [inlined]
  [8] _generic_matmatmul!(C::CuArray{…}, A::CuArray{…}, B::LinearAlgebra.Transpose{…}, _add::LinearAlgebra.MulAddMul{…})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.11.5+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:894
  [9] generic_matmatmul!
    @ ~/.julia/juliaup/julia-1.11.5+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:868 [inlined]
 [10] _mul!
    @ ~/.julia/juliaup/julia-1.11.5+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:287 [inlined]
 [11] mul!
    @ ~/.julia/juliaup/julia-1.11.5+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:285 [inlined]
 [12] mul!
    @ ~/.julia/juliaup/julia-1.11.5+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:253 [inlined]
 [13] matmul!
    @ ~/.julia/packages/LuxLib/Kj0os/src/impl/matmul.jl:119 [inlined]
 [14] matmul!
    @ ~/.julia/packages/LuxLib/Kj0os/src/impl/matmul.jl:112 [inlined]
 [15] matmul
    @ ~/.julia/packages/LuxLib/Kj0os/src/impl/matmul.jl:57 [inlined]
 [16] matmul
    @ ~/.julia/packages/LuxLib/Kj0os/src/impl/matmul.jl:51 [inlined]
 [17] ∇matmul_bias
    @ ~/.julia/packages/LuxLib/Kj0os/src/impl/dense.jl:287 [inlined]
 [18] ∇matmul_bias(∂y::CuArray{…}, weight::CuArray{…}, x::OneHotMatrix{…}, bias::CuArray{…})
    @ LuxLib.Impl ~/.julia/packages/LuxLib/Kj0os/src/impl/dense.jl:286
 [19] (::LuxLib.Impl.var"#96#99"{…})(Δ::CuArray{…})
    @ LuxLib.Impl ~/.julia/packages/LuxLib/Kj0os/src/impl/dense.jl:87
 [20] ZBack
    @ ~/.julia/packages/Zygote/wfLOG/src/compiler/chainrules.jl:222 [inlined]
 [21] fused_dense
    @ ~/.julia/packages/LuxLib/Kj0os/src/impl/dense.jl:16 [inlined]
 [22] fused_dense_bias_activation
    @ ~/.julia/packages/LuxLib/Kj0os/src/api/dense.jl:36 [inlined]
 [23] Dense
    @ ~/.julia/packages/Lux/L2VO7/src/layers/basic.jl:357 [inlined]
 [24] apply
    @ ~/.julia/packages/LuxCore/Av7WJ/src/LuxCore.jl:155 [inlined]
 [25] applychain
    @ ~/.julia/packages/Lux/L2VO7/src/layers/containers.jl:0 [inlined]
 [26] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRulesCore.Thunk{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
 [27] Chain
    @ ~/.julia/packages/Lux/L2VO7/src/layers/containers.jl:509 [inlined]
 [28] AbstractLossFunction
    @ ~/.julia/packages/Lux/L2VO7/src/helpers/losses.jl:190 [inlined]
 [29] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Float32, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
 [30] #1
    @ ~/.julia/packages/Lux/L2VO7/ext/LuxZygoteExt/training.jl:8 [inlined]
 [31] (::Zygote.Pullback{Tuple{LuxZygoteExt.var"#1#2"{…}, @NamedTuple{…}}, Any})(Δ::Tuple{Float32, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
 [32] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{…}, Any}})(Δ::Tuple{Float32, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface.jl:97
 [33] compute_gradients_impl(::AutoZygote, objective_function::GenericLossFunction{…}, data::Tuple{…}, ts::Lux.Training.TrainState{…})
    @ LuxZygoteExt ~/.julia/packages/Lux/L2VO7/ext/LuxZygoteExt/training.jl:10
 [34] compute_gradients
    @ ~/.julia/packages/Lux/L2VO7/src/helpers/training.jl:202 [inlined]
 [35] single_train_step_impl!(backend::AutoZygote, obj_fn::GenericLossFunction{…}, data::Tuple{…}, ts::Lux.Training.TrainState{…})
    @ Lux.Training ~/.julia/packages/Lux/L2VO7/src/helpers/training.jl:328
 [36] #single_train_step!#6
    @ ~/.julia/packages/Lux/L2VO7/src/helpers/training.jl:294 [inlined]
 [37] single_train_step!(backend::AutoZygote, obj_fn::GenericLossFunction{…}, data::Tuple{…}, ts::Lux.Training.TrainState{…})
    @ Lux.Training ~/.julia/packages/Lux/L2VO7/src/helpers/training.jl:288
 [38] top-level scope
    @ ~/src/scrap2.jl:28
 [39] eval
    @ ./boot.jl:430 [inlined]
 [40] eval(x::Expr)
    @ Main ./sysimg.jl:48
 [41] top-level scope
    @ REPL[4]:1
Some type information was truncated. Use `show(err)` to see complete types.

[8] in the stack trace is a matmul between a CuArray and a Transpose of a OneHotArray (I didn't dump the full output because it's huge, but I have verified this). So this clearly happens somewhere in reverse diff.

I will see if I can reproduce this a bit more directly, but I wanted to keep the lux example up because I think we should get all those use cases working.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions