Skip to content

Commit bad3e39

Browse files
author
Ian Atol
authored
optimizer: use count checking framework (#44794)
1 parent eb4c757 commit bad3e39

File tree

6 files changed

+159
-78
lines changed

6 files changed

+159
-78
lines changed

base/compiler/ssair/inlining.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,9 +1379,7 @@ function inline_const_if_inlineable!(inst::Instruction)
13791379
end
13801380

13811381
function assemble_inline_todo!(ir::IRCode, state::InliningState)
1382-
# todo = (inline_idx, (isva, isinvoke, na), method, spvals, inline_linetable, inline_ir, lie)
13831382
todo = Pair{Int, Any}[]
1384-
et = state.et
13851383

13861384
for idx in 1:length(ir.stmts)
13871385
simpleres = process_simple!(ir, idx, state, todo)
@@ -1586,6 +1584,7 @@ function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any},
15861584
end
15871585
end
15881586
end
1587+
isa(val, Union{SSAValue, NewSSAValue}) && return val # avoid infinite loop
15891588
urs = userefs(val)
15901589
for op in urs
15911590
op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck)

base/compiler/ssair/ir.jl

Lines changed: 41 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,9 @@ struct UndefToken end; const UNDEF_TOKEN = UndefToken()
381381
isdefined(stmt, :val) || return OOB_TOKEN
382382
op == 1 || return OOB_TOKEN
383383
return stmt.val
384+
elseif isa(stmt, Union{SSAValue, NewSSAValue})
385+
op == 1 || return OOB_TOKEN
386+
return stmt
384387
elseif isa(stmt, UpsilonNode)
385388
isdefined(stmt, :val) || return OOB_TOKEN
386389
op == 1 || return OOB_TOKEN
@@ -430,6 +433,9 @@ end
430433
elseif isa(stmt, ReturnNode)
431434
op == 1 || throw(BoundsError())
432435
stmt = typeof(stmt)(v)
436+
elseif isa(stmt, Union{SSAValue, NewSSAValue})
437+
op == 1 || throw(BoundsError())
438+
stmt = v
433439
elseif isa(stmt, UpsilonNode)
434440
op == 1 || throw(BoundsError())
435441
stmt = typeof(stmt)(v)
@@ -457,7 +463,7 @@ end
457463

458464
function userefs(@nospecialize(x))
459465
relevant = (isa(x, Expr) && is_relevant_expr(x)) ||
460-
isa(x, GotoIfNot) || isa(x, ReturnNode) ||
466+
isa(x, GotoIfNot) || isa(x, ReturnNode) || isa(x, SSAValue) || isa(x, NewSSAValue) ||
461467
isa(x, PiNode) || isa(x, PhiNode) || isa(x, PhiCNode) || isa(x, UpsilonNode)
462468
return UseRefIterator(x, relevant)
463469
end
@@ -480,50 +486,10 @@ end
480486

481487
# This function is used from the show code, which may have a different
482488
# `push!`/`used` type since it's in Base.
483-
function scan_ssa_use!(push!, used, @nospecialize(stmt))
484-
if isa(stmt, SSAValue)
485-
push!(used, stmt.id)
486-
end
487-
for useref in userefs(stmt)
488-
val = useref[]
489-
if isa(val, SSAValue)
490-
push!(used, val.id)
491-
end
492-
end
493-
end
489+
scan_ssa_use!(push!, used, @nospecialize(stmt)) = foreachssa(ssa -> push!(used, ssa.id), stmt)
494490

495491
# Manually specialized copy of the above with push! === Compiler.push!
496-
function scan_ssa_use!(used::IdSet, @nospecialize(stmt))
497-
if isa(stmt, SSAValue)
498-
push!(used, stmt.id)
499-
end
500-
for useref in userefs(stmt)
501-
val = useref[]
502-
if isa(val, SSAValue)
503-
push!(used, val.id)
504-
end
505-
end
506-
end
507-
508-
function ssamap(f, @nospecialize(stmt))
509-
urs = userefs(stmt)
510-
for op in urs
511-
val = op[]
512-
if isa(val, SSAValue)
513-
op[] = f(val)
514-
end
515-
end
516-
return urs[]
517-
end
518-
519-
function foreachssa(f, @nospecialize(stmt))
520-
for op in userefs(stmt)
521-
val = op[]
522-
if isa(val, SSAValue)
523-
f(val)
524-
end
525-
end
526-
end
492+
scan_ssa_use!(used::IdSet, @nospecialize(stmt)) = foreachssa(ssa -> push!(used, ssa.id), stmt)
527493

528494
function insert_node!(ir::IRCode, pos::Int, inst::NewInstruction, attach_after::Bool=false)
529495
node = add!(ir.new_nodes, pos, attach_after)
@@ -751,20 +717,13 @@ end
751717

752718
function count_added_node!(compact::IncrementalCompact, @nospecialize(v))
753719
needs_late_fixup = false
754-
if isa(v, SSAValue)
755-
compact.used_ssas[v.id] += 1
756-
elseif isa(v, NewSSAValue)
757-
compact.new_new_used_ssas[v.id] += 1
758-
needs_late_fixup = true
759-
else
760-
for ops in userefs(v)
761-
val = ops[]
762-
if isa(val, SSAValue)
763-
compact.used_ssas[val.id] += 1
764-
elseif isa(val, NewSSAValue)
765-
compact.new_new_used_ssas[val.id] += 1
766-
needs_late_fixup = true
767-
end
720+
for ops in userefs(v)
721+
val = ops[]
722+
if isa(val, SSAValue)
723+
compact.used_ssas[val.id] += 1
724+
elseif isa(val, NewSSAValue)
725+
compact.new_new_used_ssas[val.id] += 1
726+
needs_late_fixup = true
768727
end
769728
end
770729
return needs_late_fixup
@@ -931,6 +890,27 @@ function setindex!(compact::IncrementalCompact, @nospecialize(v), idx::Int)
931890
return compact
932891
end
933892

893+
__set_check_ssa_counts(onoff::Bool) = __check_ssa_counts__[] = onoff
894+
const __check_ssa_counts__ = fill(false)
895+
896+
function _oracle_check(compact::IncrementalCompact)
897+
observed_used_ssas = Core.Compiler.find_ssavalue_uses1(compact)
898+
for i = 1:length(observed_used_ssas)
899+
if observed_used_ssas[i] != compact.used_ssas[i]
900+
return observed_used_ssas
901+
end
902+
end
903+
return nothing
904+
end
905+
906+
function oracle_check(compact::IncrementalCompact)
907+
maybe_oracle_used_ssas = _oracle_check(compact)
908+
if maybe_oracle_used_ssas !== nothing
909+
@eval Main (compact = $compact; oracle_used_ssas = $maybe_oracle_used_ssas)
910+
error("Oracle check failed, inspect Main.compact and Main.oracle_used_ssas")
911+
end
912+
end
913+
934914
getindex(view::TypesView, idx::SSAValue) = getindex(view, idx.id)
935915
function getindex(view::TypesView, idx::Int)
936916
if isa(view.ir, IncrementalCompact) && idx < view.ir.result_idx
@@ -1425,7 +1405,6 @@ function iterate(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int}=
14251405
# result_idx is not, incremented, but that's ok and expected
14261406
compact.result[old_result_idx] = compact.ir.stmts[idx]
14271407
result_idx = process_node!(compact, old_result_idx, compact.ir.stmts[idx], idx, idx, active_bb, true)
1428-
stmt_if_any = old_result_idx == result_idx ? nothing : compact.result[old_result_idx][:inst]
14291408
compact.result_idx = result_idx
14301409
if idx == last(bb.stmts) && !attach_after_stmt_after(compact, idx)
14311410
finish_current_bb!(compact, active_bb, old_result_idx)
@@ -1464,11 +1443,7 @@ function maybe_erase_unused!(
14641443
callback(val)
14651444
end
14661445
if effect_free
1467-
if isa(stmt, SSAValue)
1468-
kill_ssa_value(stmt)
1469-
else
1470-
foreachssa(kill_ssa_value, stmt)
1471-
end
1446+
foreachssa(kill_ssa_value, stmt)
14721447
inst[:inst] = nothing
14731448
return true
14741449
end
@@ -1570,6 +1545,9 @@ end
15701545
function complete(compact::IncrementalCompact)
15711546
result_bbs = resize!(compact.result_bbs, compact.active_result_bb-1)
15721547
cfg = CFG(result_bbs, Int[first(result_bbs[i].stmts) for i in 2:length(result_bbs)])
1548+
if __check_ssa_counts__[]
1549+
oracle_check(compact)
1550+
end
15731551
return IRCode(compact.ir, compact.result, cfg, compact.new_new_nodes)
15741552
end
15751553

base/compiler/ssair/passes.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,15 +1151,6 @@ function adce_erase!(phi_uses::Vector{Int}, extra_worklist::Vector{Int}, compact
11511151
end
11521152
end
11531153

1154-
function count_uses(@nospecialize(stmt), uses::Vector{Int})
1155-
for ur in userefs(stmt)
1156-
use = ur[]
1157-
if isa(use, SSAValue)
1158-
uses[use.id] += 1
1159-
end
1160-
end
1161-
end
1162-
11631154
function mark_phi_cycles!(compact::IncrementalCompact, safe_phis::SPCSet, phi::Int)
11641155
worklist = Int[]
11651156
push!(worklist, phi)

base/compiler/ssair/slot2ssa.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,6 @@ function make_ssa!(ci::CodeInfo, code::Vector{Any}, idx, slot, @nospecialize(typ
7272
end
7373

7474
function new_to_regular(@nospecialize(stmt), new_offset::Int)
75-
if isa(stmt, NewSSAValue)
76-
return SSAValue(stmt.id + new_offset)
77-
end
7875
urs = userefs(stmt)
7976
for op in urs
8077
val = op[]

base/compiler/utilities.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,27 @@ end
228228
# SSAValues/Slots #
229229
###################
230230

231+
function ssamap(f, @nospecialize(stmt))
232+
urs = userefs(stmt)
233+
for op in urs
234+
val = op[]
235+
if isa(val, SSAValue)
236+
op[] = f(val)
237+
end
238+
end
239+
return urs[]
240+
end
241+
242+
function foreachssa(f, @nospecialize(stmt))
243+
urs = userefs(stmt)
244+
for op in urs
245+
val = op[]
246+
if isa(val, SSAValue)
247+
f(val)
248+
end
249+
end
250+
end
251+
231252
function find_ssavalue_uses(body::Vector{Any}, nvals::Int)
232253
uses = BitSet[ BitSet() for i = 1:nvals ]
233254
for line in 1:length(body)
@@ -333,6 +354,38 @@ end
333354
@inline slot_id(s) = isa(s, SlotNumber) ? (s::SlotNumber).id :
334355
isa(s, Argument) ? (s::Argument).n : (s::TypedSlot).id
335356

357+
######################
358+
# IncrementalCompact #
359+
######################
360+
361+
# specifically meant to be used with body1 = compact.result and body2 = compact.new_new_nodes, with nvals == length(compact.used_ssas)
362+
function find_ssavalue_uses1(compact)
363+
body1, body2 = compact.result.inst, compact.new_new_nodes.stmts.inst
364+
nvals = length(compact.used_ssas)
365+
nbody1 = length(body1)
366+
nbody2 = length(body2)
367+
368+
uses = zeros(Int, nvals)
369+
function increment_uses(ssa::SSAValue)
370+
uses[ssa.id] += 1
371+
end
372+
373+
for line in 1:(nbody1 + nbody2)
374+
# index into the right body
375+
if line <= nbody1
376+
isassigned(body1, line) || continue
377+
e = body1[line]
378+
else
379+
line -= nbody1
380+
isassigned(body2, line) || continue
381+
e = body2[line]
382+
end
383+
384+
foreachssa(increment_uses, e)
385+
end
386+
return uses
387+
end
388+
336389
###########
337390
# options #
338391
###########

test/compiler/ssair.jl

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
using Base.Meta
44
using Core.IR
55
const Compiler = Core.Compiler
6-
using .Compiler: CFG, BasicBlock
6+
using .Compiler: CFG, BasicBlock, NewSSAValue
77

88
make_bb(preds, succs) = BasicBlock(Compiler.StmtRange(0, 0), preds, succs)
99

@@ -334,3 +334,66 @@ f_if_typecheck() = (if nothing; end; unsafe_load(Ptr{Int}(0)))
334334
stderr = IOBuffer()
335335
success(pipeline(Cmd(cmd); stdout=stdout, stderr=stderr)) && isempty(String(take!(stderr)))
336336
end
337+
338+
let
339+
function test_useref(stmt, v, op)
340+
if isa(stmt, Expr)
341+
@test stmt.args[op] === v
342+
elseif isa(stmt, GotoIfNot)
343+
@test stmt.cond === v
344+
elseif isa(stmt, ReturnNode) || isa(stmt, UpsilonNode)
345+
@test stmt.val === v
346+
elseif isa(stmt, SSAValue) || isa(stmt, NewSSAValue)
347+
@test stmt === v
348+
elseif isa(stmt, PiNode)
349+
@test stmt.val === v && stmt.typ === typeof(stmt)
350+
elseif isa(stmt, PhiNode) || isa(stmt, PhiCNode)
351+
@test stmt.values[op] === v
352+
end
353+
end
354+
355+
function _test_userefs(@nospecialize stmt)
356+
ex = Expr(:call, :+, Core.SSAValue(3), 1)
357+
urs = Core.Compiler.userefs(stmt)::Core.Compiler.UseRefIterator
358+
it = Core.Compiler.iterate(urs)
359+
while it !== nothing
360+
ur = getfield(it, 1)::Core.Compiler.UseRef
361+
op = getfield(it, 2)::Int
362+
v1 = Core.Compiler.getindex(ur)
363+
# set to dummy expression and then back to itself to test `_useref_setindex!`
364+
v2 = Core.Compiler.setindex!(ur, ex)
365+
test_useref(v2, ex, op)
366+
Core.Compiler.setindex!(ur, v1)
367+
@test Core.Compiler.getindex(ur) === v1
368+
it = Core.Compiler.iterate(urs, op)
369+
end
370+
end
371+
372+
function test_userefs(body)
373+
for stmt in body
374+
_test_userefs(stmt)
375+
end
376+
end
377+
378+
# this isn't valid code, we just care about looking at a variety of IR nodes
379+
body = Any[
380+
Expr(:enter, 11),
381+
Expr(:call, :+, SSAValue(3), 1),
382+
Expr(:throw_undef_if_not, :expected, false),
383+
Expr(:leave, 1),
384+
Expr(:(=), SSAValue(1), Expr(:call, :+, SSAValue(3), 1)),
385+
UpsilonNode(),
386+
UpsilonNode(SSAValue(2)),
387+
PhiCNode(Any[SSAValue(5), SSAValue(7), SSAValue(9)]),
388+
PhiCNode(Any[SSAValue(6)]),
389+
PhiNode(Int32[8], Any[SSAValue(7)]),
390+
PiNode(SSAValue(6), GotoNode),
391+
GotoIfNot(SSAValue(3), 10),
392+
GotoNode(5),
393+
SSAValue(7),
394+
NewSSAValue(9),
395+
ReturnNode(SSAValue(11)),
396+
]
397+
398+
test_userefs(body)
399+
end

0 commit comments

Comments
 (0)