@@ -23,15 +23,15 @@ SSADefUse() = SSADefUse(Int[], Int[], Int[])
2323
2424compute_live_ins (cfg:: CFG , du:: SSADefUse ) = compute_live_ins (cfg, du. defs, du. uses)
2525
26- function try_compute_field_stmt (compact :: IncrementalCompact , stmt:: Expr )
26+ function try_compute_field_stmt (ir :: Union{ IncrementalCompact,IRCode} , stmt:: Expr )
2727 field = stmt. args[3 ]
2828 # fields are usually literals, handle them manually
2929 if isa (field, QuoteNode)
3030 field = field. value
3131 elseif isa (field, Int)
3232 # try to resolve other constants, e.g. global reference
3333 else
34- field = compact_exprtype (compact , field)
34+ field = isa (ir, IncrementalCompact) ? compact_exprtype (ir , field) : argextype (field, ir )
3535 if isa (field, Const)
3636 field = field. val
3737 else
@@ -42,8 +42,8 @@ function try_compute_field_stmt(compact::IncrementalCompact, stmt::Expr)
4242 return field
4343end
4444
45- function try_compute_fieldidx_stmt (compact :: IncrementalCompact , stmt:: Expr , typ:: DataType )
46- field = try_compute_field_stmt (compact , stmt)
45+ function try_compute_fieldidx_stmt (ir :: Union{ IncrementalCompact,IRCode} , stmt:: Expr , typ:: DataType )
46+ field = try_compute_field_stmt (ir , stmt)
4747 return try_compute_fieldidx (typ, field)
4848end
4949
@@ -112,6 +112,13 @@ function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int},
112112 return def, stmtblock, curblock
113113end
114114
115+ function collect_leaves (compact:: IncrementalCompact , @nospecialize (val), @nospecialize (typeconstraint))
116+ if isa (val, Union{OldSSAValue, SSAValue})
117+ val, typeconstraint = simple_walk_constraint (compact, val, typeconstraint)
118+ end
119+ return walk_to_defs (compact, val, typeconstraint)
120+ end
121+
115122function simple_walk (compact:: IncrementalCompact , @nospecialize (defssa#= ::AnySSAValue=# ),
116123 callback = (@nospecialize (pi ), @nospecialize (idx)) -> false )
117124 while true
@@ -152,7 +159,7 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
152159end
153160
154161function simple_walk_constraint (compact:: IncrementalCompact , @nospecialize (defssa#= ::AnySSAValue=# ),
155- @nospecialize (typeconstraint = types (compact)[defssa] ))
162+ @nospecialize (typeconstraint))
156163 callback = function (@nospecialize (pi ), @nospecialize (idx))
157164 if isa (pi , PiNode)
158165 typeconstraint = typeintersect (typeconstraint, widenconst (pi . typ))
@@ -164,20 +171,16 @@ function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defss
164171end
165172
166173"""
167- walk_to_defs(compact, val, intermediaries )
174+ walk_to_defs(compact, val, typeconstraint )
168175
169- Starting at `val` walk use-def chains to get all the leaves feeding into
170- this val (pruning those leaves rules out by path conditions).
176+ Starting at `val` walk use-def chains to get all the leaves feeding into this `val`
177+ (pruning those leaves rules out by path conditions).
171178"""
172- function walk_to_defs (compact:: IncrementalCompact , @nospecialize (defssa), @nospecialize (typeconstraint), visited_phinodes:: Vector{AnySSAValue} = AnySSAValue[])
173- isa (defssa, AnySSAValue) || return Any[defssa]
179+ function walk_to_defs (compact:: IncrementalCompact , @nospecialize (defssa), @nospecialize (typeconstraint))
180+ visited_phinodes = AnySSAValue[]
181+ isa (defssa, AnySSAValue) || return Any[defssa], visited_phinodes
174182 def = compact[defssa]
175- isa (def, PhiNode) || return Any[defssa]
176- # Step 2: Figure out what the struct is defined as
177- # # Track definitions through PiNode/PhiNode
178- found_def = false
179- # # Track which PhiNodes, SSAValue intermediaries
180- # # we forwarded through.
183+ isa (def, PhiNode) || return Any[defssa], visited_phinodes
181184 visited_constraints = IdDict {AnySSAValue, Any} ()
182185 worklist_defs = AnySSAValue[]
183186 worklist_constraints = Any[]
@@ -239,10 +242,10 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
239242 push! (leaves, defssa)
240243 end
241244 end
242- leaves
245+ return leaves, visited_phinodes
243246end
244247
245- function process_immutable_preserve (new_preserves:: Vector{Any} , compact:: IncrementalCompact , def:: Expr )
248+ function process_immutable_preserve! (new_preserves:: Vector{Any} , compact:: IncrementalCompact , def:: Expr )
246249 for arg in (isexpr (def, :new ) ? def. args : def. args[2 : end ])
247250 if ! isbitstype (widenconst (compact_exprtype (compact, arg)))
248251 push! (new_preserves, arg)
@@ -449,13 +452,10 @@ function lift_comparison!(compact::IncrementalCompact,
449452 return
450453 end
451454
452- if isa (val, Union{OldSSAValue, SSAValue})
453- val, typeconstraint = simple_walk_constraint (compact, val, typeconstraint)
454- end
455-
456- visited_phinodes = AnySSAValue[]
457- leaves = walk_to_defs (compact, val, typeconstraint, visited_phinodes)
455+ valtyp = widenconst (compact_exprtype (compact, val))
456+ isa (valtyp, Union) || return # bail out if there won't be a good chance for lifting
458457
458+ leaves, visited_phinodes = collect_leaves (compact, val, valtyp)
459459 length (leaves) ≤ 1 && return # bail out if we don't have multiple leaves
460460
461461 # Let's check if we evaluate the comparison for each one of the leaves
@@ -476,10 +476,6 @@ function lift_comparison!(compact::IncrementalCompact,
476476 visited_phinodes, cmp, lifting_cache, Bool,
477477 lifted_leaves:: IdDict{Any, Union{Nothing,LiftedValue}} , val):: LiftedValue
478478
479- # global assertion_counter
480- # assertion_counter::Int += 1
481- # insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), lifted_val), nothing, 0, true)
482- # return
483479 compact[idx] = lifted_val. x
484480end
485481
@@ -576,6 +572,10 @@ function perform_lifting!(compact::IncrementalCompact,
576572 return stmt_val # N.B. should never happen
577573end
578574
575+ # NOTE we use `IdSet{Int}` instead of `BitSet` for `sroa_pass!` since it works on IR after inlining,
576+ # which can be very large sometimes, and analyzed program counters are often very sparse
577+ const SPCSet = IdSet{Int}
578+
579579"""
580580 sroa_pass!(ir::IRCode) -> newir::IRCode
581581
@@ -596,17 +596,16 @@ a result of succeeding dead code elimination.
596596"""
597597function sroa_pass! (ir:: IRCode )
598598 compact = IncrementalCompact (ir)
599- defuses = IdDict {Int, Tuple{IdSet{Int}, SSADefUse}} ()
599+ defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations
600600 lifting_cache = IdDict {Pair{AnySSAValue, Any}, AnySSAValue} ()
601601 for ((_, idx), stmt) in compact
602+ # check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
602603 isa (stmt, Expr) || continue
603- result_t = compact_exprtype (compact, SSAValue (idx))
604604 is_setfield = false
605605 field_ordering = :unspecified
606- # Step 1: Check whether the statement we're looking at is a getfield/setfield!
607606 if is_known_call (stmt, setfield!, compact)
608- is_setfield = true
609607 4 <= length (stmt. args) <= 5 || continue
608+ is_setfield = true
610609 if length (stmt. args) == 5
611610 field_ordering = compact_exprtype (compact, stmt. args[5 ])
612611 end
@@ -624,7 +623,7 @@ function sroa_pass!(ir::IRCode)
624623 old_preserves = stmt. args[(6 + nccallargs): end ]
625624 for (pidx, preserved_arg) in enumerate (old_preserves)
626625 isa (preserved_arg, SSAValue) || continue
627- let intermediaries = IdSet {Int} ()
626+ let intermediaries = SPCSet ()
628627 callback = function (@nospecialize (pi ), @nospecialize (ssa))
629628 push! (intermediaries, ssa. id)
630629 return false
@@ -634,7 +633,7 @@ function sroa_pass!(ir::IRCode)
634633 defidx = def. id
635634 def = compact[defidx]
636635 if is_tuple_call (compact, def)
637- process_immutable_preserve (new_preserves, compact, def)
636+ process_immutable_preserve! (new_preserves, compact, def)
638637 old_preserves[pidx] = nothing
639638 continue
640639 elseif isexpr (def, :new )
@@ -643,14 +642,17 @@ function sroa_pass!(ir::IRCode)
643642 typ = unwrap_unionall (typ)
644643 end
645644 if typ isa DataType && ! ismutabletype (typ)
646- process_immutable_preserve (new_preserves, compact, def)
645+ process_immutable_preserve! (new_preserves, compact, def)
647646 old_preserves[pidx] = nothing
648647 continue
649648 end
650649 else
651650 continue
652651 end
653- mid, defuse = get! (defuses, defidx, (IdSet {Int} (), SSADefUse ()))
652+ if defuses === nothing
653+ defuses = IdDict {Int, Tuple{SPCSet, SSADefUse}} ()
654+ end
655+ mid, defuse = get! (defuses, defidx, (SPCSet (), SSADefUse ()))
654656 push! (defuse. ccall_preserve_uses, idx)
655657 union! (mid, intermediaries)
656658 end
@@ -675,10 +677,15 @@ function sroa_pass!(ir::IRCode)
675677 else
676678 continue
677679 end
680+
681+ # analyze this `getfield` / `setfield!` call
682+
678683 field = try_compute_field_stmt (compact, stmt)
679684 field === nothing && continue
680685
681- struct_typ = unwrap_unionall (widenconst (compact_exprtype (compact, stmt. args[2 ])))
686+ val = stmt. args[2 ]
687+
688+ struct_typ = unwrap_unionall (widenconst (compact_exprtype (compact, val)))
682689 if isa (struct_typ, Union) && struct_typ <: Tuple
683690 struct_typ = unswitchtupleunion (struct_typ)
684691 end
@@ -689,19 +696,21 @@ function sroa_pass!(ir::IRCode)
689696 continue
690697 end
691698
692- def, typeconstraint = stmt. args[2 ], struct_typ
693-
699+ # analyze this mutable struct here for the later pass
694700 if ismutabletype (struct_typ)
695- isa (def , SSAValue) || continue
696- let intermediaries = IdSet {Int} ()
701+ isa (val , SSAValue) || continue
702+ let intermediaries = SPCSet ()
697703 callback = function (@nospecialize (pi ), @nospecialize (ssa))
698704 push! (intermediaries, ssa. id)
699705 return false
700706 end
701- def = simple_walk (compact, def , callback)
707+ def = simple_walk (compact, val , callback)
702708 # Mutable stuff here
703709 isa (def, SSAValue) || continue
704- mid, defuse = get! (defuses, def. id, (IdSet {Int} (), SSADefUse ()))
710+ if defuses === nothing
711+ defuses = IdDict {Int, Tuple{SPCSet, SSADefUse}} ()
712+ end
713+ mid, defuse = get! (defuses, def. id, (SPCSet (), SSADefUse ()))
705714 if is_setfield
706715 push! (defuse. defs, idx)
707716 else
@@ -711,32 +720,28 @@ function sroa_pass!(ir::IRCode)
711720 end
712721 continue
713722 elseif is_setfield
714- continue
723+ continue # invalid `setfield!` call, but just ignore here
715724 end
716725
717726 # perform SROA on immutable structs here on
718727
719- if isa (def, Union{OldSSAValue, SSAValue})
720- def, typeconstraint = simple_walk_constraint (compact, def, typeconstraint)
721- end
722-
723- visited_phinodes = AnySSAValue[]
724- leaves = walk_to_defs (compact, def, typeconstraint, visited_phinodes)
725-
726- isempty (leaves) && continue
727-
728728 field = try_compute_fieldidx (struct_typ, field)
729729 field === nothing && continue
730730
731- r = lift_leaves (compact, result_t, field, leaves)
732- r === nothing && continue
733- lifted_leaves, any_undef = r
731+ leaves, visited_phinodes = collect_leaves (compact, val, struct_typ)
732+ isempty (leaves) && continue
733+
734+ result_t = compact_exprtype (compact, SSAValue (idx))
735+ lifted_result = lift_leaves (compact, result_t, field, leaves)
736+ lifted_result === nothing && continue
737+ lifted_leaves, any_undef = lifted_result
734738
735739 if any_undef
736740 result_t = make_MaybeUndef (result_t)
737741 end
738742
739- val = perform_lifting! (compact, visited_phinodes, field, lifting_cache, result_t, lifted_leaves, stmt. args[2 ])
743+ val = perform_lifting! (compact,
744+ visited_phinodes, field, lifting_cache, result_t, lifted_leaves, val)
740745
741746 # Insert the undef check if necessary
742747 if any_undef
@@ -750,28 +755,32 @@ function sroa_pass!(ir::IRCode)
750755 @assert val != = nothing
751756 end
752757
753- # global assertion_counter
754- # assertion_counter::Int += 1
755- # insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), val), nothing, 0, true)
756- # continue
757758 compact[idx] = val === nothing ? nothing : val. x
758759 end
759760
760761 non_dce_finish! (compact)
761- # Copy the use count, `simple_dce!` may modify it and for our predicate
762- # below we need it consistent with the state of the IR here (after tracking
763- # phi node arguments, but before dce).
764- used_ssas = copy (compact. used_ssas)
765- simple_dce! (compact)
766- ir = complete (compact)
767-
768- # Compute domtree, needed below, now that we have finished compacting the
769- # IR. This needs to be after we iterate through the IR with
770- # `IncrementalCompact` because removing dead blocks can invalidate the
771- # domtree.
762+ if defuses != = nothing
763+ # now go through analyzed mutable structs and see which ones we can eliminate
764+ # NOTE copy the use count here, because `simple_dce!` may modify it and we need it
765+ # consistent with the state of the IR here (after tracking `PhiNode` arguments,
766+ # but before the DCE) for our predicate within `sroa_mutables!`
767+ used_ssas = copy (compact. used_ssas)
768+ simple_dce! (compact)
769+ ir = complete (compact)
770+ sroa_mutables! (ir, defuses, used_ssas)
771+ return ir
772+ else
773+ simple_dce! (compact)
774+ return complete (compact)
775+ end
776+ end
777+
778+ function sroa_mutables! (ir:: IRCode , defuses:: IdDict{Int, Tuple{SPCSet, SSADefUse}} , used_ssas:: Vector{Int} )
779+ # Compute domtree, needed below, now that we have finished compacting the IR.
780+ # This needs to be after we iterate through the IR with `IncrementalCompact`
781+ # because removing dead blocks can invalidate the domtree.
772782 @timeit " domtree 2" domtree = construct_domtree (ir. cfg. blocks)
773783
774- # Now go through any mutable structs and see which ones we can eliminate
775784 for (idx, (intermediaries, defuse)) in defuses
776785 intermediaries = collect (intermediaries)
777786 # Check if there are any uses we did not account for. If so, the variable
@@ -806,12 +815,12 @@ function sroa_pass!(ir::IRCode)
806815 # it would have been deleted. That's fine, just ignore
807816 # the use in that case.
808817 stmt === nothing && continue
809- field = try_compute_fieldidx_stmt (compact , stmt:: Expr , typ)
818+ field = try_compute_fieldidx_stmt (ir , stmt:: Expr , typ)
810819 field === nothing && @goto skip
811820 push! (fielddefuse[field]. uses, use)
812821 end
813822 for use in defuse. defs
814- field = try_compute_fieldidx_stmt (compact , ir[SSAValue (use)]:: Expr , typ)
823+ field = try_compute_fieldidx_stmt (ir , ir[SSAValue (use)]:: Expr , typ)
815824 field === nothing && @goto skip
816825 push! (fielddefuse[field]. defs, use)
817826 end
@@ -846,8 +855,9 @@ function sroa_pass!(ir::IRCode)
846855 end
847856 end
848857 end
849- preserve_uses = IdDict {Int, Vector{Any}} ((idx=> Any[] for idx in IdSet {Int} (defuse. ccall_preserve_uses)))
850858 # Everything accounted for. Go field by field and perform idf
859+ preserve_uses = isempty (defuse. ccall_preserve_uses) ? nothing :
860+ IdDict {Int, Vector{Any}} ((idx=> Any[] for idx in SPCSet (defuse. ccall_preserve_uses)))
851861 for fidx in 1 : ndefuse
852862 du = fielddefuse[fidx]
853863 ftyp = fieldtype (typ, fidx)
@@ -863,8 +873,10 @@ function sroa_pass!(ir::IRCode)
863873 ir[SSAValue (stmt)] = compute_value_for_use (ir, domtree, allblocks, du, phinodes, fidx, stmt)
864874 end
865875 if ! isbitstype (ftyp)
866- for (use, list) in preserve_uses
867- push! (list, compute_value_for_use (ir, domtree, allblocks, du, phinodes, fidx, use))
876+ if preserve_uses != = nothing
877+ for (use, list) in preserve_uses
878+ push! (list, compute_value_for_use (ir, domtree, allblocks, du, phinodes, fidx, use))
879+ end
868880 end
869881 end
870882 for b in phiblocks
@@ -881,7 +893,7 @@ function sroa_pass!(ir::IRCode)
881893 ir[SSAValue (stmt)] = nothing
882894 end
883895 end
884- isempty (defuse . ccall_preserve_uses) && continue
896+ preserve_uses === nothing && continue
885897 push! (intermediaries, newidx)
886898 # Insert the new preserves
887899 for (use, new_preserves) in preserve_uses
@@ -897,10 +909,7 @@ function sroa_pass!(ir::IRCode)
897909
898910 @label skip
899911 end
900-
901- return ir
902912end
903- # assertion_counter = 0
904913
905914"""
906915 canonicalize_typeassert!(compact::IncrementalCompact, idx::Int, stmt::Expr)
0 commit comments