Skip to content
69 changes: 50 additions & 19 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ end
_findval(val, labels::Tuple{}, i::Integer) = nothing

"""
onehotbatch(xs, labels, [default])
onehotbatch(xs, labels, [default]; dims::Val{D}=Val{1})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
onehotbatch(xs, labels, [default]; dims::Val{D}=Val{1})
onehotbatch(xs, labels, [default]; dims = Val(1))


Returns a [`OneHotMatrix`](@ref) where `k`th column of the matrix is [`onehot(xs[k], labels)`](@ref onehot).
This is a sparse matrix, which stores just a `Vector{UInt32}` containing the indices of the
Expand All @@ -64,6 +64,8 @@ i.e. `result[:, k...] == onehot(xs[k...], labels)`.
Note that `xs` can be any iterable, such as a string. And that using a tuple
for `labels` will often speed up construction, certainly for less than 32 classes.

If dims keyword is given, the onehot vectors lie on the [dims] dimension rather than the first one.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If dims keyword is given, the onehot vectors lie on the [dims] dimension rather than the first one.
If dims keyword is given, the onehot vectors lie on the `dims` dimension rather than the first one. `dims` should be provided as a `Val` to guarantee type stability (but an plain integer is valid as well).


# Examples
```jldoctest
julia> oh = onehotbatch("abracadabra", 'a':'e', 'e')
Expand All @@ -79,44 +81,73 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl
1 4 13 1 7 1 10 1 4 13 1
2 5 14 2 8 2 11 2 5 14 2
3 6 15 3 9 3 12 3 6 15 3

# One hot vectors on the second axis
julia> onehotbatch([0, 0, 7], 0:9; dims=Val(2))
3×10 PermutedDimsArray(OneHotMatrix(::Vector{UInt32}), (2, 1)) with eltype Bool:
1 0 0 0 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 1 0 0
```
"""
onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suggestion for how to write this would be this. Add the keyword dims but leave the basic path as close to untouched as you can, like so:

onehotbatch(data, labels, default...; dims=Val(1)) = _onehotbatch(dims, data, length(labels) < 32 ? Tuple(labels) : labels, default...)

function _onehotbatch(::Val{1}, data, labels)
  # as before
   return OneHotArray(indices, length(labels))
end
function _onehotbatch(::Val{1}, data, labels, default)
  # as before
  return OneHotArray(indices, length(labels))
end

In particular, this does not call collect(data), as this shouldn't be necessary, we can just iterate things.

Readers uninterested in permutations can stop there. But to handle them, make it obvious that we call the same path, and then permute it.

_onehotbatch(dims::Integer, data, labels, default...) = _onehotbatch(Val(dims), data, labels, default...)
_onehotbatch(dims::Val, data, labels, default...) = _permute(dims, _onehotbatch(Val(1), data, labels, default...))

_permute(::Val{2}, array::OneHotArray{<:Any, 1, 2}) = transpose(array)
function _permute(::Val{d}, array::OneHotArray{<:Any, N,M}) where {d, N, M}
  # this is where you compute perm, can use N or M, I forget...
  PermutedDimsArray(array, perm)
end

I made a special case for transpose, as I think that's always preferable to PermutedDimsArray.


function _onehotbatch(data, labels)
indices = UInt32[something(_findval(i, labels), 0) for i in data]
if 0 in indices
for x in data
isnothing(_findval(x, labels)) && error("Value $x not found in labels")
end
end
return OneHotArray(indices, length(labels))
end
# developer note:
# onehotbatch is intended as the api and includes bounds checks
# _onehotbatch is intended as the implementation which includes membership checks
# _onehotbatch_fast same as above but without membership checks which would be slow on GPU

function _onehotbatch(data, labels, default)
default_index = _findval(default, labels)
isnothing(default_index) && error("Default value $default is not in labels")
indices = UInt32[something(_findval(i, labels), default_index) for i in data]
return OneHotArray(indices, length(labels))
function onehotbatch(data::String, labels, default...; dims::Val{D} = Val(1)) where D
_onehotbatch(dims, data, length(labels) < 32 ? Tuple(labels) : labels, default...)
end

function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<:Integer})
function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<:Integer}, default...; dims::Val{D} = Val(1)) where D
lo, hi = extrema(data)
lo < first(labels) && error("Value $lo not found in labels")
hi > last(labels) && error("Value $hi not found in labels")
offset = 1 - first(labels)
indices = UInt32.(data .+ offset)
return OneHotArray(indices, length(labels))
_onehotbatch(dims, indices, length(labels) < 32 ? Tuple(labels) : labels)
end

# That bounds check with extrema synchronises on GPU, much slower than rest of the function,
# hence add a special method, with a less helpful error message:
function onehotbatch(data::AbstractGPUArray{<:Integer}, labels::AbstractUnitRange{<:Integer})
function onehotbatch(data::AbstractGPUArray{<:Integer}, labels::AbstractUnitRange{<:Integer}, default...; dims::Val{D} = Val(1)) where D
offset = 1 - first(labels)
indices = map(data) do datum
i = UInt32(datum + offset)
checkbounds(labels, i)
i
end
_onehotbatch_fast(dims, indices, length(labels) < 32 ? Tuple(labels) : labels)
end
# _onehotbatch_fast does not have the bounds checks in _onehotbatch which would slow down GPU, but allows permute
_onehotbatch_fast(dims::Val{D}, indices, labels) where D = _permute(dims, _onehotbatch_fast(Val(1), indices, labels))
_onehotbatch_fast(::Val{1}, indices, labels) = OneHotArray(indices, length(labels))

_onehotbatch(dims::Val, data, labels, default...) = _permute(dims, _onehotbatch(Val(1), data, labels, default...))

_permute(::Val{2}, array::OneHotArray{<:Any, 1, 2}) = transpose(array)
function _permute(::Val{d}, array::OneHotArray{<:Any, N,M}) where {d, N, M}
perm = Tuple(ntuple(d -> d==D ? 1 : (d==1 ? D : d), M))
# need to use obtuse PermutedDimsArray constructor in order to stabilise permuation types
iperm = invperm(perm)
PermutedDimsArray{eltype(out),M,(perm...,),(iperm...,),typeof(out)}(out)
end

function _onehotbatch(::Val{1}, data, labels)
indices = UInt32[something(_findval(i, labels), 0) for i in data]
if 0 in indices
for x in data
isnothing(_findval(x, labels)) && error("Value $x not found in labels")
end
end
return OneHotArray(indices, length(labels))
end

function _onehotbatch(::Val{1}, data, labels, default)
default_index = _findval(default, labels)
isnothing(default_index) && error("Default value $default is not in labels")
indices = UInt32[something(_findval(i, labels), default_index) for i in data]
return OneHotArray(indices, length(labels))
end

Expand Down
14 changes: 14 additions & 0 deletions test/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,17 @@ end
@test y[:,1] isa OneHotVector
@test y[:,:] isa OneHotMatrix
end

@testset "onehotbatch dims" begin
# basic tests
@test onehotbatch([20, 10], 10:10:30; dims=Val(2)) == Bool[0 1 0; 1 0 0]
@test onehotbatch([10, 20], [30, 40, 50], 30; dims=Val(2)) == Bool[1 0 0; 1 0 0]
# higher dimensions
@test size(onehotbatch(reshape(collect(1:12), 3, 4), 1:12; dims=Val(2))) == (3, 12, 4) # test shape
@test sum(onehotbatch(reshape(collect(1:12), 3, 4), 1:12; dims=Val(2)), dims=2)[:] == ones(12) # test onehot on the second dim
# works with strings
@test onehotbatch("ba", 'a':'c'; dims=Val(2)) == Bool[0 1 0; 1 0 0]

@test @inferred(onehotbatch([20, 10], 10:10:30; dims=Val(2))) == Bool[0 1 0; 1 0 0]
@test @inferred(onehotbatch([40, 10], (10,20,30), 20; dims=Val(2))) == Bool[0 1 0; 1 0 0]
end