Skip to content

Commit 5ee901e

Browse files
committed
Making indexing match regular arrays more closely
1 parent 0f2cca1 commit 5ee901e

File tree

3 files changed

+30
-7
lines changed

3 files changed

+30
-7
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1111
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
12+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1213

1314
[compat]
1415
Adapt = "3.0"
@@ -18,6 +19,7 @@ MLUtils = "0.2"
1819
NNlib = "0.8"
1920

2021
[extras]
22+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2123
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2224

2325
[targets]

src/array.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,18 @@ _isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L
3434

3535
Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...)
3636

37-
_onehotindex(x, i) = (x == i)
37+
function Base.getindex(x::OneHotArray{<:Any, <:Any, N}, i::Integer, I::Vararg{Any, N}) where N
38+
@boundscheck checkbounds(x, i, I...)
39+
return x.indices[I...] .== i
40+
end
3841

39-
Base.getindex(x::OneHotVector, i::Integer) = _onehotindex(x.indices, i)
40-
Base.getindex(x::OneHotVector{T, L}, ::Colon) where {T, L} = x
42+
function Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L
43+
@boundscheck checkbounds(x, :, I...)
44+
return OneHotArray(x.indices[I...], L)
45+
end
4146

42-
Base.getindex(x::OneHotArray, i::Integer, I...) = _onehotindex.(x.indices[I...], i)
43-
Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(x.indices[I...], L)
44-
Base.getindex(x::OneHotArray{<:Any, <:Any, <:Any, N}, ::Vararg{Colon, N}) where N = x
45-
Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...]
47+
Base.getindex(x::OneHotArray, ::Colon) = BitVector(reshape(x, :))
48+
Base.getindex(x::OneHotArray{<:Any, <:Any, N}, ::Colon, ::Vararg{Colon, N}) where N = x
4649

4750
function Base.showarg(io::IO, x::OneHotArray, toplevel)
4851
print(io, ndims(x) == 1 ? "OneHotVector(" : ndims(x) == 2 ? "OneHotMatrix(" : "OneHotArray(")

test/array.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
using Random
2+
3+
Random.seed!(0)
4+
15
ov = OneHotVector(rand(1:10), 10)
26
ov2 = OneHotVector(rand(1:11), 11)
37
om = OneHotMatrix(rand(1:10, 5), 10)
@@ -21,6 +25,7 @@ end
2125
@test om[:, 3] == OneHotVector(om.indices[3], 10)
2226
@test om[3, :] == (om.indices .== 3)
2327
@test om[:, :] == om
28+
@test om[:] == reshape(om, :)
2429

2530
# array indexing
2631
@test oa[3, 3, 3] == (oa.indices[3, 3] == 3)
@@ -29,9 +34,22 @@ end
2934
@test oa[3, :, :] == (oa.indices .== 3)
3035
@test oa[:, 3, :] == OneHotMatrix(oa.indices[3, :], 10)
3136
@test oa[:, :, :] == oa
37+
@test oa[:] == reshape(oa, :)
3238

3339
# cartesian indexing
3440
@test oa[CartesianIndex(3, 3, 3)] == oa[3, 3, 3]
41+
42+
# linear indexing
43+
@test om[9] == true
44+
@test om[10] == false
45+
@test om[11] == om[1, 2]
46+
@test oa[52] == oa[2, 1, 2]
47+
48+
# bounds checks
49+
@test_throws BoundsError ov[0]
50+
@test_throws BoundsError om[2, -1]
51+
@test_throws BoundsError oa[11, 5, 5]
52+
@test_throws BoundsError oa[:, :]
3553
end
3654

3755
@testset "Concatenating" begin

0 commit comments

Comments
 (0)