-
-
Notifications
You must be signed in to change notification settings - Fork 8
Closed
Description
More details to come, but here is a good, fairly minimal example showing how the most basic use case in Lux fails.
using Lux, LuxCUDA, Enzyme, Optimisers, Random, Static
using MLUtils, ADTypes
using Zygote
using OneHotArrays
_main() = quote
device = gpu_device()
rng = Xoshiro(999)
Xdat = onehotbatch(rand(rng, 1:4, 128), 1:4)
ydat = randn(rng, 4, 128)
data = DataLoader((Xdat, ydat), batchsize=32, partial=false) |> device
f = Chain(
Dense(4=>8, relu), Dense(8=>8, relu), Dense(8=>4),
)
(θ, ψ) = Lux.setup(rng, f) |> device
(x, y) = first(data)
(ŷ, _) = Lux.apply(f, x, θ, ψ)
opt = Adam(0.001f0)
s = Training.TrainState(f, θ, ψ, opt)
(_, ℓ, _, s) = Training.single_train_step!(AutoZygote(), MSELoss(), (x, y), s)
end
This fails with
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.
If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] errorscalar(op::String)
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:151
[3] _assertscalar(op::String, behavior::GPUArraysCore.ScalarIndexing)
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:124
[4] assertscalar(op::String)
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:112
[5] getindex
@ ~/.julia/packages/GPUArrays/uiVyU/src/host/indexing.jl:50 [inlined]
[6] getindex
@ ~/.julia/dev/OneHotArrays/src/array.jl:67 [inlined]
[7] getindex
@ ~/.julia/juliaup/julia-1.11.5+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/adjtrans.jl:334 [inlined]
[8] _generic_matmatmul!(C::CuArray{…}, A::CuArray{…}, B::LinearAlgebra.Transpose{…}, _add::LinearAlgebra.MulAddMul{…})
@ LinearAlgebra ~/.julia/juliaup/julia-1.11.5+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:894
[9] generic_matmatmul!
@ ~/.julia/juliaup/julia-1.11.5+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:868 [inlined]
[10] _mul!
@ ~/.julia/juliaup/julia-1.11.5+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:287 [inlined]
[11] mul!
@ ~/.julia/juliaup/julia-1.11.5+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:285 [inlined]
[12] mul!
@ ~/.julia/juliaup/julia-1.11.5+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:253 [inlined]
[13] matmul!
@ ~/.julia/packages/LuxLib/Kj0os/src/impl/matmul.jl:119 [inlined]
[14] matmul!
@ ~/.julia/packages/LuxLib/Kj0os/src/impl/matmul.jl:112 [inlined]
[15] matmul
@ ~/.julia/packages/LuxLib/Kj0os/src/impl/matmul.jl:57 [inlined]
[16] matmul
@ ~/.julia/packages/LuxLib/Kj0os/src/impl/matmul.jl:51 [inlined]
[17] ∇matmul_bias
@ ~/.julia/packages/LuxLib/Kj0os/src/impl/dense.jl:287 [inlined]
[18] ∇matmul_bias(∂y::CuArray{…}, weight::CuArray{…}, x::OneHotMatrix{…}, bias::CuArray{…})
@ LuxLib.Impl ~/.julia/packages/LuxLib/Kj0os/src/impl/dense.jl:286
[19] (::LuxLib.Impl.var"#96#99"{…})(Δ::CuArray{…})
@ LuxLib.Impl ~/.julia/packages/LuxLib/Kj0os/src/impl/dense.jl:87
[20] ZBack
@ ~/.julia/packages/Zygote/wfLOG/src/compiler/chainrules.jl:222 [inlined]
[21] fused_dense
@ ~/.julia/packages/LuxLib/Kj0os/src/impl/dense.jl:16 [inlined]
[22] fused_dense_bias_activation
@ ~/.julia/packages/LuxLib/Kj0os/src/api/dense.jl:36 [inlined]
[23] Dense
@ ~/.julia/packages/Lux/L2VO7/src/layers/basic.jl:357 [inlined]
[24] apply
@ ~/.julia/packages/LuxCore/Av7WJ/src/LuxCore.jl:155 [inlined]
[25] applychain
@ ~/.julia/packages/Lux/L2VO7/src/layers/containers.jl:0 [inlined]
[26] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRulesCore.Thunk{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[27] Chain
@ ~/.julia/packages/Lux/L2VO7/src/layers/containers.jl:509 [inlined]
[28] AbstractLossFunction
@ ~/.julia/packages/Lux/L2VO7/src/helpers/losses.jl:190 [inlined]
[29] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Float32, Nothing, Nothing})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[30] #1
@ ~/.julia/packages/Lux/L2VO7/ext/LuxZygoteExt/training.jl:8 [inlined]
[31] (::Zygote.Pullback{Tuple{LuxZygoteExt.var"#1#2"{…}, @NamedTuple{…}}, Any})(Δ::Tuple{Float32, Nothing, Nothing})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[32] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{…}, Any}})(Δ::Tuple{Float32, Nothing, Nothing})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface.jl:97
[33] compute_gradients_impl(::AutoZygote, objective_function::GenericLossFunction{…}, data::Tuple{…}, ts::Lux.Training.TrainState{…})
@ LuxZygoteExt ~/.julia/packages/Lux/L2VO7/ext/LuxZygoteExt/training.jl:10
[34] compute_gradients
@ ~/.julia/packages/Lux/L2VO7/src/helpers/training.jl:202 [inlined]
[35] single_train_step_impl!(backend::AutoZygote, obj_fn::GenericLossFunction{…}, data::Tuple{…}, ts::Lux.Training.TrainState{…})
@ Lux.Training ~/.julia/packages/Lux/L2VO7/src/helpers/training.jl:328
[36] #single_train_step!#6
@ ~/.julia/packages/Lux/L2VO7/src/helpers/training.jl:294 [inlined]
[37] single_train_step!(backend::AutoZygote, obj_fn::GenericLossFunction{…}, data::Tuple{…}, ts::Lux.Training.TrainState{…})
@ Lux.Training ~/.julia/packages/Lux/L2VO7/src/helpers/training.jl:288
[38] top-level scope
@ ~/src/scrap2.jl:28
[39] eval
@ ./boot.jl:430 [inlined]
[40] eval(x::Expr)
@ Main ./sysimg.jl:48
[41] top-level scope
@ REPL[4]:1
Some type information was truncated. Use `show(err)` to see complete types.
[8]
in the stack trace is a matmul between a CuArray
and a Transpose
of a OneHotArray
(I didn't dump the full output because it's huge, but I have verified this). So this clearly happens somewhere in reverse diff.
I will see if I can reproduce this a bit more directly, but I wanted to keep the lux example up because I think we should get all those use cases working.
Metadata
Metadata
Assignees
Labels
No labels