Skip to content

Commit 3059520

Browse files
mcabbottToucheSir
andauthored
Add some checks to onecold (#25)
* add sanity checks * Apply suggestions from code review Co-authored-by: Brian Chen <[email protected]> Co-authored-by: Brian Chen <[email protected]>
1 parent 2a41ca1 commit 3059520

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "OneHotArrays"
22
uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
3-
version = "0.2.0"
3+
version = "0.2.1"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/onehot.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,18 @@ julia> onecold([ 1 0 0 1 0 1 0 1 0 0 1
126126
"abeacadabea"
127127
```
128128
"""
129-
onecold(y::AbstractVector, labels = 1:length(y)) = labels[argmax(y)]
129+
function onecold(y::AbstractVector, labels = 1:length(y))
130+
nl = length(labels)
131+
ny = length(y)
132+
nl == ny || throw(DimensionMismatch("onecold got $nl labels for a vector of length $ny, these must agree"))
133+
ymax, i = findmax(y)
134+
ymax isa Number && isnan(ymax) && throw(ArgumentError("maximum value found by onecold is $ymax"))
135+
labels[i]
136+
end
130137
function onecold(y::AbstractArray, labels = 1:size(y, 1))
138+
nl = length(labels)
139+
ny = size(y, 1)
140+
nl == ny || throw(DimensionMismatch("onecold got $nl labels for an array with size(y, 1) == $ny, these must agree"))
131141
indices = _fast_argmax(y)
132142
xs = isbits(labels) ? indices : collect(indices) # non-bit type cannot be handled by CUDA
133143

test/onehot.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ end
5252
cold = onecold(hot, labels)
5353

5454
@test cold == data
55+
56+
@test_throws DimensionMismatch onecold([0.3, 0.2], (:a, :b, :c))
57+
@test_throws DimensionMismatch onecold([0.3, 0.2, 0.5, 0.99], (:a, :b, :c))
58+
@test_throws ArgumentError onecold([0.3, NaN, 0.5], (:a, :b, :c))
5559
end
5660

5761
@testset "onehotbatch indexing" begin

0 commit comments

Comments
 (0)