Skip to content

Commit 9c43cb7

Browse files
authored
batch -> stack (#26)
1 parent 3059520 commit 9c43cb7

File tree

5 files changed

+21
-3
lines changed

5 files changed

+21
-3
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.2.1"
55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
89
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
@@ -14,8 +15,8 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1415
Adapt = "3.0"
1516
CUDA = "3.8"
1617
ChainRulesCore = "1.13"
18+
Compat = "4.2"
1719
GPUArraysCore = "0.1.0"
18-
MLUtils = "0.2, 0.3"
1920
NNlib = "0.8"
2021
Zygote = "0.6.35"
2122
julia = "1.6"

src/OneHotArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using Adapt
44
using ChainRulesCore
55
using GPUArraysCore
66
using LinearAlgebra
7-
using MLUtils
7+
using Compat: Compat
88
using NNlib
99

1010
export onehot, onehotbatch, onecold,

src/array.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,16 @@ Base.hcat(x::OneHotMatrix, xs::OneHotMatrix...) =
129129
Base.hcat(x::OneHotVector, xs::OneHotVector...) =
130130
OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), _nlabels(x, xs...))
131131

132-
MLUtils.batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(_indices.(xs), _nlabels(xs...))
132+
if isdefined(Base, :stack)
133+
import Base: _stack
134+
else
135+
import Compat: _stack
136+
end
137+
function _stack(::Colon, xs::AbstractArray{<:OneHotArray})
138+
n = _nlabels(first(xs))
139+
all(x -> _nlabels(x)==n, xs) || throw(DimensionMismatch("The number of labels are not the same for all one-hot arrays."))
140+
OneHotArray(Compat.stack(_indices, xs), n)
141+
end
133142

134143
Adapt.adapt_structure(T, x::OneHotArray) = OneHotArray(adapt(T, _indices(x)), x.nlabels)
135144

test/array.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ end
6464
@test cat(oa, oa; dims = 3) isa OneHotArray
6565
@test cat(oa, oa; dims = 1) == cat(collect(oa), collect(oa); dims = 1)
6666

67+
# stack
68+
@test stack([ov, ov]) == hcat(ov, ov)
69+
@test stack([ov, ov, ov]) isa OneHotMatrix
70+
@test stack([om, om]) == cat(om, om; dims = 3)
71+
@test stack([om, om]) isa OneHotArray
72+
@test stack([oa, oa, oa, oa]) isa OneHotArray
73+
6774
# proper error handling of inconsistent sizes
6875
@test_throws DimensionMismatch hcat(ov, ov2)
6976
@test_throws DimensionMismatch hcat(om, om2)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using OneHotArrays
22
using Test
3+
using Compat: stack
34

45
@testset "OneHotArray" begin
56
include("array.jl")

0 commit comments

Comments
 (0)