@@ -23,15 +23,15 @@ SSADefUse() = SSADefUse(Int[], Int[], Int[])
2323
2424compute_live_ins (cfg:: CFG , du:: SSADefUse ) = compute_live_ins (cfg, du. defs, du. uses)
2525
26- function try_compute_field_stmt (compact :: IncrementalCompact , stmt:: Expr )
26+ function try_compute_field_stmt (ir :: Union{ IncrementalCompact,IRCode} , stmt:: Expr )
2727 field = stmt. args[3 ]
2828 # fields are usually literals, handle them manually
2929 if isa (field, QuoteNode)
3030 field = field. value
3131 elseif isa (field, Int)
3232 # try to resolve other constants, e.g. global reference
3333 else
34- field = compact_exprtype (compact , field)
34+ field = isa (ir, IncrementalCompact) ? compact_exprtype (ir , field) : argextype (field, ir )
3535 if isa (field, Const)
3636 field = field. val
3737 else
@@ -42,8 +42,8 @@ function try_compute_field_stmt(compact::IncrementalCompact, stmt::Expr)
4242 return field
4343end
4444
45- function try_compute_fieldidx_stmt (compact :: IncrementalCompact , stmt:: Expr , typ:: DataType )
46- field = try_compute_field_stmt (compact , stmt)
45+ function try_compute_fieldidx_stmt (ir :: Union{ IncrementalCompact,IRCode} , stmt:: Expr , typ:: DataType )
46+ field = try_compute_field_stmt (ir , stmt)
4747 return try_compute_fieldidx (typ, field)
4848end
4949
@@ -112,6 +112,13 @@ function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int},
112112 return def, stmtblock, curblock
113113end
114114
115+ function collect_leaves (compact:: IncrementalCompact , @nospecialize (val), @nospecialize (typeconstraint))
116+ if isa (val, Union{OldSSAValue, SSAValue})
117+ val, typeconstraint = simple_walk_constraint (compact, val, typeconstraint)
118+ end
119+ return walk_to_defs (compact, val, typeconstraint)
120+ end
121+
115122function simple_walk (compact:: IncrementalCompact , @nospecialize (defssa#= ::AnySSAValue=# ),
116123 callback = (@nospecialize (pi ), @nospecialize (idx)) -> false )
117124 while true
@@ -152,7 +159,7 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
152159end
153160
154161function simple_walk_constraint (compact:: IncrementalCompact , @nospecialize (defssa#= ::AnySSAValue=# ),
155- @nospecialize (typeconstraint = types (compact)[defssa] ))
162+ @nospecialize (typeconstraint))
156163 callback = function (@nospecialize (pi ), @nospecialize (idx))
157164 if isa (pi , PiNode)
158165 typeconstraint = typeintersect (typeconstraint, widenconst (pi . typ))
@@ -164,20 +171,16 @@ function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defss
164171end
165172
166173"""
167- walk_to_defs(compact, val, intermediaries )
174+ walk_to_defs(compact, val, typeconstraint )
168175
169- Starting at `val` walk use-def chains to get all the leaves feeding into
170- this val (pruning those leaves rules out by path conditions).
176+ Starting at `val` walk use-def chains to get all the leaves feeding into this `val`
177+ (pruning those leaves rules out by path conditions).
171178"""
172- function walk_to_defs (compact:: IncrementalCompact , @nospecialize (defssa), @nospecialize (typeconstraint), visited_phinodes:: Vector{AnySSAValue} = AnySSAValue[])
173- isa (defssa, AnySSAValue) || return Any[defssa]
179+ function walk_to_defs (compact:: IncrementalCompact , @nospecialize (defssa), @nospecialize (typeconstraint))
180+ visited_phinodes = AnySSAValue[]
181+ isa (defssa, AnySSAValue) || return Any[defssa], visited_phinodes
174182 def = compact[defssa]
175- isa (def, PhiNode) || return Any[defssa]
176- # Step 2: Figure out what the struct is defined as
177- # # Track definitions through PiNode/PhiNode
178- found_def = false
179- # # Track which PhiNodes, SSAValue intermediaries
180- # # we forwarded through.
183+ isa (def, PhiNode) || return Any[defssa], visited_phinodes
181184 visited_constraints = IdDict {AnySSAValue, Any} ()
182185 worklist_defs = AnySSAValue[]
183186 worklist_constraints = Any[]
@@ -239,10 +242,10 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
239242 push! (leaves, defssa)
240243 end
241244 end
242- leaves
245+ return leaves, visited_phinodes
243246end
244247
245- function process_immutable_preserve (new_preserves:: Vector{Any} , compact:: IncrementalCompact , def:: Expr )
248+ function process_immutable_preserve! (new_preserves:: Vector{Any} , compact:: IncrementalCompact , def:: Expr )
246249 for arg in (isexpr (def, :new ) ? def. args : def. args[2 : end ])
247250 if ! isbitstype (widenconst (compact_exprtype (compact, arg)))
248251 push! (new_preserves, arg)
@@ -449,13 +452,10 @@ function lift_comparison!(compact::IncrementalCompact,
449452 return
450453 end
451454
452- if isa (val, Union{OldSSAValue, SSAValue})
453- val, typeconstraint = simple_walk_constraint (compact, val, typeconstraint)
454- end
455-
456- visited_phinodes = AnySSAValue[]
457- leaves = walk_to_defs (compact, val, typeconstraint, visited_phinodes)
455+ valtyp = widenconst (compact_exprtype (compact, val))
456+ isa (valtyp, Union) || return # bail out if there won't be a good chance for lifting
458457
458+ leaves, visited_phinodes = collect_leaves (compact, val, valtyp)
459459 length (leaves) ≤ 1 && return # bail out if we don't have multiple leaves
460460
461461 # Let's check if we evaluate the comparison for each one of the leaves
@@ -596,17 +596,16 @@ a result of succeeding dead code elimination.
596596"""
597597function sroa_pass! (ir:: IRCode )
598598 compact = IncrementalCompact (ir)
599- defuses = IdDict {Int, Tuple{IdSet{Int}, SSADefUse}} ()
599+ defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations
600600 lifting_cache = IdDict {Pair{AnySSAValue, Any}, AnySSAValue} ()
601601 for ((_, idx), stmt) in compact
602+ # check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
602603 isa (stmt, Expr) || continue
603- result_t = compact_exprtype (compact, SSAValue (idx))
604604 is_setfield = false
605605 field_ordering = :unspecified
606- # Step 1: Check whether the statement we're looking at is a getfield/setfield!
607606 if is_known_call (stmt, setfield!, compact)
608- is_setfield = true
609607 4 <= length (stmt. args) <= 5 || continue
608+ is_setfield = true
610609 if length (stmt. args) == 5
611610 field_ordering = compact_exprtype (compact, stmt. args[5 ])
612611 end
@@ -624,7 +623,7 @@ function sroa_pass!(ir::IRCode)
624623 old_preserves = stmt. args[(6 + nccallargs): end ]
625624 for (pidx, preserved_arg) in enumerate (old_preserves)
626625 isa (preserved_arg, SSAValue) || continue
627- let intermediaries = IdSet {Int} ()
626+ let intermediaries = BitSet ()
628627 callback = function (@nospecialize (pi ), @nospecialize (ssa))
629628 push! (intermediaries, ssa. id)
630629 return false
@@ -634,7 +633,7 @@ function sroa_pass!(ir::IRCode)
634633 defidx = def. id
635634 def = compact[defidx]
636635 if is_tuple_call (compact, def)
637- process_immutable_preserve (new_preserves, compact, def)
636+ process_immutable_preserve! (new_preserves, compact, def)
638637 old_preserves[pidx] = nothing
639638 continue
640639 elseif isexpr (def, :new )
@@ -643,14 +642,17 @@ function sroa_pass!(ir::IRCode)
643642 typ = unwrap_unionall (typ)
644643 end
645644 if typ isa DataType && ! ismutabletype (typ)
646- process_immutable_preserve (new_preserves, compact, def)
645+ process_immutable_preserve! (new_preserves, compact, def)
647646 old_preserves[pidx] = nothing
648647 continue
649648 end
650649 else
651650 continue
652651 end
653- mid, defuse = get! (defuses, defidx, (IdSet {Int} (), SSADefUse ()))
652+ if defuses === nothing
653+ defuses = IdDict {Int, Tuple{BitSet, SSADefUse}} ()
654+ end
655+ mid, defuse = get! (defuses, defidx, (BitSet (), SSADefUse ()))
654656 push! (defuse. ccall_preserve_uses, idx)
655657 union! (mid, intermediaries)
656658 end
@@ -675,6 +677,9 @@ function sroa_pass!(ir::IRCode)
675677 else
676678 continue
677679 end
680+
681+ # analyze this `getfield` / `setfield!` call
682+
678683 field = try_compute_field_stmt (compact, stmt)
679684 field === nothing && continue
680685
@@ -689,19 +694,23 @@ function sroa_pass!(ir::IRCode)
689694 continue
690695 end
691696
692- def, typeconstraint = stmt. args[2 ], struct_typ
697+ val = stmt. args[2 ]
693698
699+ # analyze this mutable struct here for the later pass
694700 if ismutabletype (struct_typ)
695- isa (def , SSAValue) || continue
696- let intermediaries = IdSet {Int} ()
701+ isa (val , SSAValue) || continue
702+ let intermediaries = BitSet ()
697703 callback = function (@nospecialize (pi ), @nospecialize (ssa))
698704 push! (intermediaries, ssa. id)
699705 return false
700706 end
701- def = simple_walk (compact, def , callback)
707+ def = simple_walk (compact, val , callback)
702708 # Mutable stuff here
703709 isa (def, SSAValue) || continue
704- mid, defuse = get! (defuses, def. id, (IdSet {Int} (), SSADefUse ()))
710+ if defuses === nothing
711+ defuses = IdDict {Int, Tuple{BitSet, SSADefUse}} ()
712+ end
713+ mid, defuse = get! (defuses, def. id, (BitSet (), SSADefUse ()))
705714 if is_setfield
706715 push! (defuse. defs, idx)
707716 else
@@ -711,26 +720,21 @@ function sroa_pass!(ir::IRCode)
711720 end
712721 continue
713722 elseif is_setfield
714- continue
723+ continue # invalid `setfield!` call, but just ignore here
715724 end
716725
717726 # perform SROA on immutable structs here on
718727
719- if isa (def, Union{OldSSAValue, SSAValue})
720- def, typeconstraint = simple_walk_constraint (compact, def, typeconstraint)
721- end
722-
723- visited_phinodes = AnySSAValue[]
724- leaves = walk_to_defs (compact, def, typeconstraint, visited_phinodes)
725-
726- isempty (leaves) && continue
727-
728728 field = try_compute_fieldidx (struct_typ, field)
729729 field === nothing && continue
730730
731- r = lift_leaves (compact, result_t, field, leaves)
732- r === nothing && continue
733- lifted_leaves, any_undef = r
731+ leaves, visited_phinodes = collect_leaves (compact, val, struct_typ)
732+ isempty (leaves) && continue
733+
734+ result_t = compact_exprtype (compact, SSAValue (idx))
735+ lifted_result = lift_leaves (compact, result_t, field, leaves)
736+ lifted_result === nothing && continue
737+ lifted_leaves, any_undef = lifted_result
734738
735739 if any_undef
736740 result_t = make_MaybeUndef (result_t)
@@ -750,28 +754,32 @@ function sroa_pass!(ir::IRCode)
750754 @assert val != = nothing
751755 end
752756
753- # global assertion_counter
754- # assertion_counter::Int += 1
755- # insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), val), nothing, 0, true)
756- # continue
757757 compact[idx] = val === nothing ? nothing : val. x
758758 end
759759
760760 non_dce_finish! (compact)
761- # Copy the use count, `simple_dce!` may modify it and for our predicate
762- # below we need it consistent with the state of the IR here (after tracking
763- # phi node arguments, but before dce).
764- used_ssas = copy (compact. used_ssas)
765- simple_dce! (compact)
766- ir = complete (compact)
767-
768- # Compute domtree, needed below, now that we have finished compacting the
769- # IR. This needs to be after we iterate through the IR with
770- # `IncrementalCompact` because removing dead blocks can invalidate the
771- # domtree.
761+ if defuses != = nothing
762+ # now go through analyzed mutable structs and see which ones we can eliminate
763+ # NOTE copy the use count here, because `simple_dce!` may modify it and we need it
764+ # consistent with the state of the IR here (after tracking `PhiNode` arguments,
765+ # but before the DCE) for our predicate within `sroa_mutables!`
766+ used_ssas = copy (compact. used_ssas)
767+ simple_dce! (compact)
768+ ir = complete (compact)
769+ sroa_mutables! (ir, defuses, used_ssas)
770+ return ir
771+ else
772+ simple_dce! (compact)
773+ return complete (compact)
774+ end
775+ end
776+
777+ function sroa_mutables! (ir:: IRCode , defuses:: IdDict{Int, Tuple{BitSet, SSADefUse}} , used_ssas:: Vector{Int} )
778+ # Compute domtree, needed below, now that we have finished compacting the IR.
779+ # This needs to be after we iterate through the IR with `IncrementalCompact`
780+ # because removing dead blocks can invalidate the domtree.
772781 @timeit " domtree 2" domtree = construct_domtree (ir. cfg. blocks)
773782
774- # Now go through any mutable structs and see which ones we can eliminate
775783 for (idx, (intermediaries, defuse)) in defuses
776784 intermediaries = collect (intermediaries)
777785 # Check if there are any uses we did not account for. If so, the variable
@@ -806,12 +814,12 @@ function sroa_pass!(ir::IRCode)
806814 # it would have been deleted. That's fine, just ignore
807815 # the use in that case.
808816 stmt === nothing && continue
809- field = try_compute_fieldidx_stmt (compact , stmt:: Expr , typ)
817+ field = try_compute_fieldidx_stmt (ir , stmt:: Expr , typ)
810818 field === nothing && @goto skip
811819 push! (fielddefuse[field]. uses, use)
812820 end
813821 for use in defuse. defs
814- field = try_compute_fieldidx_stmt (compact , ir[SSAValue (use)]:: Expr , typ)
822+ field = try_compute_fieldidx_stmt (ir , ir[SSAValue (use)]:: Expr , typ)
815823 field === nothing && @goto skip
816824 push! (fielddefuse[field]. defs, use)
817825 end
@@ -846,8 +854,9 @@ function sroa_pass!(ir::IRCode)
846854 end
847855 end
848856 end
849- preserve_uses = IdDict {Int, Vector{Any}} ((idx=> Any[] for idx in IdSet {Int} (defuse. ccall_preserve_uses)))
850857 # Everything accounted for. Go field by field and perform idf
858+ preserve_uses = isempty (defuse. ccall_preserve_uses) ? nothing :
859+ IdDict {Int, Vector{Any}} ((idx=> Any[] for idx in BitSet (defuse. ccall_preserve_uses)))
851860 for fidx in 1 : ndefuse
852861 du = fielddefuse[fidx]
853862 ftyp = fieldtype (typ, fidx)
@@ -863,8 +872,10 @@ function sroa_pass!(ir::IRCode)
863872 ir[SSAValue (stmt)] = compute_value_for_use (ir, domtree, allblocks, du, phinodes, fidx, stmt)
864873 end
865874 if ! isbitstype (ftyp)
866- for (use, list) in preserve_uses
867- push! (list, compute_value_for_use (ir, domtree, allblocks, du, phinodes, fidx, use))
875+ if preserve_uses != = nothing
876+ for (use, list) in preserve_uses
877+ push! (list, compute_value_for_use (ir, domtree, allblocks, du, phinodes, fidx, use))
878+ end
868879 end
869880 end
870881 for b in phiblocks
@@ -881,7 +892,7 @@ function sroa_pass!(ir::IRCode)
881892 ir[SSAValue (stmt)] = nothing
882893 end
883894 end
884- isempty (defuse . ccall_preserve_uses) && continue
895+ preserve_uses === nothing && continue
885896 push! (intermediaries, newidx)
886897 # Insert the new preserves
887898 for (use, new_preserves) in preserve_uses
@@ -897,10 +908,7 @@ function sroa_pass!(ir::IRCode)
897908
898909 @label skip
899910 end
900-
901- return ir
902911end
903- # assertion_counter = 0
904912
905913"""
906914 canonicalize_typeassert!(compact::IncrementalCompact, idx::Int, stmt::Expr)
0 commit comments