Skip to content

Commit 65b9be4

Browse files
authored
improve cat inferrability (#45028)
Make `cat` inferrable even if its arguments are not fully constant: ```julia julia> r = rand(Float32, 56, 56, 64, 1); julia> f(r) = cat(r, r, dims=(3,)) f (generic function with 1 method) julia> @inferred f(r); julia> last(@code_typed f(r)) Array{Float32, 4} ``` After descending into its call graph, I found that constant propagation is prohibited at `cat_t(::Type{T}, X...; dims)` due to the method instance heuristic, i.e. its body is considered to be too complex for successful inlining although it's explicitly annotated as `@inline`. But for this case, the constant propagation is greatly helpful both for abstract interpretation and optimization since it can improve the return type inference. Since it is not an easy task to improve the method instance heuristic, which is our primary logic for constant propagation, this commit does a quick fix by helping inference with the `@constprop` annotation. There is another issue that currently there is no good way to properly apply `@constprop`/`@inline` effects to a keyword function (as a note, this is a general issue of macro annotations on a method definition). So this commit also changes some internal helper functions of `cat` so that now they are not keyword ones: the changes are also necessary for the `@inline` annotation on `cat_t` to be effective to trick the method instance heuristic.
1 parent 45abec4 commit 65b9be4

File tree

4 files changed

+24
-32
lines changed

4 files changed

+24
-32
lines changed

base/abstractarray.jl

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1716,23 +1716,16 @@ end
17161716
_cs(d, a, b) = (a == b ? a : throw(DimensionMismatch(
17171717
"mismatch in dimension $d (expected $a got $b)")))
17181718

1719-
function dims2cat(::Val{dims}) where dims
1720-
if any((0), dims)
1721-
throw(ArgumentError("All cat dimensions must be positive integers, but got $dims"))
1722-
end
1723-
ntuple(in(dims), maximum(dims))
1724-
end
1725-
1719+
dims2cat(::Val{dims}) where dims = dims2cat(dims)
17261720
function dims2cat(dims)
17271721
if any((0), dims)
17281722
throw(ArgumentError("All cat dimensions must be positive integers, but got $dims"))
17291723
end
17301724
ntuple(in(dims), maximum(dims))
17311725
end
17321726

1733-
_cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims)
1727+
_cat(dims, X...) = _cat_t(dims, promote_eltypeof(X...), X...)
17341728

1735-
@inline cat_t(::Type{T}, X...; dims) where {T} = _cat_t(dims, T, X...)
17361729
@inline function _cat_t(dims, ::Type{T}, X...) where {T}
17371730
catdims = dims2cat(dims)
17381731
shape = cat_size_shape(catdims, X...)
@@ -1742,6 +1735,9 @@ _cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims)
17421735
end
17431736
return __cat(A, shape, catdims, X...)
17441737
end
1738+
# this version of `cat_t` is not very kind for inference and so its usage should be avoided,
1739+
# nevertheless it is here just for compat after https://github.com/JuliaLang/julia/pull/45028
1740+
@inline cat_t(::Type{T}, X...; dims) where {T} = _cat_t(dims, T, X...)
17451741

17461742
# Why isn't this called `__cat!`?
17471743
__cat(A, shape, catdims, X...) = __cat_offset!(A, shape, catdims, ntuple(zero, length(shape)), X...)
@@ -1880,8 +1876,8 @@ julia> reduce(hcat, vs)
18801876
"""
18811877
hcat(X...) = cat(X...; dims=Val(2))
18821878

1883-
typed_vcat(::Type{T}, X...) where T = cat_t(T, X...; dims=Val(1))
1884-
typed_hcat(::Type{T}, X...) where T = cat_t(T, X...; dims=Val(2))
1879+
typed_vcat(::Type{T}, X...) where T = _cat_t(Val(1), T, X...)
1880+
typed_hcat(::Type{T}, X...) where T = _cat_t(Val(2), T, X...)
18851881

18861882
"""
18871883
cat(A...; dims)
@@ -1917,7 +1913,8 @@ julia> cat(true, trues(2,2), trues(4)', dims=(1,2))
19171913
```
19181914
"""
19191915
@inline cat(A...; dims) = _cat(dims, A...)
1920-
_cat(catdims, A::AbstractArray{T}...) where {T} = cat_t(T, A...; dims=catdims)
1916+
# `@constprop :aggressive` allows `catdims` to be propagated as constant improving return type inference
1917+
@constprop :aggressive _cat(catdims, A::AbstractArray{T}...) where {T} = _cat_t(catdims, T, A...)
19211918

19221919
# The specializations for 1 and 2 inputs are important
19231920
# especially when running with --inline=no, see #11158
@@ -1928,12 +1925,12 @@ hcat(A::AbstractArray) = cat(A; dims=Val(2))
19281925
hcat(A::AbstractArray, B::AbstractArray) = cat(A, B; dims=Val(2))
19291926
hcat(A::AbstractArray...) = cat(A...; dims=Val(2))
19301927

1931-
typed_vcat(T::Type, A::AbstractArray) = cat_t(T, A; dims=Val(1))
1932-
typed_vcat(T::Type, A::AbstractArray, B::AbstractArray) = cat_t(T, A, B; dims=Val(1))
1933-
typed_vcat(T::Type, A::AbstractArray...) = cat_t(T, A...; dims=Val(1))
1934-
typed_hcat(T::Type, A::AbstractArray) = cat_t(T, A; dims=Val(2))
1935-
typed_hcat(T::Type, A::AbstractArray, B::AbstractArray) = cat_t(T, A, B; dims=Val(2))
1936-
typed_hcat(T::Type, A::AbstractArray...) = cat_t(T, A...; dims=Val(2))
1928+
typed_vcat(T::Type, A::AbstractArray) = _cat_t(Val(1), T, A)
1929+
typed_vcat(T::Type, A::AbstractArray, B::AbstractArray) = _cat_t(Val(1), T, A, B)
1930+
typed_vcat(T::Type, A::AbstractArray...) = _cat_t(Val(1), T, A...)
1931+
typed_hcat(T::Type, A::AbstractArray) = _cat_t(Val(2), T, A)
1932+
typed_hcat(T::Type, A::AbstractArray, B::AbstractArray) = _cat_t(Val(2), T, A, B)
1933+
typed_hcat(T::Type, A::AbstractArray...) = _cat_t(Val(2), T, A...)
19371934

19381935
# 2d horizontal and vertical concatenation
19391936

stdlib/LinearAlgebra/src/special.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,14 +414,14 @@ const _TypedDenseConcatGroup{T} = Union{Vector{T}, Adjoint{T,Vector{T}}, Transpo
414414

415415
promote_to_array_type(::Tuple{Vararg{Union{_DenseConcatGroup,UniformScaling}}}) = Matrix
416416

417-
Base._cat(dims, xs::_DenseConcatGroup...) = Base.cat_t(promote_eltype(xs...), xs...; dims=dims)
417+
Base._cat(dims, xs::_DenseConcatGroup...) = Base._cat_t(dims, promote_eltype(xs...), xs...)
418418
vcat(A::Vector...) = Base.typed_vcat(promote_eltype(A...), A...)
419419
vcat(A::_DenseConcatGroup...) = Base.typed_vcat(promote_eltype(A...), A...)
420420
hcat(A::Vector...) = Base.typed_hcat(promote_eltype(A...), A...)
421421
hcat(A::_DenseConcatGroup...) = Base.typed_hcat(promote_eltype(A...), A...)
422422
hvcat(rows::Tuple{Vararg{Int}}, xs::_DenseConcatGroup...) = Base.typed_hvcat(promote_eltype(xs...), rows, xs...)
423423
# For performance, specially handle the case where the matrices/vectors have homogeneous eltype
424-
Base._cat(dims, xs::_TypedDenseConcatGroup{T}...) where {T} = Base.cat_t(T, xs...; dims=dims)
424+
Base._cat(dims, xs::_TypedDenseConcatGroup{T}...) where {T} = Base._cat_t(dims, T, xs...)
425425
vcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_vcat(T, A...)
426426
hcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hcat(T, A...)
427427
hvcat(rows::Tuple{Vararg{Int}}, xs::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hvcat(T, rows, xs...)

test/abstractarray.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,10 @@ function test_cat(::Type{TestAbstractArray})
733733
cat3v(As) = cat(As...; dims=Val(3))
734734
@test @inferred(cat3v(As)) == zeros(2, 2, 2)
735735
@test @inferred(cat(As...; dims=Val((1,2)))) == zeros(4, 4)
736+
737+
r = rand(Float32, 56, 56, 64, 1);
738+
f(r) = cat(r, r, dims=(3,))
739+
@inferred f(r);
736740
end
737741

738742
function test_ind2sub(::Type{TestAbstractArray})

test/ambiguous.jl

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -172,20 +172,11 @@ using LinearAlgebra, SparseArrays, SuiteSparse
172172
# not using isempty so this prints more information when it fails
173173
@testset "detect_ambiguities" begin
174174
let ambig = Set{Any}(((m1.sig, m2.sig) for (m1, m2) in detect_ambiguities(Core, Base; recursive=true, ambiguous_bottom=false, allowed_undefineds)))
175-
@test isempty(ambig)
176-
expect = []
177175
good = true
178-
while !isempty(ambig)
179-
sigs = pop!(ambig)
180-
i = findfirst(==(sigs), expect)
181-
if i === nothing
182-
println(stderr, "push!(expect, (", sigs[1], ", ", sigs[2], "))")
183-
good = false
184-
continue
185-
end
186-
deleteat!(expect, i)
176+
for (sig1, sig2) in ambig
177+
@test sig1 === sig2 # print this ambiguity
178+
good = false
187179
end
188-
@test isempty(expect)
189180
@test good
190181
end
191182

0 commit comments

Comments
 (0)