Skip to content

Commit 0f2cca1

Browse files
authored
Merge pull request #5 from TLipede/file-rearrange
Moving code into separate files
2 parents 8b85adb + df7f974 commit 0f2cca1

File tree

8 files changed

+468
-460
lines changed

8 files changed

+468
-460
lines changed

src/OneHotArrays.jl

Lines changed: 3 additions & 269 deletions
Original file line numberDiff line numberDiff line change
@@ -10,274 +10,8 @@ using NNlib
1010
export onehot, onehotbatch, onecold, OneHotArray,
1111
OneHotVector, OneHotMatrix, OneHotLike
1212

13-
"""
14-
OneHotArray{T,L,N,M,I} <: AbstractArray{Bool,M}
15-
16-
These are constructed by [`onehot`](@ref) and [`onehotbatch`](@ref).
17-
Parameter `I` is the type of the underlying storage, and `T` its eltype.
18-
"""
19-
struct OneHotArray{T<:Integer, L, N, var"N+1", I<:Union{T, AbstractArray{T, N}}} <: AbstractArray{Bool, var"N+1"}
20-
indices::I
21-
end
22-
OneHotArray{T, L, N, I}(indices) where {T, L, N, I} = OneHotArray{T, L, N, N+1, I}(indices)
23-
OneHotArray(indices::T, L::Integer) where {T<:Integer} = OneHotArray{T, L, 0, 1, T}(indices)
24-
OneHotArray(indices::I, L::Integer) where {T, N, I<:AbstractArray{T, N}} = OneHotArray{T, L, N, N+1, I}(indices)
25-
26-
_indices(x::OneHotArray) = x.indices
27-
_indices(x::Base.ReshapedArray{<: Any, <: Any, <: OneHotArray}) =
28-
reshape(parent(x).indices, x.dims[2:end])
29-
30-
const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T}
31-
const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I}
32-
33-
@doc @doc(OneHotArray)
34-
OneHotVector(idx, L) = OneHotArray(idx, L)
35-
@doc @doc(OneHotArray)
36-
OneHotMatrix(indices, L) = OneHotArray(indices, L)
37-
38-
# use this type so reshaped arrays hit fast paths
39-
# e.g. argmax
40-
const OneHotLike{T, L, N, var"N+1", I} =
41-
Union{OneHotArray{T, L, N, var"N+1", I},
42-
Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L, <:Any, <:Any, I}}}
43-
44-
_isonehot(x::OneHotArray) = true
45-
_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L = (size(x, 1) == L)
46-
47-
Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...)
48-
49-
_onehotindex(x, i) = (x == i)
50-
51-
Base.getindex(x::OneHotVector, i::Integer) = _onehotindex(x.indices, i)
52-
Base.getindex(x::OneHotVector{T, L}, ::Colon) where {T, L} = x
53-
54-
Base.getindex(x::OneHotArray, i::Integer, I...) = _onehotindex.(x.indices[I...], i)
55-
Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(x.indices[I...], L)
56-
Base.getindex(x::OneHotArray{<:Any, <:Any, <:Any, N}, ::Vararg{Colon, N}) where N = x
57-
Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...]
58-
59-
function Base.showarg(io::IO, x::OneHotArray, toplevel)
60-
print(io, ndims(x) == 1 ? "OneHotVector(" : ndims(x) == 2 ? "OneHotMatrix(" : "OneHotArray(")
61-
Base.showarg(io, x.indices, false)
62-
print(io, ')')
63-
toplevel && print(io, " with eltype Bool")
64-
return nothing
65-
end
66-
67-
# this is from /LinearAlgebra/src/diagonal.jl, official way to print the dots:
68-
function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::AbstractString)
69-
x[i,j] ? s : _isonehot(x) ? Base.replace_with_centered_mark(s) : s
70-
end
71-
72-
# copy CuArray versions back before trying to print them:
73-
Base.print_array(io::IO, X::OneHotLike{T, L, N, var"N+1", <:CuArray}) where {T, L, N, var"N+1"} =
74-
Base.print_array(io, adapt(Array, X))
75-
Base.print_array(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, var"N+1", <:CuArray}}) where {T, L, N, var"N+1"} =
76-
Base.print_array(io, adapt(Array, X))
77-
78-
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
79-
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}
80-
81-
function Base.cat(x::OneHotLike{<:Any, L}, xs::OneHotLike{<:Any, L}...; dims::Int) where L
82-
if isone(dims) || any(x -> !_isonehot(x), (x, xs...))
83-
return cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims)
84-
else
85-
return OneHotArray(cat(_indices(x), _indices.(xs)...; dims = dims - 1), L)
86-
end
87-
end
88-
89-
Base.hcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 2)
90-
Base.vcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 1)
91-
92-
# optimized concatenation for matrices and vectors of same parameters
93-
Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 2}} =
94-
OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L)
95-
Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 1}} =
96-
OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L)
97-
98-
MLUtils.batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotMatrix(_indices.(xs), L)
99-
100-
Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)
101-
102-
Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, N, <: CuArray}}) where N = CUDA.CuArrayStyle{N}()
103-
104-
Base.map(f, x::OneHotLike) = Base.broadcast(f, x)
105-
106-
Base.argmax(x::OneHotLike; dims = Colon()) =
107-
(_isonehot(x) && dims == 1) ?
108-
reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) :
109-
invoke(argmax, Tuple{AbstractArray}, x; dims = dims)
110-
111-
"""
112-
onehot(x, labels, [default])
113-
114-
Return a `OneHotVector` which is roughly a sparse representation of `x .== labels`.
115-
116-
Instead of storing say `Vector{Bool}`, it stores the index of the first occurrence
117-
of `x` in `labels`. If `x` is not found in labels, then it either returns `onehot(default, labels)`,
118-
or gives an error if no default is given.
119-
120-
See also [`onehotbatch`](@ref) to apply this to many `x`s,
121-
and [`onecold`](@ref) to reverse either of these, as well as to generalise `argmax`.
122-
123-
# Examples
124-
```jldoctest
125-
julia> β = Flux.onehot(:b, (:a, :b, :c))
126-
3-element OneHotVector(::UInt32) with eltype Bool:
127-
128-
1
129-
130-
131-
julia> αβγ = (Flux.onehot(0, 0:2), β, Flux.onehot(:z, [:a, :b, :c], :c)) # uses default
132-
(Bool[1, 0, 0], Bool[0, 1, 0], Bool[0, 0, 1])
133-
134-
julia> hcat(αβγ...) # preserves sparsity
135-
3×3 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
136-
1 ⋅ ⋅
137-
⋅ 1 ⋅
138-
⋅ ⋅ 1
139-
```
140-
"""
141-
function onehot(x, labels)
142-
i = _findval(x, labels)
143-
isnothing(i) && error("Value $x is not in labels")
144-
OneHotVector{UInt32, length(labels)}(i)
145-
end
146-
147-
function onehot(x, labels, default)
148-
i = _findval(x, labels)
149-
isnothing(i) && return onehot(default, labels)
150-
OneHotVector{UInt32, length(labels)}(i)
151-
end
152-
153-
_findval(val, labels) = findfirst(isequal(val), labels)
154-
# Fast unrolled method for tuples:
155-
function _findval(val, labels::Tuple, i::Integer=1)
156-
ifelse(isequal(val, first(labels)), i, _findval(val, Base.tail(labels), i+1))
157-
end
158-
_findval(val, labels::Tuple{}, i::Integer) = nothing
159-
160-
"""
161-
onehotbatch(xs, labels, [default])
162-
163-
Returns a `OneHotMatrix` where `k`th column of the matrix is [`onehot(xs[k], labels)`](@ref onehot).
164-
This is a sparse matrix, which stores just a `Vector{UInt32}` containing the indices of the
165-
nonzero elements.
166-
167-
If one of the inputs in `xs` is not found in `labels`, that column is `onehot(default, labels)`
168-
if `default` is given, else an error.
169-
170-
If `xs` has more dimensions, `M = ndims(xs) > 1`, then the result is an
171-
`AbstractArray{Bool, M+1}` which is one-hot along the first dimension,
172-
i.e. `result[:, k...] == onehot(xs[k...], labels)`.
173-
174-
Note that `xs` can be any iterable, such as a string. And that using a tuple
175-
for `labels` will often speed up construction, certainly for less than 32 classes.
176-
177-
# Examples
178-
```jldoctest
179-
julia> oh = Flux.onehotbatch("abracadabra", 'a':'e', 'e')
180-
5×11 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
181-
1 ⋅ ⋅ 1 ⋅ 1 ⋅ 1 ⋅ ⋅ 1
182-
⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅
183-
⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
184-
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅
185-
⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅
186-
187-
julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficiently
188-
3×11 Matrix{Int64}:
189-
1 4 13 1 7 1 10 1 4 13 1
190-
2 5 14 2 8 2 11 2 5 14 2
191-
3 6 15 3 9 3 12 3 6 15 3
192-
```
193-
"""
194-
onehotbatch(ls, labels, default...) = _onehotbatch(ls, length(labels) < 32 ? Tuple(labels) : labels, default...)
195-
# NB function barier:
196-
_onehotbatch(ls, labels, default...) = batch([onehot(l, labels, default...) for l in ls])
197-
198-
"""
199-
onecold(y::AbstractArray, labels = 1:size(y,1))
200-
201-
Roughly the inverse operation of [`onehot`](@ref) or [`onehotbatch`](@ref):
202-
This finds the index of the largest element of `y`, or each column of `y`,
203-
and looks them up in `labels`.
204-
205-
If `labels` are not specified, the default is integers `1:size(y,1)` --
206-
the same operation as `argmax(y, dims=1)` but sometimes a different return type.
207-
208-
# Examples
209-
```jldoctest
210-
julia> Flux.onecold([false, true, false])
211-
2
212-
213-
julia> Flux.onecold([0.3, 0.2, 0.5], (:a, :b, :c))
214-
:c
215-
216-
julia> Flux.onecold([ 1 0 0 1 0 1 0 1 0 0 1
217-
0 1 0 0 0 0 0 0 1 0 0
218-
0 0 0 0 1 0 0 0 0 0 0
219-
0 0 0 0 0 0 1 0 0 0 0
220-
0 0 1 0 0 0 0 0 0 1 0 ], 'a':'e') |> String
221-
"abeacadabea"
222-
```
223-
"""
224-
onecold(y::AbstractVector, labels = 1:length(y)) = labels[argmax(y)]
225-
function onecold(y::AbstractArray, labels = 1:size(y, 1))
226-
indices = _fast_argmax(y)
227-
xs = isbits(labels) ? indices : collect(indices) # non-bit type cannot be handled by CUDA
228-
229-
return map(xi -> labels[xi[1]], xs)
230-
end
231-
232-
_fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1)
233-
function _fast_argmax(x::OneHotLike)
234-
if _isonehot(x)
235-
return _indices(x)
236-
else
237-
return _fast_argmax(convert(_onehot_bool_type(x), x))
238-
end
239-
end
240-
241-
ChainRulesCore.@non_differentiable onehot(::Any...)
242-
ChainRulesCore.@non_differentiable onehotbatch(::Any...)
243-
ChainRulesCore.@non_differentiable onecold(::Any...)
244-
245-
ChainRulesCore.@non_differentiable (::Type{<:OneHotArray})(indices::Any, L::Integer)
246-
247-
function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L
248-
_isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)
249-
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
250-
return A[:, onecold(B)]
251-
end
252-
253-
function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L, 1}) where L
254-
_isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)
255-
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
256-
return NNlib.gather(A, _indices(B))
257-
end
258-
259-
function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotMatrix})
260-
B_dim = length(_indices(parent(B)))
261-
size(A, 2) == B_dim || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $B_dim"))
262-
return NNlib.scatter(+, A, _indices(parent(B)), dstsize=(size(A,1), size(B,2)))
263-
end
264-
265-
for wrapper in [:Adjoint, :Transpose]
266-
@eval begin
267-
function Base.:*(A::$wrapper{<:Any, <:AbstractMatrix{T}}, b::OneHotVector{<:Any, L}) where {L, T}
268-
size(A, 2) == L ||
269-
throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
270-
271-
return A[:, onecold(b)]
272-
end
273-
274-
function Base.:*(A::$wrapper{<:Number, <:AbstractVector{T}}, b::OneHotVector{<:Any, L}) where {L, T}
275-
size(A, 2) == L ||
276-
throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
277-
278-
return A[onecold(b)]
279-
end
280-
end
281-
end
13+
include("array.jl")
14+
include("onehot.jl")
15+
include("linalg.jl")
28216

28317
end

src/array.jl

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""
2+
OneHotArray{T,L,N,M,I} <: AbstractArray{Bool,M}
3+
4+
These are constructed by [`onehot`](@ref) and [`onehotbatch`](@ref).
5+
Parameter `I` is the type of the underlying storage, and `T` its eltype.
6+
"""
7+
struct OneHotArray{T<:Integer, L, N, var"N+1", I<:Union{T, AbstractArray{T, N}}} <: AbstractArray{Bool, var"N+1"}
8+
indices::I
9+
end
10+
OneHotArray{T, L, N, I}(indices) where {T, L, N, I} = OneHotArray{T, L, N, N+1, I}(indices)
11+
OneHotArray(indices::T, L::Integer) where {T<:Integer} = OneHotArray{T, L, 0, 1, T}(indices)
12+
OneHotArray(indices::I, L::Integer) where {T, N, I<:AbstractArray{T, N}} = OneHotArray{T, L, N, N+1, I}(indices)
13+
14+
_indices(x::OneHotArray) = x.indices
15+
_indices(x::Base.ReshapedArray{<: Any, <: Any, <: OneHotArray}) =
16+
reshape(parent(x).indices, x.dims[2:end])
17+
18+
const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T}
19+
const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I}
20+
21+
@doc @doc(OneHotArray)
22+
OneHotVector(idx, L) = OneHotArray(idx, L)
23+
@doc @doc(OneHotArray)
24+
OneHotMatrix(indices, L) = OneHotArray(indices, L)
25+
26+
# use this type so reshaped arrays hit fast paths
27+
# e.g. argmax
28+
const OneHotLike{T, L, N, var"N+1", I} =
29+
Union{OneHotArray{T, L, N, var"N+1", I},
30+
Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L, <:Any, <:Any, I}}}
31+
32+
_isonehot(x::OneHotArray) = true
33+
_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L = (size(x, 1) == L)
34+
35+
Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...)
36+
37+
_onehotindex(x, i) = (x == i)
38+
39+
Base.getindex(x::OneHotVector, i::Integer) = _onehotindex(x.indices, i)
40+
Base.getindex(x::OneHotVector{T, L}, ::Colon) where {T, L} = x
41+
42+
Base.getindex(x::OneHotArray, i::Integer, I...) = _onehotindex.(x.indices[I...], i)
43+
Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(x.indices[I...], L)
44+
Base.getindex(x::OneHotArray{<:Any, <:Any, <:Any, N}, ::Vararg{Colon, N}) where N = x
45+
Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...]
46+
47+
function Base.showarg(io::IO, x::OneHotArray, toplevel)
48+
print(io, ndims(x) == 1 ? "OneHotVector(" : ndims(x) == 2 ? "OneHotMatrix(" : "OneHotArray(")
49+
Base.showarg(io, x.indices, false)
50+
print(io, ')')
51+
toplevel && print(io, " with eltype Bool")
52+
return nothing
53+
end
54+
55+
# this is from /LinearAlgebra/src/diagonal.jl, official way to print the dots:
56+
function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::AbstractString)
57+
x[i,j] ? s : _isonehot(x) ? Base.replace_with_centered_mark(s) : s
58+
end
59+
60+
# copy CuArray versions back before trying to print them:
61+
Base.print_array(io::IO, X::OneHotLike{T, L, N, var"N+1", <:CuArray}) where {T, L, N, var"N+1"} =
62+
Base.print_array(io, adapt(Array, X))
63+
Base.print_array(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, var"N+1", <:CuArray}}) where {T, L, N, var"N+1"} =
64+
Base.print_array(io, adapt(Array, X))
65+
66+
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
67+
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}
68+
69+
function Base.cat(x::OneHotLike{<:Any, L}, xs::OneHotLike{<:Any, L}...; dims::Int) where L
70+
if isone(dims) || any(x -> !_isonehot(x), (x, xs...))
71+
return cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims)
72+
else
73+
return OneHotArray(cat(_indices(x), _indices.(xs)...; dims = dims - 1), L)
74+
end
75+
end
76+
77+
Base.hcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 2)
78+
Base.vcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 1)
79+
80+
# optimized concatenation for matrices and vectors of same parameters
81+
Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 2}} =
82+
OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L)
83+
Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 1}} =
84+
OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L)
85+
86+
MLUtils.batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotMatrix(_indices.(xs), L)
87+
88+
Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)
89+
90+
Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, N, <: CuArray}}) where N = CUDA.CuArrayStyle{N}()
91+
92+
Base.map(f, x::OneHotLike) = Base.broadcast(f, x)
93+
94+
Base.argmax(x::OneHotLike; dims = Colon()) =
95+
(_isonehot(x) && dims == 1) ?
96+
reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) :
97+
invoke(argmax, Tuple{AbstractArray}, x; dims = dims)

0 commit comments

Comments
 (0)