Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 91 additions & 82 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ SSADefUse() = SSADefUse(Int[], Int[], Int[])

compute_live_ins(cfg::CFG, du::SSADefUse) = compute_live_ins(cfg, du.defs, du.uses)

function try_compute_field_stmt(compact::IncrementalCompact, stmt::Expr)
function try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr)
field = stmt.args[3]
# fields are usually literals, handle them manually
if isa(field, QuoteNode)
field = field.value
elseif isa(field, Int)
# try to resolve other constants, e.g. global reference
else
field = compact_exprtype(compact, field)
field = isa(ir, IncrementalCompact) ? compact_exprtype(ir, field) : argextype(field, ir)
if isa(field, Const)
field = field.val
else
Expand All @@ -42,8 +42,8 @@ function try_compute_field_stmt(compact::IncrementalCompact, stmt::Expr)
return field
end

function try_compute_fieldidx_stmt(compact::IncrementalCompact, stmt::Expr, typ::DataType)
field = try_compute_field_stmt(compact, stmt)
function try_compute_fieldidx_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr, typ::DataType)
field = try_compute_field_stmt(ir, stmt)
return try_compute_fieldidx(typ, field)
end

Expand Down Expand Up @@ -112,6 +112,13 @@ function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int},
return def, stmtblock, curblock
end

function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typeconstraint))
if isa(val, Union{OldSSAValue, SSAValue})
val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint)
end
return walk_to_defs(compact, val, typeconstraint)
end

function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
callback = (@nospecialize(pi), @nospecialize(idx)) -> false)
while true
Expand Down Expand Up @@ -152,7 +159,7 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
end

function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
@nospecialize(typeconstraint = types(compact)[defssa]))
@nospecialize(typeconstraint))
callback = function (@nospecialize(pi), @nospecialize(idx))
if isa(pi, PiNode)
typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ))
Expand All @@ -164,20 +171,16 @@ function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defss
end

"""
walk_to_defs(compact, val, intermediaries)
walk_to_defs(compact, val, typeconstraint)

Starting at `val` walk use-def chains to get all the leaves feeding into
this val (pruning those leaves rules out by path conditions).
Starting at `val` walk use-def chains to get all the leaves feeding into this `val`
(pruning those leaves rules out by path conditions).
"""
function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint), visited_phinodes::Vector{AnySSAValue}=AnySSAValue[])
isa(defssa, AnySSAValue) || return Any[defssa]
function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint))
visited_phinodes = AnySSAValue[]
isa(defssa, AnySSAValue) || return Any[defssa], visited_phinodes
def = compact[defssa]
isa(def, PhiNode) || return Any[defssa]
# Step 2: Figure out what the struct is defined as
## Track definitions through PiNode/PhiNode
found_def = false
## Track which PhiNodes, SSAValue intermediaries
## we forwarded through.
isa(def, PhiNode) || return Any[defssa], visited_phinodes
visited_constraints = IdDict{AnySSAValue, Any}()
worklist_defs = AnySSAValue[]
worklist_constraints = Any[]
Expand Down Expand Up @@ -239,10 +242,10 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
push!(leaves, defssa)
end
end
leaves
return leaves, visited_phinodes
end

function process_immutable_preserve(new_preserves::Vector{Any}, compact::IncrementalCompact, def::Expr)
function process_immutable_preserve!(new_preserves::Vector{Any}, compact::IncrementalCompact, def::Expr)
for arg in (isexpr(def, :new) ? def.args : def.args[2:end])
if !isbitstype(widenconst(compact_exprtype(compact, arg)))
push!(new_preserves, arg)
Expand Down Expand Up @@ -449,13 +452,10 @@ function lift_comparison!(compact::IncrementalCompact,
return
end

if isa(val, Union{OldSSAValue, SSAValue})
val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint)
end

visited_phinodes = AnySSAValue[]
leaves = walk_to_defs(compact, val, typeconstraint, visited_phinodes)
valtyp = widenconst(compact_exprtype(compact, val))
isa(valtyp, Union) || return # bail out if there won't be a good chance for lifting

leaves, visited_phinodes = collect_leaves(compact, val, valtyp)
length(leaves) ≤ 1 && return # bail out if we don't have multiple leaves

# Let's check if we evaluate the comparison for each one of the leaves
Expand All @@ -476,10 +476,6 @@ function lift_comparison!(compact::IncrementalCompact,
visited_phinodes, cmp, lifting_cache, Bool,
lifted_leaves::IdDict{Any, Union{Nothing,LiftedValue}}, val)::LiftedValue

# global assertion_counter
# assertion_counter::Int += 1
# insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), lifted_val), nothing, 0, true)
# return
compact[idx] = lifted_val.x
end

Expand Down Expand Up @@ -576,6 +572,10 @@ function perform_lifting!(compact::IncrementalCompact,
return stmt_val # N.B. should never happen
end

# NOTE we use `IdSet{Int}` instead of `BitSet` for `sroa_pass!` since it works on IR after inlining,
# which can be very large sometimes, and analyzed program counters are often very sparse
const SPCSet = IdSet{Int}

"""
sroa_pass!(ir::IRCode) -> newir::IRCode

Expand All @@ -596,17 +596,16 @@ a result of succeeding dead code elimination.
"""
function sroa_pass!(ir::IRCode)
compact = IncrementalCompact(ir)
defuses = IdDict{Int, Tuple{IdSet{Int}, SSADefUse}}()
defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wonder: we may want to have some nice support for this sort of optimization?
Like: @delay_alloc defuses = IdDict{Int, Tuple{BitSet, SSADefUse}}()

lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}()
for ((_, idx), stmt) in compact
# check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
isa(stmt, Expr) || continue
result_t = compact_exprtype(compact, SSAValue(idx))
is_setfield = false
field_ordering = :unspecified
# Step 1: Check whether the statement we're looking at is a getfield/setfield!
if is_known_call(stmt, setfield!, compact)
is_setfield = true
4 <= length(stmt.args) <= 5 || continue
is_setfield = true
if length(stmt.args) == 5
field_ordering = compact_exprtype(compact, stmt.args[5])
end
Expand All @@ -624,7 +623,7 @@ function sroa_pass!(ir::IRCode)
old_preserves = stmt.args[(6+nccallargs):end]
for (pidx, preserved_arg) in enumerate(old_preserves)
isa(preserved_arg, SSAValue) || continue
let intermediaries = IdSet{Int}()
let intermediaries = SPCSet()
callback = function (@nospecialize(pi), @nospecialize(ssa))
push!(intermediaries, ssa.id)
return false
Expand All @@ -634,7 +633,7 @@ function sroa_pass!(ir::IRCode)
defidx = def.id
def = compact[defidx]
if is_tuple_call(compact, def)
process_immutable_preserve(new_preserves, compact, def)
process_immutable_preserve!(new_preserves, compact, def)
old_preserves[pidx] = nothing
continue
elseif isexpr(def, :new)
Expand All @@ -643,14 +642,17 @@ function sroa_pass!(ir::IRCode)
typ = unwrap_unionall(typ)
end
if typ isa DataType && !ismutabletype(typ)
process_immutable_preserve(new_preserves, compact, def)
process_immutable_preserve!(new_preserves, compact, def)
old_preserves[pidx] = nothing
continue
end
else
continue
end
mid, defuse = get!(defuses, defidx, (IdSet{Int}(), SSADefUse()))
if defuses === nothing
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
end
mid, defuse = get!(defuses, defidx, (SPCSet(), SSADefUse()))
push!(defuse.ccall_preserve_uses, idx)
union!(mid, intermediaries)
end
Expand All @@ -675,10 +677,15 @@ function sroa_pass!(ir::IRCode)
else
continue
end

# analyze this `getfield` / `setfield!` call

field = try_compute_field_stmt(compact, stmt)
field === nothing && continue

struct_typ = unwrap_unionall(widenconst(compact_exprtype(compact, stmt.args[2])))
val = stmt.args[2]

struct_typ = unwrap_unionall(widenconst(compact_exprtype(compact, val)))
if isa(struct_typ, Union) && struct_typ <: Tuple
struct_typ = unswitchtupleunion(struct_typ)
end
Expand All @@ -689,19 +696,21 @@ function sroa_pass!(ir::IRCode)
continue
end

def, typeconstraint = stmt.args[2], struct_typ

# analyze this mutable struct here for the later pass
if ismutabletype(struct_typ)
isa(def, SSAValue) || continue
let intermediaries = IdSet{Int}()
isa(val, SSAValue) || continue
let intermediaries = SPCSet()
callback = function (@nospecialize(pi), @nospecialize(ssa))
push!(intermediaries, ssa.id)
return false
end
def = simple_walk(compact, def, callback)
def = simple_walk(compact, val, callback)
# Mutable stuff here
isa(def, SSAValue) || continue
mid, defuse = get!(defuses, def.id, (IdSet{Int}(), SSADefUse()))
if defuses === nothing
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
end
mid, defuse = get!(defuses, def.id, (SPCSet(), SSADefUse()))
if is_setfield
push!(defuse.defs, idx)
else
Expand All @@ -711,32 +720,28 @@ function sroa_pass!(ir::IRCode)
end
continue
elseif is_setfield
continue
continue # invalid `setfield!` call, but just ignore here
end

# perform SROA on immutable structs here on

if isa(def, Union{OldSSAValue, SSAValue})
def, typeconstraint = simple_walk_constraint(compact, def, typeconstraint)
end

visited_phinodes = AnySSAValue[]
leaves = walk_to_defs(compact, def, typeconstraint, visited_phinodes)

isempty(leaves) && continue

field = try_compute_fieldidx(struct_typ, field)
field === nothing && continue

r = lift_leaves(compact, result_t, field, leaves)
r === nothing && continue
lifted_leaves, any_undef = r
leaves, visited_phinodes = collect_leaves(compact, val, struct_typ)
isempty(leaves) && continue

result_t = compact_exprtype(compact, SSAValue(idx))
lifted_result = lift_leaves(compact, result_t, field, leaves)
lifted_result === nothing && continue
lifted_leaves, any_undef = lifted_result

if any_undef
result_t = make_MaybeUndef(result_t)
end

val = perform_lifting!(compact, visited_phinodes, field, lifting_cache, result_t, lifted_leaves, stmt.args[2])
val = perform_lifting!(compact,
visited_phinodes, field, lifting_cache, result_t, lifted_leaves, val)

# Insert the undef check if necessary
if any_undef
Expand All @@ -750,28 +755,32 @@ function sroa_pass!(ir::IRCode)
@assert val !== nothing
end

# global assertion_counter
# assertion_counter::Int += 1
# insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), val), nothing, 0, true)
# continue
compact[idx] = val === nothing ? nothing : val.x
end

non_dce_finish!(compact)
# Copy the use count, `simple_dce!` may modify it and for our predicate
# below we need it consistent with the state of the IR here (after tracking
# phi node arguments, but before dce).
used_ssas = copy(compact.used_ssas)
simple_dce!(compact)
ir = complete(compact)

# Compute domtree, needed below, now that we have finished compacting the
# IR. This needs to be after we iterate through the IR with
# `IncrementalCompact` because removing dead blocks can invalidate the
# domtree.
if defuses !== nothing
# now go through analyzed mutable structs and see which ones we can eliminate
# NOTE copy the use count here, because `simple_dce!` may modify it and we need it
# consistent with the state of the IR here (after tracking `PhiNode` arguments,
# but before the DCE) for our predicate within `sroa_mutables!`
used_ssas = copy(compact.used_ssas)
simple_dce!(compact)
ir = complete(compact)
sroa_mutables!(ir, defuses, used_ssas)
return ir
else
simple_dce!(compact)
return complete(compact)
end
end

function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int})
# Compute domtree, needed below, now that we have finished compacting the IR.
# This needs to be after we iterate through the IR with `IncrementalCompact`
# because removing dead blocks can invalidate the domtree.
@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)

# Now go through any mutable structs and see which ones we can eliminate
for (idx, (intermediaries, defuse)) in defuses
intermediaries = collect(intermediaries)
# Check if there are any uses we did not account for. If so, the variable
Expand Down Expand Up @@ -806,12 +815,12 @@ function sroa_pass!(ir::IRCode)
# it would have been deleted. That's fine, just ignore
# the use in that case.
stmt === nothing && continue
field = try_compute_fieldidx_stmt(compact, stmt::Expr, typ)
field = try_compute_fieldidx_stmt(ir, stmt::Expr, typ)
field === nothing && @goto skip
push!(fielddefuse[field].uses, use)
end
for use in defuse.defs
field = try_compute_fieldidx_stmt(compact, ir[SSAValue(use)]::Expr, typ)
field = try_compute_fieldidx_stmt(ir, ir[SSAValue(use)]::Expr, typ)
field === nothing && @goto skip
push!(fielddefuse[field].defs, use)
end
Expand Down Expand Up @@ -846,8 +855,9 @@ function sroa_pass!(ir::IRCode)
end
end
end
preserve_uses = IdDict{Int, Vector{Any}}((idx=>Any[] for idx in IdSet{Int}(defuse.ccall_preserve_uses)))
# Everything accounted for. Go field by field and perform idf
preserve_uses = isempty(defuse.ccall_preserve_uses) ? nothing :
IdDict{Int, Vector{Any}}((idx=>Any[] for idx in SPCSet(defuse.ccall_preserve_uses)))
for fidx in 1:ndefuse
du = fielddefuse[fidx]
ftyp = fieldtype(typ, fidx)
Expand All @@ -863,8 +873,10 @@ function sroa_pass!(ir::IRCode)
ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt)
end
if !isbitstype(ftyp)
for (use, list) in preserve_uses
push!(list, compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use))
if preserve_uses !== nothing
for (use, list) in preserve_uses
push!(list, compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use))
end
end
end
for b in phiblocks
Expand All @@ -881,7 +893,7 @@ function sroa_pass!(ir::IRCode)
ir[SSAValue(stmt)] = nothing
end
end
isempty(defuse.ccall_preserve_uses) && continue
preserve_uses === nothing && continue
push!(intermediaries, newidx)
# Insert the new preserves
for (use, new_preserves) in preserve_uses
Expand All @@ -897,10 +909,7 @@ function sroa_pass!(ir::IRCode)

@label skip
end

return ir
end
# assertion_counter = 0

"""
canonicalize_typeassert!(compact::IncrementalCompact, idx::Int, stmt::Expr)
Expand Down