Skip to content

Commit fb5e715

Browse files
committed
optimizer: run SROA multiple times to handle more nested loads
1 parent 30fe8cc commit fb5e715

File tree

2 files changed

+114
-35
lines changed

2 files changed

+114
-35
lines changed

base/compiler/ssair/passes.jl

Lines changed: 80 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -329,13 +329,32 @@ struct LiftedValue
329329
end
330330
const LiftedLeaves = IdDict{Any, Union{Nothing,LiftedValue}}
331331

332+
# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
333+
# which can be very large sometimes, and program counters in question are often very sparse
334+
const SPCSet = IdSet{Int}
335+
336+
mutable struct NestedLoads
337+
maybe::Union{Nothing,SPCSet}
338+
NestedLoads() = new(nothing)
339+
end
340+
function record_nested_load!(nested_loads::NestedLoads, pc::Int)
341+
maybe = nested_loads.maybe
342+
maybe === nothing && (maybe = nested_loads.maybe = SPCSet())
343+
push!(maybe::SPCSet, pc)
344+
end
345+
function is_nested_load(nested_loads::NestedLoads, pc::Int)
346+
maybe = nested_loads.maybe
347+
maybe === nothing && return false
348+
return pc in maybe::SPCSet
349+
end
350+
332351
# try to compute lifted values that can replace `getfield(x, field)` call
333352
# where `x` is an immutable struct that are defined at any of `leaves`
334-
function lift_leaves(compact::IncrementalCompact,
335-
@nospecialize(result_t), field::Int, leaves::Vector{Any})
353+
function lift_leaves!(compact::IncrementalCompact, leaves::Vector{Any},
354+
@nospecialize(result_t), field::Int, nested_loads::NestedLoads)
336355
# For every leaf, the lifted value
337356
lifted_leaves = LiftedLeaves()
338-
maybe_undef = false
357+
local maybe_undef = false
339358
for leaf in leaves
340359
cache_key = leaf
341360
if isa(leaf, AnySSAValue)
@@ -382,11 +401,19 @@ function lift_leaves(compact::IncrementalCompact,
382401
ocleaf = simple_walk(compact, ocleaf)
383402
end
384403
ocdef, _ = walk_to_def(compact, ocleaf)
385-
if isexpr(ocdef, :new_opaque_closure) && isa(field, Int) && 1 field length(ocdef.args)-5
404+
if isexpr(ocdef, :new_opaque_closure) && 1 field length(ocdef.args)-5
386405
lift_arg!(compact, leaf, cache_key, ocdef, 5+field, lifted_leaves)
387406
continue
388407
end
389408
return nothing
409+
elseif isa(def, Expr) && is_known_call(def, getfield, compact)
410+
if isa(leaf, SSAValue)
411+
struct_typ = unwrap_unionall(widenconst(compact_exprtype(compact, def.args[2])))
412+
if ismutabletype(struct_typ)
413+
record_nested_load!(nested_loads, leaf.id)
414+
end
415+
end
416+
return nothing
390417
else
391418
typ = compact_exprtype(compact, leaf)
392419
if !isa(typ, Const)
@@ -611,10 +638,6 @@ function perform_lifting!(compact::IncrementalCompact,
611638
return stmt_val # N.B. should never happen
612639
end
613640

614-
# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
615-
# which can be very large sometimes, and program counters in question are often very sparse
616-
const SPCSet = IdSet{Int}
617-
618641
"""
619642
sroa_pass!(ir::IRCode) -> newir::IRCode
620643
@@ -633,10 +656,11 @@ its argument).
633656
In a case when all usages are fully eliminated, `struct` allocation may also be erased as
634657
a result of succeeding dead code elimination.
635658
"""
636-
function sroa_pass!(ir::IRCode)
659+
function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
637660
compact = IncrementalCompact(ir)
638661
defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations
639662
lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}()
663+
nested_loads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Immutable`
640664
for ((_, idx), stmt) in compact
641665
# check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
642666
isa(stmt, Expr) || continue
@@ -691,7 +715,9 @@ function sroa_pass!(ir::IRCode)
691715
if defuses === nothing
692716
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
693717
end
694-
mid, defuse = get!(defuses, defidx, (SPCSet(), SSADefUse()))
718+
mid, defuse = get!(defuses, defidx) do
719+
SPCSet(), SSADefUse()
720+
end
695721
push!(defuse.ccall_preserve_uses, idx)
696722
union!(mid, intermediaries)
697723
end
@@ -704,16 +730,17 @@ function sroa_pass!(ir::IRCode)
704730
compact[idx] = new_expr
705731
end
706732
continue
707-
# TODO: This isn't the best place to put these
708-
elseif is_known_call(stmt, typeassert, compact)
709-
canonicalize_typeassert!(compact, idx, stmt)
710-
continue
711-
elseif is_known_call(stmt, (===), compact)
712-
lift_comparison!(compact, idx, stmt, lifting_cache)
713-
continue
714-
# elseif is_known_call(stmt, isa, compact)
715-
# TODO do a similar optimization as `lift_comparison!` for `===`
716733
else
734+
if optional_opts
735+
# TODO: This isn't the best place to put these
736+
if is_known_call(stmt, typeassert, compact)
737+
canonicalize_typeassert!(compact, idx, stmt)
738+
elseif is_known_call(stmt, (===), compact)
739+
lift_comparison!(compact, idx, stmt, lifting_cache)
740+
# elseif is_known_call(stmt, isa, compact)
741+
# TODO do a similar optimization as `lift_comparison!` for `===`
742+
end
743+
end
717744
continue
718745
end
719746

@@ -749,7 +776,9 @@ function sroa_pass!(ir::IRCode)
749776
if defuses === nothing
750777
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
751778
end
752-
mid, defuse = get!(defuses, def.id, (SPCSet(), SSADefUse()))
779+
mid, defuse = get!(defuses, def.id) do
780+
SPCSet(), SSADefUse()
781+
end
753782
if is_setfield
754783
push!(defuse.defs, idx)
755784
else
@@ -771,7 +800,7 @@ function sroa_pass!(ir::IRCode)
771800
isempty(leaves) && continue
772801

773802
result_t = compact_exprtype(compact, SSAValue(idx))
774-
lifted_result = lift_leaves(compact, result_t, field, leaves)
803+
lifted_result = lift_leaves!(compact, leaves, result_t, field, nested_loads)
775804
lifted_result === nothing && continue
776805
lifted_leaves, any_undef = lifted_result
777806

@@ -807,20 +836,23 @@ function sroa_pass!(ir::IRCode)
807836
used_ssas = copy(compact.used_ssas)
808837
simple_dce!(compact, (x::SSAValue) -> used_ssas[x.id] -= 1)
809838
ir = complete(compact)
810-
sroa_mutables!(ir, defuses, used_ssas)
811-
return ir
839+
return sroa_mutables!(ir, defuses, used_ssas, nested_loads)
812840
else
813841
simple_dce!(compact)
814842
return complete(compact)
815843
end
816844
end
817845

818-
function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int})
846+
function sroa_mutables!(ir::IRCode,
847+
defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int},
848+
nested_loads::NestedLoads)
819849
# Compute domtree, needed below, now that we have finished compacting the IR.
820850
# This needs to be after we iterate through the IR with `IncrementalCompact`
821851
# because removing dead blocks can invalidate the domtree.
822852
@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)
823853

854+
nested_mloads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable`
855+
local any_eliminated = any_meliminated = false
824856
for (idx, (intermediaries, defuse)) in defuses
825857
intermediaries = collect(intermediaries)
826858
# Check if there are any uses we did not account for. If so, the variable
@@ -836,7 +868,19 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
836868
nleaves == nuses_total || continue
837869
# Find the type for this allocation
838870
defexpr = ir[SSAValue(idx)]
839-
isexpr(defexpr, :new) || continue
871+
isa(defexpr, Expr) || continue
872+
if !isexpr(defexpr, :new)
873+
if is_known_call(defexpr, getfield, ir)
874+
val = defexpr.args[2]
875+
if isa(val, SSAValue)
876+
struct_typ = unwrap_unionall(widenconst(argextype(val, ir)))
877+
if ismutabletype(struct_typ)
878+
record_nested_load!(nested_mloads, idx)
879+
end
880+
end
881+
end
882+
continue
883+
end
840884
newidx = idx
841885
typ = ir.stmts[newidx][:type]
842886
if isa(typ, UnionAll)
@@ -900,6 +944,12 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
900944
# Now go through all uses and rewrite them
901945
for stmt in du.uses
902946
ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt)
947+
if !any_eliminated
948+
any_eliminated |= is_nested_load(nested_loads, stmt)
949+
end
950+
if !any_meliminated
951+
any_meliminated |= is_nested_load(nested_mloads, stmt)
952+
end
903953
end
904954
if !isbitstype(ftyp)
905955
if preserve_uses !== nothing
@@ -938,6 +988,11 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
938988

939989
@label skip
940990
end
991+
if any_eliminated || any_meliminated
992+
return sroa_pass!(compact!(ir), false)
993+
else
994+
return ir
995+
end
941996
end
942997

943998
"""

test/compiler/irpasses.jl

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ iscall(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[
9090
struct ImmutableXYZ; x; y; z; end
9191
mutable struct MutableXYZ; x; y; z; end
9292

93+
struct ImmutableOuter{T}; x::T; y::T; z::T; end
94+
mutable struct MutableOuter{T}; x::T; y::T; z::T; end
95+
9396
# should optimize away very basic cases
9497
let src = code_typed1((Any,Any,Any)) do x, y, z
9598
xyz = ImmutableXYZ(x, y, z)
@@ -198,9 +201,8 @@ let src = code_typed1((Bool,Bool,Any,Any)) do c1, c2, x, y
198201
@test any(isnew, src.code)
199202
end
200203

201-
# should include a simple alias analysis
202-
struct ImmutableOuter{T}; x::T; y::T; z::T; end
203-
mutable struct MutableOuter{T}; x::T; y::T; z::T; end
204+
# alias analysis
205+
# --------------
204206
let src = code_typed1((Any,Any,Any)) do x, y, z
205207
xyz = ImmutableXYZ(x, y, z)
206208
outer = ImmutableOuter(xyz, xyz, xyz)
@@ -227,9 +229,11 @@ let src = code_typed1((Any,Any,Any)) do x, y, z
227229
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
228230
end
229231
end
230-
231-
# FIXME our analysis isn't yet so powerful at this moment: may be unable to handle nested objects well
232-
# OK: mutable(immutable(...)) case
232+
# FIXME? in order to handle nested mutable `getfield` calls, we run SROA iteratively until
233+
# any nested mutable `getfield` calls become no longer eliminatable:
234+
# it's probably not the most efficient option and we may want to introduce some sort of
235+
# alias analysis and eliminates all the loads at once.
236+
# mutable(immutable(...)) case
233237
let src = code_typed1((Any,Any,Any)) do x, y, z
234238
xyz = MutableXYZ(x, y, z)
235239
t = (xyz,)
@@ -260,21 +264,41 @@ let # this is a simple end to end test case, which demonstrates allocation elimi
260264
# compiled code for `simple_sroa`, otherwise everything can be folded even without SROA
261265
@test @allocated(simple_sroa(s)) == 0
262266
end
263-
# FIXME: immutable(mutable(...)) case
267+
# immutable(mutable(...)) case
264268
let src = code_typed1((Any,Any,Any)) do x, y, z
265269
xyz = ImmutableXYZ(x, y, z)
266270
outer = MutableOuter(xyz, xyz, xyz)
267271
outer.x.x, outer.y.y, outer.z.z
268272
end
269-
@test_broken !any(isnew, src.code)
273+
@test !any(isnew, src.code)
274+
@test any(src.code) do @nospecialize x
275+
iscall((src, tuple), x) &&
276+
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
277+
end
270278
end
271-
# FIXME: mutable(mutable(...)) case
279+
# mutable(mutable(...)) case
272280
let src = code_typed1((Any,Any,Any)) do x, y, z
273281
xyz = MutableXYZ(x, y, z)
274282
outer = MutableOuter(xyz, xyz, xyz)
275283
outer.x.x, outer.y.y, outer.z.z
276284
end
277-
@test_broken !any(isnew, src.code)
285+
@test !any(isnew, src.code)
286+
@test any(src.code) do @nospecialize x
287+
iscall((src, tuple), x) &&
288+
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
289+
end
290+
end
291+
let src = code_typed1((Any,Any,Any)) do x, y, z
292+
xyz = MutableXYZ(x, y, z)
293+
inner = MutableOuter(xyz, xyz, xyz)
294+
outer = MutableOuter(inner, inner, inner)
295+
outer.x.x.x, outer.y.y.y, outer.z.z.z
296+
end
297+
@test !any(isnew, src.code)
298+
@test any(src.code) do @nospecialize x
299+
iscall((src, tuple), x) &&
300+
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
301+
end
278302
end
279303

280304
# should work nicely with inlining to optimize away a complicated case

0 commit comments

Comments
 (0)