Skip to content

Commit c402d09

Browse files
committed
cat: ensure vararg is more inferrable
Ensures that Union{} is not part of the output possibilities after type-piracy of Base.cat methods. Refs JuliaLang/julia#50550
1 parent 2c4f870 commit c402d09

File tree

1 file changed

+29
-19
lines changed

1 file changed

+29
-19
lines changed

src/sparsevector.jl

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,13 +1194,14 @@ anysparse() = false
11941194
anysparse(X) = X isa AbstractArray && issparse(X)
11951195
anysparse(X, Xs...) = anysparse(X) || anysparse(Xs...)
11961196

1197-
function hcat(X::Union{Vector, AbstractSparseVector}...)
1197+
const _SparseVecConcatGroup = Union{Vector, AbstractSparseVector}
1198+
function hcat(X::_SparseVecConcatGroup...)
11981199
if anysparse(X...)
11991200
X = map(sparse, X)
12001201
end
12011202
return cat(X...; dims=Val(2))
12021203
end
1203-
function vcat(X::Union{Vector, AbstractSparseVector}...)
1204+
function vcat(X::_SparseVecConcatGroup...)
12041205
if anysparse(X...)
12051206
X = map(sparse, X)
12061207
end
@@ -1213,30 +1214,30 @@ end
12131214
const _SparseConcatGroup = Union{AbstractVecOrMat{<:Number},Number}
12141215

12151216
# `@constprop :aggressive` allows `dims` to be propagated as constant improving return type inference
1216-
Base.@constprop :aggressive function Base._cat(dims, X::_SparseConcatGroup...)
1217-
T = promote_eltype(X...)
1218-
if anysparse(X...)
1219-
X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...)
1217+
Base.@constprop :aggressive function Base._cat(dims, X1::_SparseConcatGroup, X::_SparseConcatGroup...)
1218+
T = promote_eltype(X1, X...)
1219+
if anysparse(X1) || anysparse(X...)
1220+
X1, X = _sparse(X1), map(_makesparse, X)
12201221
end
1221-
return Base._cat_t(dims, T, X...)
1222+
return Base._cat_t(dims, T, X1, X...)
12221223
end
1223-
function hcat(X::_SparseConcatGroup...)
1224-
if anysparse(X...)
1225-
X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...)
1224+
function hcat(X1::_SparseConcatGroup, X::_SparseConcatGroup...)
1225+
if anysparse(X1) || anysparse(X...)
1226+
X1, X = _sparse(X1), map(_makesparse, X)
12261227
end
1227-
return cat(X..., dims=Val(2))
1228+
return cat(X1, X..., dims=Val(2))
12281229
end
1229-
function vcat(X::_SparseConcatGroup...)
1230-
if anysparse(X...)
1231-
X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...)
1230+
function vcat(X1::_SparseConcatGroup, X::_SparseConcatGroup...)
1231+
if anysparse(X1) || anysparse(X...)
1232+
X1, X = _sparse(X1), map(_makesparse, X)
12321233
end
1233-
return cat(X..., dims=Val(1))
1234+
return cat(X1, X..., dims=Val(1))
12341235
end
1235-
function hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
1236-
if anysparse(X...)
1237-
vcat(_hvcat_rows(rows, X...)...)
1236+
function hvcat(rows::Tuple{Vararg{Int}}, X1::_SparseConcatGroup, X::_SparseConcatGroup...)
1237+
if anysparse(X1) || anysparse(X...)
1238+
vcat(_hvcat_rows(rows, X1, X...)...)
12381239
else
1239-
Base.typed_hvcat(Base.promote_eltypeof(X...), rows, X...)
1240+
Base.typed_hvcat(Base.promote_eltypeof(X1, X...), rows, X1, X...)
12401241
end
12411242
end
12421243
function _hvcat_rows((row1, rows...)::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
@@ -1254,6 +1255,15 @@ function _hvcat_rows((row1, rows...)::Tuple{Vararg{Int}}, X::_SparseConcatGroup.
12541255
end
12551256
_hvcat_rows(::Tuple{}, X::_SparseConcatGroup...) = ()
12561257

1258+
# disambiguation for type-piracy problems created above
1259+
hcat(n1::Number, ns::Vararg{Number}) = invoke(hcat, Tuple{Vararg{Number}}, n1, ns...)
1260+
vcat(n1::Number, ns::Vararg{Number}) = invoke(vcat, Tuple{Vararg{Number}}, n1, ns...)
1261+
hcat(n1::Type{N}, ns::Vararg{N}) where {N<:Number} = invoke(hcat, Tuple{Vararg{Number}}, n1, ns...)
1262+
vcat(n1::Type{N}, ns::Vararg{N}) where {N<:Number} = invoke(vcat, Tuple{Vararg{Number}}, n1, ns...)
1263+
hvcat(rows::Tuple{Vararg{Int}}, n1::Number, ns::Vararg{Number}) = invoke(hvcat, Tuple{typeof(rows), Vararg{Number}}, rows, n1, ns...)
1264+
hvcat(rows::Tuple{Vararg{Int}}, n1::N, ns::Vararg{N}) where {N<:Number} = invoke(hvcat, Tuple{typeof(rows), Vararg{N}}, rows, n1, ns...)
1265+
1266+
12571267
# make sure UniformScaling objects are converted to sparse matrices for concatenation
12581268
promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = anysparse(A...) ? SparseMatrixCSC : Matrix
12591269
promote_to_arrays_(n::Int, ::Type{SparseMatrixCSC}, J::UniformScaling) = sparse(J, n, n)

0 commit comments

Comments
 (0)