@@ -34,15 +34,18 @@ _isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L
34
34
35
35
Base. size (x:: OneHotArray{<:Any, L} ) where L = (Int (L), size (x. indices)... )
36
36
37
- _onehotindex (x, i) = (x == i)
37
+ function Base. getindex (x:: OneHotArray{<:Any, <:Any, N} , i:: Integer , I:: Vararg{Any, N} ) where N
38
+ @boundscheck checkbounds (x, i, I... )
39
+ return x. indices[I... ] .== i
40
+ end
38
41
39
- Base. getindex (x:: OneHotVector , i:: Integer ) = _onehotindex (x. indices, i)
40
- Base. getindex (x:: OneHotVector{T, L} , :: Colon ) where {T, L} = x
42
+ function Base. getindex (x:: OneHotArray{<:Any, L} , :: Colon , I... ) where L
43
+ @boundscheck checkbounds (x, :, I... )
44
+ return OneHotArray (x. indices[I... ], L)
45
+ end
41
46
42
- Base. getindex (x:: OneHotArray , i:: Integer , I... ) = _onehotindex .(x. indices[I... ], i)
43
- Base. getindex (x:: OneHotArray{<:Any, L} , :: Colon , I... ) where L = OneHotArray (x. indices[I... ], L)
44
- Base. getindex (x:: OneHotArray{<:Any, <:Any, <:Any, N} , :: Vararg{Colon, N} ) where N = x
45
- Base. getindex (x:: OneHotArray , I:: CartesianIndex{N} ) where N = x[I[1 ], Tuple (I)[2 : N]. .. ]
47
+ Base. getindex (x:: OneHotArray , :: Colon ) = BitVector (reshape (x, :))
48
+ Base. getindex (x:: OneHotArray{<:Any, <:Any, N} , :: Colon , :: Vararg{Colon, N} ) where N = x
46
49
47
50
function Base. showarg (io:: IO , x:: OneHotArray , toplevel)
48
51
print (io, ndims (x) == 1 ? " OneHotVector(" : ndims (x) == 2 ? " OneHotMatrix(" : " OneHotArray(" )
0 commit comments