Skip to content

Commit 35b8bc2

Browse files
committed
Speeding up onehotbatch
1 parent 0f2cca1 commit 35b8bc2

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

src/onehot.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,24 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl
8181
3 6 15 3 9 3 12 3 6 15 3
8282
```
8383
"""
84-
onehotbatch(ls, labels, default...) = _onehotbatch(ls, length(labels) < 32 ? Tuple(labels) : labels, default...)
85-
# NB function barier:
86-
_onehotbatch(ls, labels, default...) = batch([onehot(l, labels, default...) for l in ls])
84+
onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...)
85+
86+
function _onehotbatch(data, labels)
87+
indices = UInt32[something(_findval(i, labels), 0) for i in data]
88+
if 0 in indices
89+
for x in data
90+
isnothing(_findval(x, labels)) && error("Value $x not found in labels")
91+
end
92+
end
93+
return OneHotArray(indices, length(labels))
94+
end
95+
96+
function _onehotbatch(data, labels, default)
97+
default_index = _findval(default, labels)
98+
isnothing(default_index) && error("Default value $default is not in labels")
99+
indices = UInt32[something(_findval(i, labels), default_index) for i in data]
100+
return OneHotArray(indices, length(labels))
101+
end
87102

88103
"""
89104
onecold(y::AbstractArray, labels = 1:size(y,1))

test/onehot.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
@test onehotbatch("abc", 'a':'c') == Bool[1 0 0; 0 1 0; 0 0 1]
1616
@test onehotbatch("zbc", ('a', 'b', 'c'), 'a') == Bool[1 0 0; 0 1 0; 0 0 1]
1717

18+
@test onehotbatch([10, 20], [30, 40, 50], 30) == Bool[1 1; 0 0; 0 0]
19+
1820
@test_throws Exception onehotbatch([:a, :d], [:a, :b, :c])
1921
@test_throws Exception onehotbatch([:a, :d], (:a, :b, :c))
2022
@test_throws Exception onehotbatch([:a, :d], [:a, :b, :c], :e)

0 commit comments

Comments
 (0)