@@ -169,7 +169,7 @@ function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospec
169169end
170170
171171function simple_walk (compact:: IncrementalCompact , @nospecialize (defssa#= ::AnySSAValue=# ),
172- callback = (@nospecialize (pi ), @nospecialize (idx)) -> false )
172+ callback = (@nospecialize (x ), @nospecialize (idx)) -> false )
173173 while true
174174 if isa (defssa, OldSSAValue)
175175 if already_inserted (compact, defssa)
@@ -337,10 +337,29 @@ struct LiftedValue
337337end
338338const LiftedLeaves = IdDict{Any, Union{Nothing,LiftedValue}}
339339
340+ # NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
341+ # which can be very large sometimes, and program counters in question are often very sparse
342+ const SPCSet = IdSet{Int}
343+
344+ mutable struct NestedLoads
345+ maybe:: Union{Nothing,SPCSet}
346+ NestedLoads () = new (nothing )
347+ end
348+ function record_nested_load! (nested_loads:: NestedLoads , pc:: Int )
349+ maybe = nested_loads. maybe
350+ maybe === nothing && (maybe = nested_loads. maybe = SPCSet ())
351+ push! (maybe:: SPCSet , pc)
352+ end
353+ function is_nested_load (nested_loads:: NestedLoads , pc:: Int )
354+ maybe = nested_loads. maybe
355+ maybe === nothing && return false
356+ return pc in maybe:: SPCSet
357+ end
358+
340359# try to compute lifted values that can replace `getfield(x, field)` call
341360# where `x` is an immutable struct that are defined at any of `leaves`
342- function lift_leaves (compact:: IncrementalCompact ,
343- @nospecialize (result_t), field:: Int , leaves :: Vector{Any} )
361+ function lift_leaves! (compact:: IncrementalCompact , leaves :: Vector{Any} ,
362+ @nospecialize (result_t), field:: Int , nested_loads :: NestedLoads )
344363 # For every leaf, the lifted value
345364 lifted_leaves = LiftedLeaves ()
346365 maybe_undef = false
@@ -390,11 +409,19 @@ function lift_leaves(compact::IncrementalCompact,
390409 ocleaf = simple_walk (compact, ocleaf)
391410 end
392411 ocdef, _ = walk_to_def (compact, ocleaf)
393- if isexpr (ocdef, :new_opaque_closure ) && isa (field, Int) && 1 ≤ field ≤ length (ocdef. args)- 5
412+ if isexpr (ocdef, :new_opaque_closure ) && 1 ≤ field ≤ length (ocdef. args)- 5
394413 lift_arg! (compact, leaf, cache_key, ocdef, 5 + field, lifted_leaves)
395414 continue
396415 end
397416 return nothing
417+ elseif is_known_call (def, getfield, compact)
418+ if isa (leaf, SSAValue)
419+ struct_typ = unwrap_unionall (widenconst (argextype (def. args[2 ], compact)))
420+ if ismutabletype (struct_typ)
421+ record_nested_load! (nested_loads, leaf. id)
422+ end
423+ end
424+ return nothing
398425 else
399426 typ = argextype (leaf, compact)
400427 if ! isa (typ, Const)
@@ -588,7 +615,7 @@ function perform_lifting!(compact::IncrementalCompact,
588615 end
589616 val = lifted_val. x
590617 if isa (val, AnySSAValue)
591- callback = (@nospecialize (pi ), @nospecialize (idx)) -> true
618+ callback = (@nospecialize (x ), @nospecialize (idx)) -> true
592619 val = simple_walk (compact, val, callback)
593620 end
594621 push! (new_node. values, val)
@@ -619,10 +646,6 @@ function perform_lifting!(compact::IncrementalCompact,
619646 return stmt_val # N.B. should never happen
620647end
621648
622- # NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
623- # which can be very large sometimes, and program counters in question are often very sparse
624- const SPCSet = IdSet{Int}
625-
626649"""
627650 sroa_pass!(ir::IRCode) -> newir::IRCode
628651
@@ -641,10 +664,11 @@ its argument).
641664In a case when all usages are fully eliminated, `struct` allocation may also be erased as
642665a result of succeeding dead code elimination.
643666"""
644- function sroa_pass! (ir:: IRCode )
667+ function sroa_pass! (ir:: IRCode , optional_opts :: Bool = true )
645668 compact = IncrementalCompact (ir)
646669 defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations
647670 lifting_cache = IdDict {Pair{AnySSAValue, Any}, AnySSAValue} ()
671+ nested_loads = NestedLoads () # tracks nested `getfield(getfield(...)::Mutable, ...)::Immutable`
648672 for ((_, idx), stmt) in compact
649673 # check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
650674 isa (stmt, Expr) || continue
@@ -672,7 +696,7 @@ function sroa_pass!(ir::IRCode)
672696 preserved_arg = stmt. args[pidx]
673697 isa (preserved_arg, SSAValue) || continue
674698 let intermediaries = SPCSet ()
675- callback = function (@nospecialize (pi ), @nospecialize (ssa))
699+ callback = function (@nospecialize (x ), @nospecialize (ssa))
676700 push! (intermediaries, ssa. id)
677701 return false
678702 end
@@ -700,7 +724,9 @@ function sroa_pass!(ir::IRCode)
700724 if defuses === nothing
701725 defuses = IdDict {Int, Tuple{SPCSet, SSADefUse}} ()
702726 end
703- mid, defuse = get! (defuses, defidx, (SPCSet (), SSADefUse ()))
727+ mid, defuse = get! (defuses, defidx) do
728+ SPCSet (), SSADefUse ()
729+ end
704730 push! (defuse. ccall_preserve_uses, idx)
705731 union! (mid, intermediaries)
706732 end
@@ -710,16 +736,17 @@ function sroa_pass!(ir::IRCode)
710736 compact[idx] = form_new_preserves (stmt, preserved, new_preserves)
711737 end
712738 continue
713- # TODO : This isn't the best place to put these
714- elseif is_known_call (stmt, typeassert, compact)
715- canonicalize_typeassert! (compact, idx, stmt)
716- continue
717- elseif is_known_call (stmt, (=== ), compact)
718- lift_comparison! (compact, idx, stmt, lifting_cache)
719- continue
720- # elseif is_known_call(stmt, isa, compact)
721- # TODO do a similar optimization as `lift_comparison!` for `===`
722739 else
740+ if optional_opts
741+ # TODO : This isn't the best place to put these
742+ if is_known_call (stmt, typeassert, compact)
743+ canonicalize_typeassert! (compact, idx, stmt)
744+ elseif is_known_call (stmt, (=== ), compact)
745+ lift_comparison! (compact, idx, stmt, lifting_cache)
746+ # elseif is_known_call(stmt, isa, compact)
747+ # TODO do a similar optimization as `lift_comparison!` for `===`
748+ end
749+ end
723750 continue
724751 end
725752
@@ -745,7 +772,7 @@ function sroa_pass!(ir::IRCode)
745772 if ismutabletype (struct_typ)
746773 isa (val, SSAValue) || continue
747774 let intermediaries = SPCSet ()
748- callback = function (@nospecialize (pi ), @nospecialize (ssa))
775+ callback = function (@nospecialize (x ), @nospecialize (ssa))
749776 push! (intermediaries, ssa. id)
750777 return false
751778 end
@@ -755,7 +782,9 @@ function sroa_pass!(ir::IRCode)
755782 if defuses === nothing
756783 defuses = IdDict {Int, Tuple{SPCSet, SSADefUse}} ()
757784 end
758- mid, defuse = get! (defuses, def. id, (SPCSet (), SSADefUse ()))
785+ mid, defuse = get! (defuses, def. id) do
786+ SPCSet (), SSADefUse ()
787+ end
759788 if is_setfield
760789 push! (defuse. defs, idx)
761790 else
@@ -777,7 +806,7 @@ function sroa_pass!(ir::IRCode)
777806 isempty (leaves) && continue
778807
779808 result_t = argextype (SSAValue (idx), compact)
780- lifted_result = lift_leaves (compact, result_t, field, leaves )
809+ lifted_result = lift_leaves! (compact, leaves, result_t, field, nested_loads )
781810 lifted_result === nothing && continue
782811 lifted_leaves, any_undef = lifted_result
783812
@@ -813,18 +842,21 @@ function sroa_pass!(ir::IRCode)
813842 used_ssas = copy (compact. used_ssas)
814843 simple_dce! (compact, (x:: SSAValue ) -> used_ssas[x. id] -= 1 )
815844 ir = complete (compact)
816- sroa_mutables! (ir, defuses, used_ssas)
817- return ir
845+ return sroa_mutables! (ir, defuses, used_ssas, nested_loads)
818846 else
819847 simple_dce! (compact)
820848 return complete (compact)
821849 end
822850end
823851
824- function sroa_mutables! (ir:: IRCode , defuses:: IdDict{Int, Tuple{SPCSet, SSADefUse}} , used_ssas:: Vector{Int} )
825- # initialization of domtree is delayed to avoid the expensive computation in many cases
826- local domtree = nothing
827- for (idx, (intermediaries, defuse)) in defuses
852+ function sroa_mutables! (ir:: IRCode ,
853+ defuses:: IdDict{Int, Tuple{SPCSet, SSADefUse}} , used_ssas:: Vector{Int} ,
854+ nested_loads:: NestedLoads )
855+ domtree = nothing # initialization of domtree is delayed to avoid the expensive computation in many cases
856+ nested_mloads = NestedLoads () # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable`
857+ any_eliminated = false
858+ # NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield`
859+ for (idx, (intermediaries, defuse)) in sort! (collect (defuses); by= first, rev= true )
828860 intermediaries = collect (intermediaries)
829861 # Check if there are any uses we did not account for. If so, the variable
830862 # escapes and we cannot eliminate the allocation. This works, because we're guaranteed
@@ -839,7 +871,19 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
839871 nleaves == nuses_total || continue
840872 # Find the type for this allocation
841873 defexpr = ir[SSAValue (idx)]
842- isexpr (defexpr, :new ) || continue
874+ isa (defexpr, Expr) || continue
875+ if ! isexpr (defexpr, :new )
876+ if is_known_call (defexpr, getfield, ir)
877+ val = defexpr. args[2 ]
878+ if isa (val, SSAValue)
879+ struct_typ = unwrap_unionall (widenconst (argextype (val, ir)))
880+ if ismutabletype (struct_typ)
881+ record_nested_load! (nested_mloads, idx)
882+ end
883+ end
884+ end
885+ continue
886+ end
843887 newidx = idx
844888 typ = ir. stmts[newidx][:type ]
845889 if isa (typ, UnionAll)
@@ -919,6 +963,10 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
919963 # Now go through all uses and rewrite them
920964 for stmt in du. uses
921965 ir[SSAValue (stmt)] = compute_value_for_use (ir, domtree, allblocks, du, phinodes, fidx, stmt)
966+ if ! any_eliminated
967+ any_eliminated |= (is_nested_load (nested_loads, stmt) ||
968+ is_nested_load (nested_mloads, stmt))
969+ end
922970 end
923971 if ! isbitstype (ftyp)
924972 if preserve_uses != = nothing
@@ -955,6 +1003,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
9551003
9561004 @label skip
9571005 end
1006+ return any_eliminated ? sroa_pass! (compact! (ir), false ) : ir
9581007end
9591008
9601009function form_new_preserves (origex:: Expr , intermediates:: Vector{Int} , new_preserves:: Vector{Any} )
0 commit comments