@@ -329,13 +329,32 @@ struct LiftedValue
329329end
330330const LiftedLeaves = IdDict{Any, Union{Nothing,LiftedValue}}
331331
332+ # NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
333+ # which can be very large sometimes, and program counters in question are often very sparse
334+ const SPCSet = IdSet{Int}
335+
336+ mutable struct NestedLoads
337+ maybe:: Union{Nothing,SPCSet}
338+ NestedLoads () = new (nothing )
339+ end
340+ function record_nested_load! (nested_loads:: NestedLoads , pc:: Int )
341+ maybe = nested_loads. maybe
342+ maybe === nothing && (maybe = nested_loads. maybe = SPCSet ())
343+ push! (maybe:: SPCSet , pc)
344+ end
345+ function is_nested_load (nested_loads:: NestedLoads , pc:: Int )
346+ maybe = nested_loads. maybe
347+ maybe === nothing && return false
348+ return pc in maybe:: SPCSet
349+ end
350+
332351# try to compute lifted values that can replace `getfield(x, field)` call
333352# where `x` is an immutable struct that are defined at any of `leaves`
334- function lift_leaves (compact:: IncrementalCompact ,
335- @nospecialize (result_t), field:: Int , leaves :: Vector{Any} )
353+ function lift_leaves! (compact:: IncrementalCompact , leaves :: Vector{Any} ,
354+ @nospecialize (result_t), field:: Int , nested_loads :: NestedLoads )
336355 # For every leaf, the lifted value
337356 lifted_leaves = LiftedLeaves ()
338- maybe_undef = false
357+ local maybe_undef = false
339358 for leaf in leaves
340359 cache_key = leaf
341360 if isa (leaf, AnySSAValue)
@@ -382,11 +401,19 @@ function lift_leaves(compact::IncrementalCompact,
382401 ocleaf = simple_walk (compact, ocleaf)
383402 end
384403 ocdef, _ = walk_to_def (compact, ocleaf)
385- if isexpr (ocdef, :new_opaque_closure ) && isa (field, Int) && 1 ≤ field ≤ length (ocdef. args)- 5
404+ if isexpr (ocdef, :new_opaque_closure ) && 1 ≤ field ≤ length (ocdef. args)- 5
386405 lift_arg! (compact, leaf, cache_key, ocdef, 5 + field, lifted_leaves)
387406 continue
388407 end
389408 return nothing
409+ elseif isa (def, Expr) && is_known_call (def, getfield, compact)
410+ if isa (leaf, SSAValue)
411+ struct_typ = unwrap_unionall (widenconst (compact_exprtype (compact, def. args[2 ])))
412+ if ismutabletype (struct_typ)
413+ record_nested_load! (nested_loads, leaf. id)
414+ end
415+ end
416+ return nothing
390417 else
391418 typ = compact_exprtype (compact, leaf)
392419 if ! isa (typ, Const)
@@ -611,10 +638,6 @@ function perform_lifting!(compact::IncrementalCompact,
611638 return stmt_val # N.B. should never happen
612639end
613640
614- # NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
615- # which can be very large sometimes, and program counters in question are often very sparse
616- const SPCSet = IdSet{Int}
617-
618641"""
619642 sroa_pass!(ir::IRCode) -> newir::IRCode
620643
@@ -633,10 +656,11 @@ its argument).
633656In a case when all usages are fully eliminated, `struct` allocation may also be erased as
634657a result of succeeding dead code elimination.
635658"""
636- function sroa_pass! (ir:: IRCode )
659+ function sroa_pass! (ir:: IRCode , optional_opts :: Bool = true )
637660 compact = IncrementalCompact (ir)
638661 defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations
639662 lifting_cache = IdDict {Pair{AnySSAValue, Any}, AnySSAValue} ()
663+ nested_loads = NestedLoads () # tracks nested `getfield(getfield(...)::Mutable, ...)::Immutable`
640664 for ((_, idx), stmt) in compact
641665 # check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
642666 isa (stmt, Expr) || continue
@@ -691,7 +715,9 @@ function sroa_pass!(ir::IRCode)
691715 if defuses === nothing
692716 defuses = IdDict {Int, Tuple{SPCSet, SSADefUse}} ()
693717 end
694- mid, defuse = get! (defuses, defidx, (SPCSet (), SSADefUse ()))
718+ mid, defuse = get! (defuses, defidx) do
719+ SPCSet (), SSADefUse ()
720+ end
695721 push! (defuse. ccall_preserve_uses, idx)
696722 union! (mid, intermediaries)
697723 end
@@ -704,16 +730,17 @@ function sroa_pass!(ir::IRCode)
704730 compact[idx] = new_expr
705731 end
706732 continue
707- # TODO : This isn't the best place to put these
708- elseif is_known_call (stmt, typeassert, compact)
709- canonicalize_typeassert! (compact, idx, stmt)
710- continue
711- elseif is_known_call (stmt, (=== ), compact)
712- lift_comparison! (compact, idx, stmt, lifting_cache)
713- continue
714- # elseif is_known_call(stmt, isa, compact)
715- # TODO do a similar optimization as `lift_comparison!` for `===`
716733 else
734+ if optional_opts
735+ # TODO : This isn't the best place to put these
736+ if is_known_call (stmt, typeassert, compact)
737+ canonicalize_typeassert! (compact, idx, stmt)
738+ elseif is_known_call (stmt, (=== ), compact)
739+ lift_comparison! (compact, idx, stmt, lifting_cache)
740+ # elseif is_known_call(stmt, isa, compact)
741+ # TODO do a similar optimization as `lift_comparison!` for `===`
742+ end
743+ end
717744 continue
718745 end
719746
@@ -749,7 +776,9 @@ function sroa_pass!(ir::IRCode)
749776 if defuses === nothing
750777 defuses = IdDict {Int, Tuple{SPCSet, SSADefUse}} ()
751778 end
752- mid, defuse = get! (defuses, def. id, (SPCSet (), SSADefUse ()))
779+ mid, defuse = get! (defuses, def. id) do
780+ SPCSet (), SSADefUse ()
781+ end
753782 if is_setfield
754783 push! (defuse. defs, idx)
755784 else
@@ -771,7 +800,7 @@ function sroa_pass!(ir::IRCode)
771800 isempty (leaves) && continue
772801
773802 result_t = compact_exprtype (compact, SSAValue (idx))
774- lifted_result = lift_leaves (compact, result_t, field, leaves )
803+ lifted_result = lift_leaves! (compact, leaves, result_t, field, nested_loads )
775804 lifted_result === nothing && continue
776805 lifted_leaves, any_undef = lifted_result
777806
@@ -807,20 +836,23 @@ function sroa_pass!(ir::IRCode)
807836 used_ssas = copy (compact. used_ssas)
808837 simple_dce! (compact, (x:: SSAValue ) -> used_ssas[x. id] -= 1 )
809838 ir = complete (compact)
810- sroa_mutables! (ir, defuses, used_ssas)
811- return ir
839+ return sroa_mutables! (ir, defuses, used_ssas, nested_loads)
812840 else
813841 simple_dce! (compact)
814842 return complete (compact)
815843 end
816844end
817845
818- function sroa_mutables! (ir:: IRCode , defuses:: IdDict{Int, Tuple{SPCSet, SSADefUse}} , used_ssas:: Vector{Int} )
846+ function sroa_mutables! (ir:: IRCode ,
847+ defuses:: IdDict{Int, Tuple{SPCSet, SSADefUse}} , used_ssas:: Vector{Int} ,
848+ nested_loads:: NestedLoads )
819849 # Compute domtree, needed below, now that we have finished compacting the IR.
820850 # This needs to be after we iterate through the IR with `IncrementalCompact`
821851 # because removing dead blocks can invalidate the domtree.
822852 @timeit " domtree 2" domtree = construct_domtree (ir. cfg. blocks)
823853
854+ nested_mloads = NestedLoads () # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable`
855+ local any_eliminated = any_meliminated = false
824856 for (idx, (intermediaries, defuse)) in defuses
825857 intermediaries = collect (intermediaries)
826858 # Check if there are any uses we did not account for. If so, the variable
@@ -836,7 +868,19 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
836868 nleaves == nuses_total || continue
837869 # Find the type for this allocation
838870 defexpr = ir[SSAValue (idx)]
839- isexpr (defexpr, :new ) || continue
871+ isa (defexpr, Expr) || continue
872+ if ! isexpr (defexpr, :new )
873+ if is_known_call (defexpr, getfield, ir)
874+ val = defexpr. args[2 ]
875+ if isa (val, SSAValue)
876+ struct_typ = unwrap_unionall (widenconst (argextype (val, ir)))
877+ if ismutabletype (struct_typ)
878+ record_nested_load! (nested_mloads, idx)
879+ end
880+ end
881+ end
882+ continue
883+ end
840884 newidx = idx
841885 typ = ir. stmts[newidx][:type ]
842886 if isa (typ, UnionAll)
@@ -900,6 +944,12 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
900944 # Now go through all uses and rewrite them
901945 for stmt in du. uses
902946 ir[SSAValue (stmt)] = compute_value_for_use (ir, domtree, allblocks, du, phinodes, fidx, stmt)
947+ if ! any_eliminated
948+ any_eliminated |= is_nested_load (nested_loads, stmt)
949+ end
950+ if ! any_meliminated
951+ any_meliminated |= is_nested_load (nested_mloads, stmt)
952+ end
903953 end
904954 if ! isbitstype (ftyp)
905955 if preserve_uses != = nothing
@@ -938,6 +988,11 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
938988
939989 @label skip
940990 end
991+ if any_eliminated || any_meliminated
992+ return sroa_pass! (compact! (ir), false )
993+ else
994+ return ir
995+ end
941996end
942997
943998"""
0 commit comments