Skip to content

Commit 5c312a0

Browse files
authored
Merge pull request #7 from TLipede/indexing-changes
2 parents 0f2cca1 + c022b9c commit 5c312a0

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1212

1313
[compat]
1414
Adapt = "3.0"
15-
ChainRulesCore = "1.13"
1615
CUDA = "3.8"
16+
ChainRulesCore = "1.13"
1717
MLUtils = "0.2"
1818
NNlib = "0.8"
1919

src/array.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,18 @@ _isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L
3434

3535
Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...)
3636

37-
_onehotindex(x, i) = (x == i)
37+
function Base.getindex(x::OneHotArray{<:Any, <:Any, N}, i::Integer, I::Vararg{Any, N}) where N
38+
@boundscheck checkbounds(x, i, I...)
39+
return x.indices[I...] .== i
40+
end
3841

39-
Base.getindex(x::OneHotVector, i::Integer) = _onehotindex(x.indices, i)
40-
Base.getindex(x::OneHotVector{T, L}, ::Colon) where {T, L} = x
42+
function Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L
43+
@boundscheck checkbounds(x, :, I...)
44+
return OneHotArray(x.indices[I...], L)
45+
end
4146

42-
Base.getindex(x::OneHotArray, i::Integer, I...) = _onehotindex.(x.indices[I...], i)
43-
Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(x.indices[I...], L)
44-
Base.getindex(x::OneHotArray{<:Any, <:Any, <:Any, N}, ::Vararg{Colon, N}) where N = x
45-
Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...]
47+
Base.getindex(x::OneHotArray, ::Colon) = BitVector(reshape(x, :))
48+
Base.getindex(x::OneHotArray{<:Any, <:Any, N}, ::Colon, ::Vararg{Colon, N}) where N = x
4649

4750
function Base.showarg(io::IO, x::OneHotArray, toplevel)
4851
print(io, ndims(x) == 1 ? "OneHotVector(" : ndims(x) == 2 ? "OneHotMatrix(" : "OneHotArray(")

test/array.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ end
2121
@test om[:, 3] == OneHotVector(om.indices[3], 10)
2222
@test om[3, :] == (om.indices .== 3)
2323
@test om[:, :] == om
24+
@test om[:] == reshape(om, :)
2425

2526
# array indexing
2627
@test oa[3, 3, 3] == (oa.indices[3, 3] == 3)
@@ -29,9 +30,20 @@ end
2930
@test oa[3, :, :] == (oa.indices .== 3)
3031
@test oa[:, 3, :] == OneHotMatrix(oa.indices[3, :], 10)
3132
@test oa[:, :, :] == oa
33+
@test oa[:] == reshape(oa, :)
3234

3335
# cartesian indexing
3436
@test oa[CartesianIndex(3, 3, 3)] == oa[3, 3, 3]
37+
38+
# linear indexing
39+
@test om[11] == om[1, 2]
40+
@test oa[52] == oa[2, 1, 2]
41+
42+
# bounds checks
43+
@test_throws BoundsError ov[0]
44+
@test_throws BoundsError om[2, -1]
45+
@test_throws BoundsError oa[11, 5, 5]
46+
@test_throws BoundsError oa[:, :]
3547
end
3648

3749
@testset "Concatenating" begin

0 commit comments

Comments
 (0)