@@ -81,9 +81,24 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl
81
81
3 6 15 3 9 3 12 3 6 15 3
82
82
```
83
83
"""
84
- onehotbatch (ls, labels, default... ) = _onehotbatch (ls, length (labels) < 32 ? Tuple (labels) : labels, default... )
85
- # NB function barier:
86
- _onehotbatch (ls, labels, default... ) = batch ([onehot (l, labels, default... ) for l in ls])
84
+ onehotbatch (data, labels, default... ) = _onehotbatch (data, length (labels) < 32 ? Tuple (labels) : labels, default... )
85
+
86
+ function _onehotbatch (data, labels)
87
+ indices = UInt32[something (_findval (i, labels), 0 ) for i in data]
88
+ if 0 in indices
89
+ for x in data
90
+ isnothing (_findval (x, labels)) && error (" Value $x not found in labels" )
91
+ end
92
+ end
93
+ return OneHotArray (indices, length (labels))
94
+ end
95
+
96
+ function _onehotbatch (data, labels, default)
97
+ default_index = _findval (default, labels)
98
+ isnothing (default_index) && error (" Default value $default is not in labels" )
99
+ indices = UInt32[something (_findval (i, labels), default_index) for i in data]
100
+ return OneHotArray (indices, length (labels))
101
+ end
87
102
88
103
"""
89
104
onecold(y::AbstractArray, labels = 1:size(y,1))
0 commit comments