diff --git a/base/abstractarray.jl b/base/abstractarray.jl index db9096e67e20c..5123aa647985f 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -2136,6 +2136,7 @@ _hvncat(dimsshape::Union{Tuple, Int}, row_first::Bool, xs::Number...) = _typed_h _hvncat(dimsshape::Union{Tuple, Int}, row_first::Bool, xs::AbstractArray...) = _typed_hvncat(promote_eltype(xs...), dimsshape, row_first, xs...) _hvncat(dimsshape::Union{Tuple, Int}, row_first::Bool, xs::AbstractArray{T}...) where T = _typed_hvncat(T, dimsshape, row_first, xs...) + typed_hvncat(T::Type, dimsshape::Tuple, row_first::Bool, xs...) = _typed_hvncat(T, dimsshape, row_first, xs...) typed_hvncat(T::Type, dim::Int, xs...) = _typed_hvncat(T, Val(dim), xs...) @@ -2152,9 +2153,9 @@ _typed_hvncat(::Type, ::Val{0}, ::AbstractArray...) = _typed_hvncat_0d_only_one( _typed_hvncat_0d_only_one() = throw(ArgumentError("a 0-dimensional array may only contain exactly one element")) -_typed_hvncat(::Type{T}, ::Val{N}) where {T, N} = Array{T, N}(undef, ntuple(x -> 0, Val(N))) - -function _typed_hvncat(::Type{T}, dims::Tuple{Vararg{Int, N}}, row_first::Bool, xs::Number...) where {T, N} +function _typed_hvncat(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, xs::Number...) where {T, N} + all(>(0), dims) || + throw(ArgumentError("`dims` argument must contain positive integers")) A = Array{T, N}(undef, dims...) lengtha = length(A) # Necessary to store result because throw blocks are being deoptimized right now, which leads to excessive allocations lengthx = length(xs) # Cuts from 3 allocations to 1. @@ -2191,9 +2192,28 @@ function hvncat_fill!(A::Array, row_first::Bool, xs::Tuple) end _typed_hvncat(T::Type, dim::Int, ::Bool, xs...) = _typed_hvncat(T, Val(dim), xs...) # catches from _hvncat type promoters + +function _typed_hvncat(::Type{T}, ::Val{N}) where {T, N} + N < 0 && + throw(ArgumentError("concatenation dimension must be nonnegative")) + return Array{T, N}(undef, ntuple(x -> 0, Val(N))) +end + +function _typed_hvncat(T::Type, ::Val{N}, xs::Number...) where N + N < 0 && + throw(ArgumentError("concatenation dimension must be nonnegative")) + A = cat_similar(xs[1], T, (ntuple(x -> 1, Val(N - 1))..., length(xs))) + hvncat_fill!(A, false, xs) + return A +end + function _typed_hvncat(::Type{T}, ::Val{N}, as::AbstractArray...) where {T, N} # optimization for arrays that can be concatenated by copying them linearly into the destination - # conditions: the elements must all have 1- or 0-length dimensions above N + # conditions: the elements must all have 1-length dimensions above N + length(as) > 0 || + throw(ArgumentError("must have at least one element")) + N < 0 && + throw(ArgumentError("concatenation dimension must be nonnegative")) for a ∈ as ndims(a) <= N || all(x -> size(a, x) == 1, (N + 1):ndims(a)) || return _typed_hvncat(T, (ntuple(x -> 1, N - 1)..., length(as), 1), false, as...) @@ -2203,10 +2223,13 @@ function _typed_hvncat(::Type{T}, ::Val{N}, as::AbstractArray...) where {T, N} nd = max(N, ndims(as[1])) Ndim = 0 - for i ∈ 1:lastindex(as) - Ndim += cat_size(as[i], N) - for d ∈ 1:N - 1 - cat_size(as[1], d) == cat_size(as[i], d) || throw(ArgumentError("mismatched size along axis $d in element $i")) + for i ∈ eachindex(as) + a = as[i] + Ndim += size(a, N) + nd = max(nd, ndims(a)) + for d ∈ 1:N-1 + size(a, d) == size(as[1], d) || + throw(ArgumentError("all dimensions of element $i other than $N must be of length 1")) end end @@ -2222,17 +2245,20 @@ function _typed_hvncat(::Type{T}, ::Val{N}, as::AbstractArray...) where {T, N} end function _typed_hvncat(::Type{T}, ::Val{N}, as...) where {T, N} - # optimization for scalars and 1-length arrays that can be concatenated by copying them linearly - # into the destination + length(as) > 0 || + throw(ArgumentError("must have at least one element")) + N < 0 && + throw(ArgumentError("concatenation dimension must be nonnegative")) nd = N Ndim = 0 - for a ∈ as - if a isa AbstractArray - cat_size(a, N) == length(a) || - throw(ArgumentError("all dimensions of elements other than $N must be of length 1")) - nd = max(nd, cat_ndims(a)) - end + for i ∈ eachindex(as) + a = as[i] Ndim += cat_size(a, N) + nd = max(nd, cat_ndims(a)) + for d ∈ 1:N-1 + cat_size(a, d) == 1 || + throw(ArgumentError("all dimensions of element $i other than $N must be of length 1")) + end end A = Array{T, nd}(undef, ntuple(x -> 1, N - 1)..., Ndim, ntuple(x -> 1, nd - N)...) @@ -2276,7 +2302,12 @@ function _typed_hvncat_1d(::Type{T}, ds::Int, ::Val{row_first}, as...) where {T, end end -function _typed_hvncat(::Type{T}, dims::Tuple{Vararg{Int, N}}, row_first::Bool, as...) where {T, N} +function _typed_hvncat(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, as...) where {T, N} + length(as) > 0 || + throw(ArgumentError("must have at least one element")) + all(>(0), dims) || + throw(ArgumentError("`dims` argument must contain positive integers")) + d1 = row_first ? 2 : 1 d2 = row_first ? 1 : 2 @@ -2291,7 +2322,9 @@ function _typed_hvncat(::Type{T}, dims::Tuple{Vararg{Int, N}}, row_first::Bool, currentdims = zeros(Int, nd) blockcount = 0 + elementcount = 0 for i ∈ eachindex(as) + elementcount += cat_length(as[i]) currentdims[d1] += cat_size(as[i], d1) if currentdims[d1] == outdims[d1] currentdims[d1] = 0 @@ -2321,14 +2354,9 @@ function _typed_hvncat(::Type{T}, dims::Tuple{Vararg{Int, N}}, row_first::Bool, end end - # calling sum() leads to 3 extra allocations - len = 0 - for a ∈ as - len += cat_length(a) - end outlen = prod(outdims) - outlen == 0 && throw(ArgumentError("too few elements in arguments, unable to infer dimensions")) - len == outlen || throw(ArgumentError("too many elements in arguments; expected $(outlen), got $(len)")) + elementcount == outlen || + throw(ArgumentError("mismatched number of elements; expected $(outlen), got $(elementcount)")) # copy into final array A = cat_similar(as[1], T, outdims) @@ -2347,14 +2375,22 @@ function _typed_hvncat(T::Type, shape::Tuple{Tuple}, row_first::Bool, xs...) return _typed_hvncat_1d(T, shape[1][1], Val(row_first), xs...) end -function _typed_hvncat(T::Type, shape::NTuple{N, Tuple}, row_first::Bool, as...) where {N} +function _typed_hvncat(::Type{T}, shape::NTuple{N, Tuple}, row_first::Bool, as...) where {T, N} + length(as) > 0 || + throw(ArgumentError("must have at least one element")) + all(>(0), tuple((shape...)...)) || + throw(ArgumentError("`shape` argument must consist of positive integers")) + d1 = row_first ? 2 : 1 d2 = row_first ? 1 : 2 - shape = collect(shape) # saves allocations later - shapelength = shape[end][1] + shapev = collect(shape) # saves allocations later + all(!isempty, shapev) || + throw(ArgumentError("each level of `shape` argument must have at least one value")) + length(shapev[end]) == 1 || + throw(ArgumentError("last level of shape must contain only one integer")) + shapelength = shapev[end][1] lengthas = length(as) shapelength == lengthas || throw(ArgumentError("number of elements does not match shape; expected $(shapelength), got $lengthas)")) - # discover dimensions nd = max(N, cat_ndims(as[1])) outdims = zeros(Int, nd) @@ -2362,7 +2398,9 @@ function _typed_hvncat(T::Type, shape::NTuple{N, Tuple}, row_first::Bool, as...) blockcounts = zeros(Int, nd) shapepos = ones(Int, nd) + elementcount = 0 for i ∈ eachindex(as) + elementcount += cat_length(as[i]) wasstartblock = false for d ∈ 1:N ad = (d < 3 && row_first) ? (d == 1 ? 2 : 1) : d @@ -2372,27 +2410,34 @@ function _typed_hvncat(T::Type, shape::NTuple{N, Tuple}, row_first::Bool, as...) if d == 1 || i == 1 || wasstartblock currentdims[d] += dsize elseif dsize != cat_size(as[i - 1], ad) - throw(ArgumentError("""argument $i has a mismatched number of elements along axis $ad; \ - expected $(cat_size(as[i - 1], ad)), got $dsize""")) + throw(ArgumentError("argument $i has a mismatched number of elements along axis $ad; \ + expected $(cat_size(as[i - 1], ad)), got $dsize")) end wasstartblock = blockcounts[d] == 1 # remember for next dimension - isendblock = blockcounts[d] == shape[d][shapepos[d]] + isendblock = blockcounts[d] == shapev[d][shapepos[d]] if isendblock if outdims[d] == 0 outdims[d] = currentdims[d] elseif outdims[d] != currentdims[d] - throw(ArgumentError("""argument $i has a mismatched number of elements along axis $ad; \ - expected $(abs(outdims[d] - (currentdims[d] - dsize))), got $dsize""")) + throw(ArgumentError("argument $i has a mismatched number of elements along axis $ad; \ + expected $(abs(outdims[d] - (currentdims[d] - dsize))), got $dsize")) end currentdims[d] = 0 blockcounts[d] = 0 shapepos[d] += 1 + d > 1 && (blockcounts[d - 1] == 0 || + throw(ArgumentError("shape in level $d is inconsistent; level counts must nest \ + evenly into each other"))) end end end + outlen = prod(outdims) + elementcount == outlen || + throw(ArgumentError("mismatched number of elements; expected $(outlen), got $(elementcount)")) + if row_first outdims[1], outdims[2] = outdims[2], outdims[1] end diff --git a/test/abstractarray.jl b/test/abstractarray.jl index a1c6dd1b22ce7..05f93805953dd 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -1390,6 +1390,69 @@ using Base: typed_hvncat @test [v v;;; fill(v, 1, 2)] == fill(v, 1, 2, 2) end + # dims form + for v ∈ ((), (1,), ([1],), (1, [1]), ([1], 1), ([1], [1])) + # reject dimension < 0 + @test_throws ArgumentError hvncat(-1, v...) + + # reject shape tuple with no elements + @test_throws ArgumentError hvncat(((),), true, v...) + end + + # reject dims or shape with negative or zero values + for v1 ∈ (-1, 0, 1) + for v2 ∈ (-1, 0, 1) + v1 == v2 == 1 && continue + for v3 ∈ ((), (1,), ([1],), (1, [1]), ([1], 1), ([1], [1])) + @test_throws ArgumentError hvncat((v1, v2), true, v3...) + @test_throws ArgumentError hvncat(((v1,), (v2,)), true, v3...) + end + end + end + + for v ∈ ((1, [1]), ([1], 1), ([1], [1])) + # reject shape with more than one end value + @test_throws ArgumentError hvncat(((1, 1),), true, v...) + end + + for v ∈ ((1, 2, 3), (1, 2, [3]), ([1], [2], [3])) + # reject shape with more values in later level + @test_throws ArgumentError hvncat(((2, 1), (1, 1, 1)), true, v...) + end + + # reject shapes that don't nest evenly between levels (e.g. 1 + 2 does not fit into 2) + @test_throws ArgumentError hvncat(((1, 2, 1), (2, 2), (4,)), true, [1 2], [3], [4], [1 2; 3 4]) + + # zero-length arrays are handled appropriately + @test [zeros(Int, 1, 2, 0) ;;; 1 3] == [1 3;;;] + @test [[] ;;; [] ;;; []] == Array{Any}(undef, 0, 1, 3) + @test [[] ; 1 ;;; 2 ; []] == [1 ;;; 2] + @test [[] ; [] ;;; [] ; []] == Array{Any}(undef, 0, 1, 2) + @test [[] ; 1 ;;; 2] == [1 ;;; 2] + @test [[] ; [] ;;; [] ;;; []] == Array{Any}(undef, 0, 1, 3) + z = zeros(Int, 0, 0, 0) + [z z ; z ;;; z ;;; z] == Array{Int}(undef, 0, 0, 0) + + for v1 ∈ (zeros(Int, 0, 0), zeros(Int, 0, 0, 0, 0), zeros(Int, 0, 0, 0, 0, 0, 0, 0)) + for v2 ∈ (1, [1]) + for v3 ∈ (2, [2]) + @test_throws ArgumentError [v1 ;;; v2] + @test_throws ArgumentError [v1 ;;; v2 v3] + @test_throws ArgumentError [v1 v1 ;;; v2 v3] + end + end + end + v1 = zeros(Int, 0, 0, 0) + for v2 ∈ (1, [1]) + for v3 ∈ (2, [2]) + # current behavior, not potentially dangerous. + # should throw error like above loop + @test [v1 ;;; v2 v3] == [v2 v3;;;] + @test_throws ArgumentError [v1 ;;; v2] + @test_throws ArgumentError [v1 v1 ;;; v2 v3] + end + end + # 0-dimension behaviors # exactly one argument, placed in an array # if already an array, copy, with type conversion as necessary