@@ -29,12 +29,14 @@ SSADefUse() = SSADefUse(Int[], Int[], Int[])
2929
3030compute_live_ins (cfg:: CFG , du:: SSADefUse ) =  compute_live_ins (cfg, du. defs, du. uses)
3131
32- function  try_compute_field_stmt (ir:: Union{IncrementalCompact,IRCode} , stmt:: Expr )
33-     field =  stmt. args[3 ]
32+ try_compute_field_stmt (ir:: Union{IncrementalCompact,IRCode} , stmt:: Expr ) = 
33+     try_compute_field (ir, stmt. args[3 ])
34+ 
35+ function  try_compute_field (ir:: Union{IncrementalCompact,IRCode} , @nospecialize (field))
3436    #  fields are usually literals, handle them manually
3537    if  isa (field, QuoteNode)
3638        field =  field. value
37-     elseif  isa (field, Int)
39+     elseif  isa (field, Int)  ||   isa (field, Symbol) 
3840    #  try to resolve other constants, e.g. global reference
3941    else 
4042        field =  argextype (field, ir)
@@ -44,8 +46,7 @@ function try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr
4446            return  nothing 
4547        end 
4648    end 
47-     isa (field, Union{Int, Symbol}) ||  return  nothing 
48-     return  field
49+     return  isa (field, Union{Int, Symbol}) ?  field :  nothing 
4950end 
5051
5152function  try_compute_fieldidx_stmt (ir:: Union{IncrementalCompact,IRCode} , stmt:: Expr , typ:: DataType )
@@ -167,7 +168,7 @@ function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospec
167168end 
168169
169170function  simple_walk (compact:: IncrementalCompact , @nospecialize (defssa#= ::AnySSAValue=#  ),
170-                      callback =  (@nospecialize (pi ), @nospecialize (idx)) ->  false )
171+                      callback =  (@nospecialize (x ), @nospecialize (idx)) ->  false )
171172    while  true 
172173        if  isa (defssa, OldSSAValue)
173174            if  already_inserted (compact, defssa)
@@ -335,10 +336,29 @@ struct LiftedValue
335336end 
336337const  LiftedLeaves =  IdDict{Any, Union{Nothing,LiftedValue}}
337338
339+ #  NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
340+ #  which can be very large sometimes, and program counters in question are often very sparse
341+ const  SPCSet =  IdSet{Int}
342+ 
343+ mutable struct  NestedLoads
344+     maybe:: Union{Nothing,SPCSet} 
345+     NestedLoads () =  new (nothing )
346+ end 
347+ function  record_nested_load! (nested_loads:: NestedLoads , pc:: Int )
348+     maybe =  nested_loads. maybe
349+     maybe ===  nothing  &&  (maybe =  nested_loads. maybe =  SPCSet ())
350+     push! (maybe:: SPCSet , pc)
351+ end 
352+ function  is_nested_load (nested_loads:: NestedLoads , pc:: Int )
353+     maybe =  nested_loads. maybe
354+     maybe ===  nothing  &&  return  false 
355+     return  pc in  maybe:: SPCSet 
356+ end 
357+ 
338358#  try to compute lifted values that can replace `getfield(x, field)` call
339359#  where `x` is an immutable struct that are defined at any of `leaves`
340- function  lift_leaves (compact:: IncrementalCompact ,
341-                      @nospecialize (result_t), field:: Int , leaves :: Vector{Any} )
360+ function  lift_leaves!  (compact:: IncrementalCompact , leaves :: Vector{Any} ,
361+                        @nospecialize (result_t), field:: Int , nested_loads :: NestedLoads )
342362    #  For every leaf, the lifted value
343363    lifted_leaves =  LiftedLeaves ()
344364    maybe_undef =  false 
@@ -388,11 +408,19 @@ function lift_leaves(compact::IncrementalCompact,
388408                    ocleaf =  simple_walk (compact, ocleaf)
389409                end 
390410                ocdef, _ =  walk_to_def (compact, ocleaf)
391-                 if  isexpr (ocdef, :new_opaque_closure ) &&  isa (field, Int)  &&   1  ≤  field ≤  length (ocdef. args)- 5 
411+                 if  isexpr (ocdef, :new_opaque_closure ) &&  1  ≤  field ≤  length (ocdef. args)- 5 
392412                    lift_arg! (compact, leaf, cache_key, ocdef, 5 + field, lifted_leaves)
393413                    continue 
394414                end 
395415                return  nothing 
416+             elseif  is_known_call (def, getfield, compact)
417+                 if  isa (leaf, SSAValue)
418+                     struct_typ =  unwrap_unionall (widenconst (argextype (def. args[2 ], compact)))
419+                     if  ismutabletype (struct_typ)
420+                         record_nested_load! (nested_loads, leaf. id)
421+                     end 
422+                 end 
423+                 return  nothing 
396424            else 
397425                typ =  argextype (leaf, compact)
398426                if  ! isa (typ, Const)
@@ -586,7 +614,7 @@ function perform_lifting!(compact::IncrementalCompact,
586614                end 
587615                val =  lifted_val. x
588616                if  isa (val, AnySSAValue)
589-                     callback =  (@nospecialize (pi ), @nospecialize (idx)) ->  true 
617+                     callback =  (@nospecialize (x ), @nospecialize (idx)) ->  true 
590618                    val =  simple_walk (compact, val, callback)
591619                end 
592620                push! (new_node. values, val)
@@ -617,10 +645,6 @@ function perform_lifting!(compact::IncrementalCompact,
617645    return  stmt_val #  N.B. should never happen
618646end 
619647
620- #  NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
621- #  which can be very large sometimes, and program counters in question are often very sparse
622- const  SPCSet =  IdSet{Int}
623- 
624648""" 
625649    sroa_pass!(ir::IRCode) -> newir::IRCode 
626650
@@ -639,10 +663,11 @@ its argument).
639663In a case when all usages are fully eliminated, `struct` allocation may also be erased as 
640664a result of succeeding dead code elimination. 
641665""" 
642- function  sroa_pass! (ir:: IRCode )
666+ function  sroa_pass! (ir:: IRCode , optional_opts :: Bool   =   true )
643667    compact =  IncrementalCompact (ir)
644668    defuses =  nothing  #  will be initialized once we encounter mutability in order to reduce dynamic allocations
645669    lifting_cache =  IdDict {Pair{AnySSAValue, Any}, AnySSAValue} ()
670+     nested_loads =  NestedLoads () #  tracks nested `getfield(getfield(...)::Mutable, ...)::Immutable`
646671    for  ((_, idx), stmt) in  compact
647672        #  check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
648673        isa (stmt, Expr) ||  continue 
@@ -670,7 +695,7 @@ function sroa_pass!(ir::IRCode)
670695                preserved_arg =  stmt. args[pidx]
671696                isa (preserved_arg, SSAValue) ||  continue 
672697                let  intermediaries =  SPCSet ()
673-                     callback =  function  (@nospecialize (pi ), @nospecialize (ssa))
698+                     callback =  function  (@nospecialize (x ), @nospecialize (ssa))
674699                        push! (intermediaries, ssa. id)
675700                        return  false 
676701                    end 
@@ -698,7 +723,9 @@ function sroa_pass!(ir::IRCode)
698723                    if  defuses ===  nothing 
699724                        defuses =  IdDict {Int, Tuple{SPCSet, SSADefUse}} ()
700725                    end 
701-                     mid, defuse =  get! (defuses, defidx, (SPCSet (), SSADefUse ()))
726+                     mid, defuse =  get! (defuses, defidx) do 
727+                         SPCSet (), SSADefUse ()
728+                     end 
702729                    push! (defuse. ccall_preserve_uses, idx)
703730                    union! (mid, intermediaries)
704731                end 
@@ -708,16 +735,17 @@ function sroa_pass!(ir::IRCode)
708735                compact[idx] =  form_new_preserves (stmt, preserved, new_preserves)
709736            end 
710737            continue 
711-         #  TODO : This isn't the best place to put these
712-         elseif  is_known_call (stmt, typeassert, compact)
713-             canonicalize_typeassert! (compact, idx, stmt)
714-             continue 
715-         elseif  is_known_call (stmt, (=== ), compact)
716-             lift_comparison! (compact, idx, stmt, lifting_cache)
717-             continue 
718-         #  elseif is_known_call(stmt, isa, compact)
719-             #  TODO  do a similar optimization as `lift_comparison!` for `===`
720738        else 
739+             if  optional_opts
740+                 #  TODO : This isn't the best place to put these
741+                 if  is_known_call (stmt, typeassert, compact)
742+                     canonicalize_typeassert! (compact, idx, stmt)
743+                 elseif  is_known_call (stmt, (=== ), compact)
744+                     lift_comparison! (compact, idx, stmt, lifting_cache)
745+                 #  elseif is_known_call(stmt, isa, compact)
746+                     #  TODO  do a similar optimization as `lift_comparison!` for `===`
747+                 end 
748+             end 
721749            continue 
722750        end 
723751
@@ -743,7 +771,7 @@ function sroa_pass!(ir::IRCode)
743771        if  ismutabletype (struct_typ)
744772            isa (val, SSAValue) ||  continue 
745773            let  intermediaries =  SPCSet ()
746-                 callback =  function  (@nospecialize (pi ), @nospecialize (ssa))
774+                 callback =  function  (@nospecialize (x ), @nospecialize (ssa))
747775                    push! (intermediaries, ssa. id)
748776                    return  false 
749777                end 
@@ -753,7 +781,9 @@ function sroa_pass!(ir::IRCode)
753781                if  defuses ===  nothing 
754782                    defuses =  IdDict {Int, Tuple{SPCSet, SSADefUse}} ()
755783                end 
756-                 mid, defuse =  get! (defuses, def. id, (SPCSet (), SSADefUse ()))
784+                 mid, defuse =  get! (defuses, def. id) do 
785+                     SPCSet (), SSADefUse ()
786+                 end 
757787                if  is_setfield
758788                    push! (defuse. defs, idx)
759789                else 
@@ -775,7 +805,7 @@ function sroa_pass!(ir::IRCode)
775805        isempty (leaves) &&  continue 
776806
777807        result_t =  argextype (SSAValue (idx), compact)
778-         lifted_result =  lift_leaves (compact, result_t, field, leaves )
808+         lifted_result =  lift_leaves!  (compact, leaves,  result_t, field, nested_loads )
779809        lifted_result ===  nothing  &&  continue 
780810        lifted_leaves, any_undef =  lifted_result
781811
@@ -811,18 +841,21 @@ function sroa_pass!(ir::IRCode)
811841        used_ssas =  copy (compact. used_ssas)
812842        simple_dce! (compact, (x:: SSAValue ) ->  used_ssas[x. id] -=  1 )
813843        ir =  complete (compact)
814-         sroa_mutables! (ir, defuses, used_ssas)
815-         return  ir
844+         return  sroa_mutables! (ir, defuses, used_ssas, nested_loads)
816845    else 
817846        simple_dce! (compact)
818847        return  complete (compact)
819848    end 
820849end 
821850
822- function  sroa_mutables! (ir:: IRCode , defuses:: IdDict{Int, Tuple{SPCSet, SSADefUse}} , used_ssas:: Vector{Int} )
823-     #  initialization of domtree is delayed to avoid the expensive computation in many cases
824-     local  domtree =  nothing 
825-     for  (idx, (intermediaries, defuse)) in  defuses
851+ function  sroa_mutables! (ir:: IRCode ,
852+     defuses:: IdDict{Int, Tuple{SPCSet, SSADefUse}} , used_ssas:: Vector{Int} ,
853+     nested_loads:: NestedLoads )
854+     local  domtree =  nothing  #  initialization of domtree is delayed to avoid the expensive computation in many cases
855+     nested_mloads =  NestedLoads () #  tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable`
856+     local  any_eliminated =  false 
857+     #  NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield`
858+     for  (idx, (intermediaries, defuse)) in  sort! (collect (defuses); by= first, rev= true )
826859        intermediaries =  collect (intermediaries)
827860        #  Check if there are any uses we did not account for. If so, the variable
828861        #  escapes and we cannot eliminate the allocation. This works, because we're guaranteed
@@ -837,7 +870,19 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
837870        nleaves ==  nuses_total ||  continue 
838871        #  Find the type for this allocation
839872        defexpr =  ir[SSAValue (idx)]
840-         isexpr (defexpr, :new ) ||  continue 
873+         isa (defexpr, Expr) ||  continue 
874+         if  ! isexpr (defexpr, :new )
875+             if  is_known_call (defexpr, getfield, ir)
876+                 val =  defexpr. args[2 ]
877+                 if  isa (val, SSAValue)
878+                     struct_typ =  unwrap_unionall (widenconst (argextype (val, ir)))
879+                     if  ismutabletype (struct_typ)
880+                         record_nested_load! (nested_mloads, idx)
881+                     end 
882+                 end 
883+             end 
884+             continue 
885+         end 
841886        newidx =  idx
842887        typ =  ir. stmts[newidx][:type ]
843888        if  isa (typ, UnionAll)
@@ -917,6 +962,10 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
917962                #  Now go through all uses and rewrite them
918963                for  stmt in  du. uses
919964                    ir[SSAValue (stmt)] =  compute_value_for_use (ir, domtree, allblocks, du, phinodes, fidx, stmt)
965+                     if  ! any_eliminated
966+                         any_eliminated |=  (is_nested_load (nested_loads,  stmt) || 
967+                                            is_nested_load (nested_mloads, stmt))
968+                     end 
920969                end 
921970                if  ! isbitstype (ftyp)
922971                    if  preserve_uses != =  nothing 
@@ -953,6 +1002,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
9531002
9541003        @label  skip
9551004    end 
1005+     return  any_eliminated ?  sroa_pass! (compact! (ir), false ) :  ir
9561006end 
9571007
9581008function  form_new_preserves (origex:: Expr , intermediates:: Vector{Int} , new_preserves:: Vector{Any} )
0 commit comments