Skip to content

Commit cc1daed

Browse files
aviateskLilithHafner
authored andcommitted
inference: form PartialStruct for extra type information propagation (JuliaLang#42831)
* inference: form `PartialStruct` for extra type information propagation This commit forms `PartialStruct` whenever there is any type-level refinement available about a field, even if it's not "constant" information. In Julia "definitions" are allowed to be abstract whereas "usages" (i.e. callsites) are often concrete. The basic idea is to allow inference to make more use of such precise callsite type information by encoding it as `PartialStruct`. This may increase optimization possibilities of "unidiomatic" Julia code, which may contain poorly-typed definitions, like this very contrived example: ```julia struct Problem n; s; c; t end function main(args...) prob = Problem(args...) s = 0 for i in 1:prob.n m = mod(i, 3) s += m == 0 ? sin(prob.s) : m == 1 ? cos(prob.c) : tan(prob.t) end return prob, s end main(10000, 1, 2, 3) ``` One of the obvious limitation is that this extra type information can be propagated inter-procedurally only as a const-propagation. I'm not sure this kind of "just a type-level" refinement can often make constant-prop' successful (i.e. shape-up a method body and allow it to be inlined, encoding the extra type information into the generated code), thus I didn't not modify any part of const-prop' heuristics. So the improvements from this change might not be very useful for general inter-procedural analysis currently, but they should definitely improve the accuracy of local analysis and very simple inter-procedural analysis.
1 parent 4f42099 commit cc1daed

File tree

4 files changed

+81
-45
lines changed

4 files changed

+81
-45
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -345,9 +345,9 @@ function from_interconditional(@nospecialize(typ), (; fargs, argtypes)::ArgInfo,
345345
else
346346
elsetype = tmeet(elsetype, widenconst(new_elsetype))
347347
end
348-
if (slot > 0 || condval !== false) && !(old vtype) # essentially vtype ⋤ old
348+
if (slot > 0 || condval !== false) && vtype old
349349
slot = id
350-
elseif (slot > 0 || condval !== true) && !(old elsetype) # essentially elsetype ⋤ old
350+
elseif (slot > 0 || condval !== true) && elsetype old
351351
slot = id
352352
else # reset: no new useful information for this slot
353353
vtype = elsetype = Any
@@ -1598,36 +1598,35 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
15981598
elseif ehead === :new
15991599
t = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))[1]
16001600
if isconcretetype(t) && !ismutabletype(t)
1601-
args = Vector{Any}(undef, length(e.args)-1)
1602-
ats = Vector{Any}(undef, length(e.args)-1)
1603-
anyconst = false
1604-
allconst = true
1601+
nargs = length(e.args) - 1
1602+
ats = Vector{Any}(undef, nargs)
1603+
local anyrefine = false
1604+
local allconst = true
16051605
for i = 2:length(e.args)
16061606
at = widenconditional(abstract_eval_value(interp, e.args[i], vtypes, sv))
1607-
if !anyconst
1608-
anyconst = has_nontrivial_const_info(at)
1609-
end
1610-
ats[i-1] = at
1607+
ft = fieldtype(t, i-1)
1608+
at = tmeet(at, ft)
16111609
if at === Bottom
16121610
t = Bottom
1613-
allconst = anyconst = false
1614-
break
1615-
elseif at isa Const
1616-
if !(at.val isa fieldtype(t, i - 1))
1617-
t = Bottom
1618-
allconst = anyconst = false
1619-
break
1620-
end
1621-
args[i-1] = at.val
1622-
else
1611+
@goto t_computed
1612+
elseif !isa(at, Const)
16231613
allconst = false
16241614
end
1615+
if !anyrefine
1616+
anyrefine = has_nontrivial_const_info(at) || # constant information
1617+
at ft # just a type-level information, but more precise than the declared type
1618+
end
1619+
ats[i-1] = at
16251620
end
16261621
# For now, don't allow partially initialized Const/PartialStruct
1627-
if t !== Bottom && fieldcount(t) == length(ats)
1622+
if fieldcount(t) == nargs
16281623
if allconst
1629-
t = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), t, args, length(args)))
1630-
elseif anyconst
1624+
argvals = Vector{Any}(undef, nargs)
1625+
for j in 1:nargs
1626+
argvals[j] = (ats[j]::Const).val
1627+
end
1628+
t = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), t, argvals, nargs))
1629+
elseif anyrefine
16311630
t = PartialStruct(t, ats)
16321631
end
16331632
end
@@ -1638,7 +1637,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
16381637
at = abstract_eval_value(interp, e.args[2], vtypes, sv)
16391638
n = fieldcount(t)
16401639
if isa(at, Const) && isa(at.val, Tuple) && n == length(at.val::Tuple) &&
1641-
let t = t; _all(i->getfield(at.val::Tuple, i) isa fieldtype(t, i), 1:n); end
1640+
let t = t, at = at; _all(i->getfield(at.val::Tuple, i) isa fieldtype(t, i), 1:n); end
16421641
t = Const(ccall(:jl_new_structt, Any, (Any, Any), t, at.val))
16431642
elseif isa(at, PartialStruct) && at Tuple && n == length(at.fields::Vector{Any}) &&
16441643
let t = t, at = at; _all(i->(at.fields::Vector{Any})[i] fieldtype(t, i), 1:n); end
@@ -1718,6 +1717,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
17181717
else
17191718
t = abstract_eval_value_expr(interp, e, vtypes, sv)
17201719
end
1720+
@label t_computed
17211721
@assert !isa(t, TypeVar) "unhandled TypeVar"
17221722
if isa(t, DataType) && isdefined(t, :instance)
17231723
# replace singleton types with their equivalent Const object
@@ -1801,17 +1801,18 @@ function widenreturn(@nospecialize(rt), @nospecialize(bestguess), nslots::Int, s
18011801
isa(rt, Type) && return rt
18021802
if isa(rt, PartialStruct)
18031803
fields = copy(rt.fields)
1804-
haveconst = false
1804+
local anyrefine = false
18051805
for i in 1:length(fields)
18061806
a = fields[i]
18071807
a = isvarargtype(a) ? a : widenreturn(a, bestguess, nslots, slottypes, changes)
1808-
if !haveconst && has_const_info(a)
1808+
if !anyrefine
18091809
# TODO: consider adding && const_prop_profitable(a) here?
1810-
haveconst = true
1810+
anyrefine = has_const_info(a) ||
1811+
a fieldtype(rt.typ, i)
18111812
end
18121813
fields[i] = a
18131814
end
1814-
haveconst && return PartialStruct(rt.typ, fields)
1815+
anyrefine && return PartialStruct(rt.typ, fields)
18151816
end
18161817
if isa(rt, PartialOpaque)
18171818
return rt # XXX: this case was missed in #39512

base/compiler/typelattice.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,12 @@ function maybe_extract_const_bool(c::AnyConditional)
140140
end
141141
maybe_extract_const_bool(@nospecialize c) = nothing
142142

143-
function (@nospecialize(a), @nospecialize(b))
143+
"""
144+
a ⊑ b -> Bool
145+
146+
The non-strict partial order over the type inference lattice.
147+
"""
148+
@nospecialize(a) @nospecialize(b) = begin
144149
if isa(b, LimitedAccuracy)
145150
if !isa(a, LimitedAccuracy)
146151
return false
@@ -232,6 +237,22 @@ function ⊑(@nospecialize(a), @nospecialize(b))
232237
end
233238
end
234239

240+
"""
241+
a ⊏ b -> Bool
242+
243+
The strict partial order over the type inference lattice.
244+
This is defined as the irreflexive kernel of `⊑`.
245+
"""
246+
@nospecialize(a) @nospecialize(b) = a b && !(b, a)
247+
248+
"""
249+
a ⋤ b -> Bool
250+
251+
This order could be used as a slightly more efficient version of the strict order `⊏`,
252+
where we can safely assume `a ⊑ b` holds.
253+
"""
254+
@nospecialize(a) @nospecialize(b) = !(b, a)
255+
235256
# Check if two lattice elements are partial order equivalent. This is basically
236257
# `a ⊑ b && b ⊑ a` but with extra performance optimizations.
237258
function is_lattice_equal(@nospecialize(a), @nospecialize(b))

test/compiler/inference.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3669,3 +3669,26 @@ end
36693669

36703670
# issue #42646
36713671
@test only(Base.return_types(getindex, (Array{undef}, Int))) >: Union{} # check that it does not throw
3672+
3673+
# form PartialStruct for extra type information propagation
3674+
struct FieldTypeRefinement{S,T}
3675+
s::S
3676+
t::T
3677+
end
3678+
@test Base.return_types((Int,)) do s
3679+
o = FieldTypeRefinement{Any,Int}(s, s)
3680+
o.s
3681+
end |> only == Int
3682+
@test Base.return_types((Int,)) do s
3683+
o = FieldTypeRefinement{Int,Any}(s, s)
3684+
o.t
3685+
end |> only == Int
3686+
@test Base.return_types((Int,)) do s
3687+
o = FieldTypeRefinement{Any,Any}(s, s)
3688+
o.s, o.t
3689+
end |> only == Tuple{Int,Int}
3690+
@test Base.return_types((Int,)) do a
3691+
s1 = Some{Any}(a)
3692+
s2 = Some{Any}(s1)
3693+
s2.value.value
3694+
end |> only == Int

test/compiler/irpasses.jl

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -426,31 +426,22 @@ let # `getfield_elim_pass!` should work with constant globals
426426
end
427427
end
428428

429-
let # `typeassert_elim_pass!`
429+
let
430+
# `typeassert` elimination after SROA
431+
# NOTE we can remove this optimization once inference is able to reason about memory-effects
430432
src = @eval Module() begin
431-
struct Foo; x; end
433+
mutable struct Foo; x; end
432434

433435
code_typed((Int,)) do a
434436
x1 = Foo(a)
435437
x2 = Foo(x1)
436-
x3 = Foo(x2)
437-
438-
r1 = (x2.x::Foo).x
439-
r2 = (x2.x::Foo).x::Int
440-
r3 = (x2.x::Foo).x::Integer
441-
r4 = ((x3.x::Foo).x::Foo).x
442-
443-
return r1, r2, r3, r4
438+
return typeassert(x2.x, Foo).x
444439
end |> only |> first
445440
end
446-
# eliminate `typeassert(f2.a, Foo)`
447-
@test all(src.code) do @nospecialize(stmt)
441+
# eliminate `typeassert(x2.x, Foo)`
442+
@test all(src.code) do @nospecialize stmt
448443
Meta.isexpr(stmt, :call) || return true
449444
ft = Core.Compiler.argextype(stmt.args[1], src, Any[], src.slottypes)
450445
return Core.Compiler.widenconst(ft) !== typeof(typeassert)
451446
end
452-
# succeeding simple DCE will eliminate `Foo(a)`
453-
@test all(src.code) do @nospecialize(stmt)
454-
return !Meta.isexpr(stmt, :new)
455-
end
456447
end

0 commit comments

Comments
 (0)