@@ -10,274 +10,8 @@ using NNlib
10
10
export onehot, onehotbatch, onecold, OneHotArray,
11
11
OneHotVector, OneHotMatrix, OneHotLike
12
12
13
- """
14
- OneHotArray{T,L,N,M,I} <: AbstractArray{Bool,M}
15
-
16
- These are constructed by [`onehot`](@ref) and [`onehotbatch`](@ref).
17
- Parameter `I` is the type of the underlying storage, and `T` its eltype.
18
- """
19
- struct OneHotArray{T<: Integer , L, N, var"N+1" , I<: Union{T, AbstractArray{T, N}} } <: AbstractArray{Bool, var"N+1"}
20
- indices:: I
21
- end
22
- OneHotArray {T, L, N, I} (indices) where {T, L, N, I} = OneHotArray {T, L, N, N+1, I} (indices)
23
- OneHotArray (indices:: T , L:: Integer ) where {T<: Integer } = OneHotArray {T, L, 0, 1, T} (indices)
24
- OneHotArray (indices:: I , L:: Integer ) where {T, N, I<: AbstractArray{T, N} } = OneHotArray {T, L, N, N+1, I} (indices)
25
-
26
- _indices (x:: OneHotArray ) = x. indices
27
- _indices (x:: Base.ReshapedArray{<: Any, <: Any, <: OneHotArray} ) =
28
- reshape (parent (x). indices, x. dims[2 : end ])
29
-
30
- const OneHotVector{T, L} = OneHotArray{T, L, 0 , 1 , T}
31
- const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1 , 2 , I}
32
-
33
- @doc @doc (OneHotArray)
34
- OneHotVector (idx, L) = OneHotArray (idx, L)
35
- @doc @doc (OneHotArray)
36
- OneHotMatrix (indices, L) = OneHotArray (indices, L)
37
-
38
- # use this type so reshaped arrays hit fast paths
39
- # e.g. argmax
40
- const OneHotLike{T, L, N, var"N+1" , I} =
41
- Union{OneHotArray{T, L, N, var"N+1" , I},
42
- Base. ReshapedArray{Bool, var"N+1" , <: OneHotArray{T, L, <:Any, <:Any, I} }}
43
-
44
- _isonehot (x:: OneHotArray ) = true
45
- _isonehot (x:: Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}} ) where L = (size (x, 1 ) == L)
46
-
47
- Base. size (x:: OneHotArray{<:Any, L} ) where L = (Int (L), size (x. indices)... )
48
-
49
- _onehotindex (x, i) = (x == i)
50
-
51
- Base. getindex (x:: OneHotVector , i:: Integer ) = _onehotindex (x. indices, i)
52
- Base. getindex (x:: OneHotVector{T, L} , :: Colon ) where {T, L} = x
53
-
54
- Base. getindex (x:: OneHotArray , i:: Integer , I... ) = _onehotindex .(x. indices[I... ], i)
55
- Base. getindex (x:: OneHotArray{<:Any, L} , :: Colon , I... ) where L = OneHotArray (x. indices[I... ], L)
56
- Base. getindex (x:: OneHotArray{<:Any, <:Any, <:Any, N} , :: Vararg{Colon, N} ) where N = x
57
- Base. getindex (x:: OneHotArray , I:: CartesianIndex{N} ) where N = x[I[1 ], Tuple (I)[2 : N]. .. ]
58
-
59
- function Base. showarg (io:: IO , x:: OneHotArray , toplevel)
60
- print (io, ndims (x) == 1 ? " OneHotVector(" : ndims (x) == 2 ? " OneHotMatrix(" : " OneHotArray(" )
61
- Base. showarg (io, x. indices, false )
62
- print (io, ' )' )
63
- toplevel && print (io, " with eltype Bool" )
64
- return nothing
65
- end
66
-
67
- # this is from /LinearAlgebra/src/diagonal.jl, official way to print the dots:
68
- function Base. replace_in_print_matrix (x:: OneHotLike , i:: Integer , j:: Integer , s:: AbstractString )
69
- x[i,j] ? s : _isonehot (x) ? Base. replace_with_centered_mark (s) : s
70
- end
71
-
72
- # copy CuArray versions back before trying to print them:
73
- Base. print_array (io:: IO , X:: OneHotLike{T, L, N, var"N+1", <:CuArray} ) where {T, L, N, var"N+1" } =
74
- Base. print_array (io, adapt (Array, X))
75
- Base. print_array (io:: IO , X:: LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, var"N+1", <:CuArray}} ) where {T, L, N, var"N+1" } =
76
- Base. print_array (io, adapt (Array, X))
77
-
78
- _onehot_bool_type (:: OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}} ) where N = Array{Bool, N}
79
- _onehot_bool_type (:: OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray} ) where N = CuArray{Bool, N}
80
-
81
- function Base. cat (x:: OneHotLike{<:Any, L} , xs:: OneHotLike{<:Any, L} ...; dims:: Int ) where L
82
- if isone (dims) || any (x -> ! _isonehot (x), (x, xs... ))
83
- return cat (map (x -> convert (_onehot_bool_type (x), x), (x, xs... ))... ; dims = dims)
84
- else
85
- return OneHotArray (cat (_indices (x), _indices .(xs)... ; dims = dims - 1 ), L)
86
- end
87
- end
88
-
89
- Base. hcat (x:: OneHotLike , xs:: OneHotLike... ) = cat (x, xs... ; dims = 2 )
90
- Base. vcat (x:: OneHotLike , xs:: OneHotLike... ) = cat (x, xs... ; dims = 1 )
91
-
92
- # optimized concatenation for matrices and vectors of same parameters
93
- Base. hcat (x:: T , xs:: T... ) where {L, T <: OneHotLike{<:Any, L, <:Any, 2} } =
94
- OneHotMatrix (reduce (vcat, _indices .(xs); init = _indices (x)), L)
95
- Base. hcat (x:: T , xs:: T... ) where {L, T <: OneHotLike{<:Any, L, <:Any, 1} } =
96
- OneHotMatrix (reduce (vcat, _indices .(xs); init = _indices (x)), L)
97
-
98
- MLUtils. batch (xs:: AbstractArray{<:OneHotVector{<:Any, L}} ) where L = OneHotMatrix (_indices .(xs), L)
99
-
100
- Adapt. adapt_structure (T, x:: OneHotArray{<:Any, L} ) where L = OneHotArray (adapt (T, _indices (x)), L)
101
-
102
- Base. BroadcastStyle (:: Type{<:OneHotArray{<: Any, <: Any, <: Any, N, <: CuArray}} ) where N = CUDA. CuArrayStyle {N} ()
103
-
104
- Base. map (f, x:: OneHotLike ) = Base. broadcast (f, x)
105
-
106
- Base. argmax (x:: OneHotLike ; dims = Colon ()) =
107
- (_isonehot (x) && dims == 1 ) ?
108
- reshape (CartesianIndex .(_indices (x), CartesianIndices (_indices (x))), 1 , size (_indices (x))... ) :
109
- invoke (argmax, Tuple{AbstractArray}, x; dims = dims)
110
-
111
- """
112
- onehot(x, labels, [default])
113
-
114
- Return a `OneHotVector` which is roughly a sparse representation of `x .== labels`.
115
-
116
- Instead of storing say `Vector{Bool}`, it stores the index of the first occurrence
117
- of `x` in `labels`. If `x` is not found in labels, then it either returns `onehot(default, labels)`,
118
- or gives an error if no default is given.
119
-
120
- See also [`onehotbatch`](@ref) to apply this to many `x`s,
121
- and [`onecold`](@ref) to reverse either of these, as well as to generalise `argmax`.
122
-
123
- # Examples
124
- ```jldoctest
125
- julia> β = Flux.onehot(:b, (:a, :b, :c))
126
- 3-element OneHotVector(::UInt32) with eltype Bool:
127
- ⋅
128
- 1
129
- ⋅
130
-
131
- julia> αβγ = (Flux.onehot(0, 0:2), β, Flux.onehot(:z, [:a, :b, :c], :c)) # uses default
132
- (Bool[1, 0, 0], Bool[0, 1, 0], Bool[0, 0, 1])
133
-
134
- julia> hcat(αβγ...) # preserves sparsity
135
- 3×3 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
136
- 1 ⋅ ⋅
137
- ⋅ 1 ⋅
138
- ⋅ ⋅ 1
139
- ```
140
- """
141
- function onehot (x, labels)
142
- i = _findval (x, labels)
143
- isnothing (i) && error (" Value $x is not in labels" )
144
- OneHotVector {UInt32, length(labels)} (i)
145
- end
146
-
147
- function onehot (x, labels, default)
148
- i = _findval (x, labels)
149
- isnothing (i) && return onehot (default, labels)
150
- OneHotVector {UInt32, length(labels)} (i)
151
- end
152
-
153
- _findval (val, labels) = findfirst (isequal (val), labels)
154
- # Fast unrolled method for tuples:
155
- function _findval (val, labels:: Tuple , i:: Integer = 1 )
156
- ifelse (isequal (val, first (labels)), i, _findval (val, Base. tail (labels), i+ 1 ))
157
- end
158
- _findval (val, labels:: Tuple{} , i:: Integer ) = nothing
159
-
160
- """
161
- onehotbatch(xs, labels, [default])
162
-
163
- Returns a `OneHotMatrix` where `k`th column of the matrix is [`onehot(xs[k], labels)`](@ref onehot).
164
- This is a sparse matrix, which stores just a `Vector{UInt32}` containing the indices of the
165
- nonzero elements.
166
-
167
- If one of the inputs in `xs` is not found in `labels`, that column is `onehot(default, labels)`
168
- if `default` is given, else an error.
169
-
170
- If `xs` has more dimensions, `M = ndims(xs) > 1`, then the result is an
171
- `AbstractArray{Bool, M+1}` which is one-hot along the first dimension,
172
- i.e. `result[:, k...] == onehot(xs[k...], labels)`.
173
-
174
- Note that `xs` can be any iterable, such as a string. And that using a tuple
175
- for `labels` will often speed up construction, certainly for less than 32 classes.
176
-
177
- # Examples
178
- ```jldoctest
179
- julia> oh = Flux.onehotbatch("abracadabra", 'a':'e', 'e')
180
- 5×11 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
181
- 1 ⋅ ⋅ 1 ⋅ 1 ⋅ 1 ⋅ ⋅ 1
182
- ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅
183
- ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
184
- ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅
185
- ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅
186
-
187
- julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficiently
188
- 3×11 Matrix{Int64}:
189
- 1 4 13 1 7 1 10 1 4 13 1
190
- 2 5 14 2 8 2 11 2 5 14 2
191
- 3 6 15 3 9 3 12 3 6 15 3
192
- ```
193
- """
194
- onehotbatch (ls, labels, default... ) = _onehotbatch (ls, length (labels) < 32 ? Tuple (labels) : labels, default... )
195
- # NB function barier:
196
- _onehotbatch (ls, labels, default... ) = batch ([onehot (l, labels, default... ) for l in ls])
197
-
198
- """
199
- onecold(y::AbstractArray, labels = 1:size(y,1))
200
-
201
- Roughly the inverse operation of [`onehot`](@ref) or [`onehotbatch`](@ref):
202
- This finds the index of the largest element of `y`, or each column of `y`,
203
- and looks them up in `labels`.
204
-
205
- If `labels` are not specified, the default is integers `1:size(y,1)` --
206
- the same operation as `argmax(y, dims=1)` but sometimes a different return type.
207
-
208
- # Examples
209
- ```jldoctest
210
- julia> Flux.onecold([false, true, false])
211
- 2
212
-
213
- julia> Flux.onecold([0.3, 0.2, 0.5], (:a, :b, :c))
214
- :c
215
-
216
- julia> Flux.onecold([ 1 0 0 1 0 1 0 1 0 0 1
217
- 0 1 0 0 0 0 0 0 1 0 0
218
- 0 0 0 0 1 0 0 0 0 0 0
219
- 0 0 0 0 0 0 1 0 0 0 0
220
- 0 0 1 0 0 0 0 0 0 1 0 ], 'a':'e') |> String
221
- "abeacadabea"
222
- ```
223
- """
224
- onecold (y:: AbstractVector , labels = 1 : length (y)) = labels[argmax (y)]
225
- function onecold (y:: AbstractArray , labels = 1 : size (y, 1 ))
226
- indices = _fast_argmax (y)
227
- xs = isbits (labels) ? indices : collect (indices) # non-bit type cannot be handled by CUDA
228
-
229
- return map (xi -> labels[xi[1 ]], xs)
230
- end
231
-
232
- _fast_argmax (x:: AbstractArray ) = dropdims (argmax (x; dims = 1 ); dims = 1 )
233
- function _fast_argmax (x:: OneHotLike )
234
- if _isonehot (x)
235
- return _indices (x)
236
- else
237
- return _fast_argmax (convert (_onehot_bool_type (x), x))
238
- end
239
- end
240
-
241
- ChainRulesCore. @non_differentiable onehot (:: Any... )
242
- ChainRulesCore. @non_differentiable onehotbatch (:: Any... )
243
- ChainRulesCore. @non_differentiable onecold (:: Any... )
244
-
245
- ChainRulesCore. @non_differentiable (:: Type{<:OneHotArray} )(indices:: Any , L:: Integer )
246
-
247
- function Base.:(* )(A:: AbstractMatrix , B:: OneHotLike{<:Any, L} ) where L
248
- _isonehot (B) || return invoke (* , Tuple{AbstractMatrix, AbstractMatrix}, A, B)
249
- size (A, 2 ) == L || throw (DimensionMismatch (" Matrix column must correspond with OneHot size: $(size (A, 2 )) != $L " ))
250
- return A[:, onecold (B)]
251
- end
252
-
253
- function Base.:(* )(A:: AbstractMatrix , B:: OneHotLike{<:Any, L, 1} ) where L
254
- _isonehot (B) || return invoke (* , Tuple{AbstractMatrix, AbstractMatrix}, A, B)
255
- size (A, 2 ) == L || throw (DimensionMismatch (" Matrix column must correspond with OneHot size: $(size (A, 2 )) != $L " ))
256
- return NNlib. gather (A, _indices (B))
257
- end
258
-
259
- function Base.:(* )(A:: AbstractMatrix , B:: Adjoint{Bool, <:OneHotMatrix} )
260
- B_dim = length (_indices (parent (B)))
261
- size (A, 2 ) == B_dim || throw (DimensionMismatch (" Matrix column must correspond with OneHot size: $(size (A, 2 )) != $B_dim " ))
262
- return NNlib. scatter (+ , A, _indices (parent (B)), dstsize= (size (A,1 ), size (B,2 )))
263
- end
264
-
265
- for wrapper in [:Adjoint , :Transpose ]
266
- @eval begin
267
- function Base.:* (A:: $wrapper{<:Any, <:AbstractMatrix{T}} , b:: OneHotVector{<:Any, L} ) where {L, T}
268
- size (A, 2 ) == L ||
269
- throw (DimensionMismatch (" Matrix column must correspond with OneHot size: $(size (A, 2 )) != $L " ))
270
-
271
- return A[:, onecold (b)]
272
- end
273
-
274
- function Base.:* (A:: $wrapper{<:Number, <:AbstractVector{T}} , b:: OneHotVector{<:Any, L} ) where {L, T}
275
- size (A, 2 ) == L ||
276
- throw (DimensionMismatch (" Matrix column must correspond with OneHot size: $(size (A, 2 )) != $L " ))
277
-
278
- return A[onecold (b)]
279
- end
280
- end
281
- end
13
+ include (" array.jl" )
14
+ include (" onehot.jl" )
15
+ include (" linalg.jl" )
282
16
283
17
end
0 commit comments