Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Compiler/Runtime improvements
* Julia-level SROA (Scalar Replacement of Aggregates) has been improved, i.e. allowing elimination of
`getfield` call with constant global field ([#42355]), enabling elimination of mutable struct with
uninitialized fields ([#43208]), improving performance ([#43232]), handling more nested `getfield`
calls ([#43239]).
calls ([#43239], [#43267]).
* Abstract callsite can now be inlined or statically resolved as far as the callsite has a single
matching method ([#43113]).

Expand Down
111 changes: 80 additions & 31 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospec
end

function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
callback = (@nospecialize(pi), @nospecialize(idx)) -> false)
callback = (@nospecialize(x), @nospecialize(idx)) -> false)
while true
if isa(defssa, OldSSAValue)
if already_inserted(compact, defssa)
Expand Down Expand Up @@ -337,10 +337,29 @@ struct LiftedValue
end
const LiftedLeaves = IdDict{Any, Union{Nothing,LiftedValue}}

# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
# which can be very large sometimes, and program counters in question are often very sparse
const SPCSet = IdSet{Int}

mutable struct NestedLoads
maybe::Union{Nothing,SPCSet}
NestedLoads() = new(nothing)
end
function record_nested_load!(nested_loads::NestedLoads, pc::Int)
maybe = nested_loads.maybe
maybe === nothing && (maybe = nested_loads.maybe = SPCSet())
push!(maybe::SPCSet, pc)
end
function is_nested_load(nested_loads::NestedLoads, pc::Int)
maybe = nested_loads.maybe
maybe === nothing && return false
return pc in maybe::SPCSet
end

# try to compute lifted values that can replace `getfield(x, field)` call
# where `x` is an immutable struct that are defined at any of `leaves`
function lift_leaves(compact::IncrementalCompact,
@nospecialize(result_t), field::Int, leaves::Vector{Any})
function lift_leaves!(compact::IncrementalCompact, leaves::Vector{Any},
@nospecialize(result_t), field::Int, nested_loads::NestedLoads)
# For every leaf, the lifted value
lifted_leaves = LiftedLeaves()
maybe_undef = false
Expand Down Expand Up @@ -390,11 +409,19 @@ function lift_leaves(compact::IncrementalCompact,
ocleaf = simple_walk(compact, ocleaf)
end
ocdef, _ = walk_to_def(compact, ocleaf)
if isexpr(ocdef, :new_opaque_closure) && isa(field, Int) && 1 ≤ field ≤ length(ocdef.args)-5
if isexpr(ocdef, :new_opaque_closure) && 1 ≤ field ≤ length(ocdef.args)-5
lift_arg!(compact, leaf, cache_key, ocdef, 5+field, lifted_leaves)
continue
end
return nothing
elseif is_known_call(def, getfield, compact)
if isa(leaf, SSAValue)
struct_typ = unwrap_unionall(widenconst(argextype(def.args[2], compact)))
if ismutabletype(struct_typ)
record_nested_load!(nested_loads, leaf.id)
end
end
return nothing
else
typ = argextype(leaf, compact)
if !isa(typ, Const)
Expand Down Expand Up @@ -588,7 +615,7 @@ function perform_lifting!(compact::IncrementalCompact,
end
val = lifted_val.x
if isa(val, AnySSAValue)
callback = (@nospecialize(pi), @nospecialize(idx)) -> true
callback = (@nospecialize(x), @nospecialize(idx)) -> true
val = simple_walk(compact, val, callback)
end
push!(new_node.values, val)
Expand Down Expand Up @@ -619,10 +646,6 @@ function perform_lifting!(compact::IncrementalCompact,
return stmt_val # N.B. should never happen
end

# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
# which can be very large sometimes, and program counters in question are often very sparse
const SPCSet = IdSet{Int}

"""
sroa_pass!(ir::IRCode) -> newir::IRCode

Expand All @@ -641,10 +664,11 @@ its argument).
In a case when all usages are fully eliminated, `struct` allocation may also be erased as
a result of succeeding dead code elimination.
"""
function sroa_pass!(ir::IRCode)
function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
compact = IncrementalCompact(ir)
defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations
lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}()
nested_loads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Immutable`
for ((_, idx), stmt) in compact
# check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
isa(stmt, Expr) || continue
Expand Down Expand Up @@ -672,7 +696,7 @@ function sroa_pass!(ir::IRCode)
preserved_arg = stmt.args[pidx]
isa(preserved_arg, SSAValue) || continue
let intermediaries = SPCSet()
callback = function (@nospecialize(pi), @nospecialize(ssa))
callback = function (@nospecialize(x), @nospecialize(ssa))
push!(intermediaries, ssa.id)
return false
end
Expand Down Expand Up @@ -700,7 +724,9 @@ function sroa_pass!(ir::IRCode)
if defuses === nothing
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
end
mid, defuse = get!(defuses, defidx, (SPCSet(), SSADefUse()))
mid, defuse = get!(defuses, defidx) do
SPCSet(), SSADefUse()
end
push!(defuse.ccall_preserve_uses, idx)
union!(mid, intermediaries)
end
Expand All @@ -710,16 +736,17 @@ function sroa_pass!(ir::IRCode)
compact[idx] = form_new_preserves(stmt, preserved, new_preserves)
end
continue
# TODO: This isn't the best place to put these
elseif is_known_call(stmt, typeassert, compact)
canonicalize_typeassert!(compact, idx, stmt)
continue
elseif is_known_call(stmt, (===), compact)
lift_comparison!(compact, idx, stmt, lifting_cache)
continue
# elseif is_known_call(stmt, isa, compact)
# TODO do a similar optimization as `lift_comparison!` for `===`
else
if optional_opts
# TODO: This isn't the best place to put these
if is_known_call(stmt, typeassert, compact)
canonicalize_typeassert!(compact, idx, stmt)
elseif is_known_call(stmt, (===), compact)
lift_comparison!(compact, idx, stmt, lifting_cache)
# elseif is_known_call(stmt, isa, compact)
# TODO do a similar optimization as `lift_comparison!` for `===`
end
end
continue
end

Expand All @@ -745,7 +772,7 @@ function sroa_pass!(ir::IRCode)
if ismutabletype(struct_typ)
isa(val, SSAValue) || continue
let intermediaries = SPCSet()
callback = function (@nospecialize(pi), @nospecialize(ssa))
callback = function (@nospecialize(x), @nospecialize(ssa))
push!(intermediaries, ssa.id)
return false
end
Expand All @@ -755,7 +782,9 @@ function sroa_pass!(ir::IRCode)
if defuses === nothing
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
end
mid, defuse = get!(defuses, def.id, (SPCSet(), SSADefUse()))
mid, defuse = get!(defuses, def.id) do
SPCSet(), SSADefUse()
end
if is_setfield
push!(defuse.defs, idx)
else
Expand All @@ -777,7 +806,7 @@ function sroa_pass!(ir::IRCode)
isempty(leaves) && continue

result_t = argextype(SSAValue(idx), compact)
lifted_result = lift_leaves(compact, result_t, field, leaves)
lifted_result = lift_leaves!(compact, leaves, result_t, field, nested_loads)
lifted_result === nothing && continue
lifted_leaves, any_undef = lifted_result

Expand Down Expand Up @@ -813,18 +842,21 @@ function sroa_pass!(ir::IRCode)
used_ssas = copy(compact.used_ssas)
simple_dce!(compact, (x::SSAValue) -> used_ssas[x.id] -= 1)
ir = complete(compact)
sroa_mutables!(ir, defuses, used_ssas)
return ir
return sroa_mutables!(ir, defuses, used_ssas, nested_loads)
else
simple_dce!(compact)
return complete(compact)
end
end

function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int})
# initialization of domtree is delayed to avoid the expensive computation in many cases
local domtree = nothing
for (idx, (intermediaries, defuse)) in defuses
function sroa_mutables!(ir::IRCode,
defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int},
nested_loads::NestedLoads)
domtree = nothing # initialization of domtree is delayed to avoid the expensive computation in many cases
nested_mloads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable`
any_eliminated = false
# NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield`
for (idx, (intermediaries, defuse)) in sort!(collect(defuses); by=first, rev=true)
intermediaries = collect(intermediaries)
# Check if there are any uses we did not account for. If so, the variable
# escapes and we cannot eliminate the allocation. This works, because we're guaranteed
Expand All @@ -839,7 +871,19 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
nleaves == nuses_total || continue
# Find the type for this allocation
defexpr = ir[SSAValue(idx)]
isexpr(defexpr, :new) || continue
isa(defexpr, Expr) || continue
if !isexpr(defexpr, :new)
if is_known_call(defexpr, getfield, ir)
val = defexpr.args[2]
if isa(val, SSAValue)
struct_typ = unwrap_unionall(widenconst(argextype(val, ir)))
if ismutabletype(struct_typ)
record_nested_load!(nested_mloads, idx)
end
end
end
continue
end
newidx = idx
typ = ir.stmts[newidx][:type]
if isa(typ, UnionAll)
Expand Down Expand Up @@ -919,6 +963,10 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
# Now go through all uses and rewrite them
for stmt in du.uses
ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt)
if !any_eliminated
any_eliminated |= (is_nested_load(nested_loads, stmt) ||
is_nested_load(nested_mloads, stmt))
end
end
if !isbitstype(ftyp)
if preserve_uses !== nothing
Expand Down Expand Up @@ -955,6 +1003,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse

@label skip
end
return any_eliminated ? sroa_pass!(compact!(ir), false) : ir
end

function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any})
Expand Down
51 changes: 41 additions & 10 deletions test/compiler/irpasses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ iscall(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[
struct ImmutableXYZ; x; y; z; end
mutable struct MutableXYZ; x; y; z; end

struct ImmutableOuter{T}; x::T; y::T; z::T; end
mutable struct MutableOuter{T}; x::T; y::T; z::T; end

# should optimize away very basic cases
let src = code_typed1((Any,Any,Any)) do x, y, z
xyz = ImmutableXYZ(x, y, z)
Expand Down Expand Up @@ -198,9 +201,8 @@ let src = code_typed1((Bool,Bool,Any,Any)) do c1, c2, x, y
@test any(isnew, src.code)
end

# should include a simple alias analysis
struct ImmutableOuter{T}; x::T; y::T; z::T; end
mutable struct MutableOuter{T}; x::T; y::T; z::T; end
# alias analysis
# --------------
let src = code_typed1((Any,Any,Any)) do x, y, z
xyz = ImmutableXYZ(x, y, z)
outer = ImmutableOuter(xyz, xyz, xyz)
Expand All @@ -227,9 +229,11 @@ let src = code_typed1((Any,Any,Any)) do x, y, z
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
end
end

# FIXME our analysis isn't yet so powerful at this moment: may be unable to handle nested objects well
# OK: mutable(immutable(...)) case
# FIXME? in order to handle nested mutable `getfield` calls, we run SROA iteratively until
# any nested mutable `getfield` calls become no longer eliminatable:
# it's probably not the most efficient option and we may want to introduce some sort of
# alias analysis and eliminates all the loads at once.
# mutable(immutable(...)) case
let src = code_typed1((Any,Any,Any)) do x, y, z
xyz = MutableXYZ(x, y, z)
t = (xyz,)
Expand Down Expand Up @@ -260,21 +264,48 @@ let # this is a simple end to end test case, which demonstrates allocation elimi
# compiled code for `simple_sroa`, otherwise everything can be folded even without SROA
@test @allocated(simple_sroa(s)) == 0
end
# FIXME: immutable(mutable(...)) case
# immutable(mutable(...)) case
let src = code_typed1((Any,Any,Any)) do x, y, z
xyz = ImmutableXYZ(x, y, z)
outer = MutableOuter(xyz, xyz, xyz)
outer.x.x, outer.y.y, outer.z.z
end
@test_broken !any(isnew, src.code)
@test !any(isnew, src.code)
@test any(src.code) do @nospecialize x
iscall((src, tuple), x) &&
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
end
end
# FIXME: mutable(mutable(...)) case
# mutable(mutable(...)) case
let src = code_typed1((Any,Any,Any)) do x, y, z
xyz = MutableXYZ(x, y, z)
outer = MutableOuter(xyz, xyz, xyz)
outer.x.x, outer.y.y, outer.z.z
end
@test_broken !any(isnew, src.code)
@test !any(isnew, src.code)
@test any(src.code) do @nospecialize x
iscall((src, tuple), x) &&
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
end
end
let src = code_typed1((Any,Any,Any)) do x, y, z
xyz = MutableXYZ(x, y, z)
inner = MutableOuter(xyz, xyz, xyz)
outer = MutableOuter(inner, inner, inner)
outer.x.x.x, outer.y.y.y, outer.z.z.z
end
@test !any(isnew, src.code)
@test any(src.code) do @nospecialize x
iscall((src, tuple), x) &&
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
end
end
let # NOTE `sroa_mutables!` eliminate from innermost definitions, so that it should be able
# to fully eliminate this insanely nested example
src = code_typed1((Int,)) do x
(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref((x))))))))))))[][][][][][][][][][]
end
@test !any(isnew, src.code)
end

# should work nicely with inlining to optimize away a complicated case
Expand Down