Skip to content

Commit c8ca350

Browse files
authored
inference: Model type propagation through exceptions (#51754)
Currently the type of a caught exception is always modeled as `Any`. This isn't a huge problem, because control flow in Julia is generally assumed to be somewhat slow, so the extra type imprecision of not knowing the return type does not matter all that much. However, there are a few situations where it matters. For example: ``` maybe_getindex(A, i) = try; A[i]; catch e; isa(e, BoundsError) && return nothing; rethrow(); end ``` At present, we cannot infer :nothrow for this method, even if that is the only error type that `A[i]` can throw. This is particularly noticable, since we can now optimize away `:nothrow` exception frames entirely (#51674). Note that this PR still does not make the above example particularly efficient (at least interprocedurally), though specialized codegen could be added on top of this to make that happen. It does however improve the inference result. A second major motivation of this change is that reasoning about exception types is likely to be a major aspect of any future work on interface checking (since interfaces imply the absence of MethodErrors), so this PR lays the groundwork for appropriate modeling of these error paths. Note that this PR adds all the required plumbing, but does not yet have a particularly precise model of error types for our builtins, bailing to `Any` for any builtin not known to be `:nothrow`. This can be improved in follow up PRs as required.
1 parent f5d189f commit c8ca350

26 files changed

+517
-244
lines changed

base/boot.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -479,13 +479,13 @@ eval(Core, quote
479479
end)
480480

481481
function CodeInstance(
482-
mi::MethodInstance, @nospecialize(rettype), @nospecialize(inferred_const),
482+
mi::MethodInstance, @nospecialize(rettype), @nospecialize(exctype), @nospecialize(inferred_const),
483483
@nospecialize(inferred), const_flags::Int32, min_world::UInt, max_world::UInt,
484484
ipo_effects::UInt32, effects::UInt32, @nospecialize(analysis_results),
485485
relocatability::UInt8)
486486
return ccall(:jl_new_codeinst, Ref{CodeInstance},
487-
(Any, Any, Any, Any, Int32, UInt, UInt, UInt32, UInt32, Any, UInt8),
488-
mi, rettype, inferred_const, inferred, const_flags, min_world, max_world,
487+
(Any, Any, Any, Any, Any, Int32, UInt, UInt, UInt32, UInt32, Any, UInt8),
488+
mi, rettype, exctype, inferred_const, inferred, const_flags, min_world, max_world,
489489
ipo_effects, effects, analysis_results,
490490
relocatability)
491491
end

base/compiler/abstractinterpretation.jl

Lines changed: 202 additions & 107 deletions
Large diffs are not rendered by default.

base/compiler/effects.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,68 @@ function Effects(effects::Effects = _EFFECTS_UNKNOWN;
171171
nonoverlayed)
172172
end
173173

174+
function is_better_effects(new::Effects, old::Effects)
175+
any_improved = false
176+
if new.consistent == ALWAYS_TRUE
177+
any_improved |= old.consistent != ALWAYS_TRUE
178+
else
179+
if !iszero(new.consistent & CONSISTENT_IF_NOTRETURNED)
180+
old.consistent == ALWAYS_TRUE && return false
181+
any_improved |= iszero(old.consistent & CONSISTENT_IF_NOTRETURNED)
182+
elseif !iszero(new.consistent & CONSISTENT_IF_INACCESSIBLEMEMONLY)
183+
old.consistent == ALWAYS_TRUE && return false
184+
any_improved |= iszero(old.consistent & CONSISTENT_IF_INACCESSIBLEMEMONLY)
185+
else
186+
return false
187+
end
188+
end
189+
if new.effect_free == ALWAYS_TRUE
190+
any_improved |= old.consistent != ALWAYS_TRUE
191+
elseif new.effect_free == EFFECT_FREE_IF_INACCESSIBLEMEMONLY
192+
old.effect_free == ALWAYS_TRUE && return false
193+
any_improved |= old.effect_free != EFFECT_FREE_IF_INACCESSIBLEMEMONLY
194+
elseif new.effect_free != old.effect_free
195+
return false
196+
end
197+
if new.nothrow
198+
any_improved |= !old.nothrow
199+
elseif new.nothrow != old.nothrow
200+
return false
201+
end
202+
if new.terminates
203+
any_improved |= !old.terminates
204+
elseif new.terminates != old.terminates
205+
return false
206+
end
207+
if new.notaskstate
208+
any_improved |= !old.notaskstate
209+
elseif new.notaskstate != old.notaskstate
210+
return false
211+
end
212+
if new.inaccessiblememonly == ALWAYS_TRUE
213+
any_improved |= old.inaccessiblememonly != ALWAYS_TRUE
214+
elseif new.inaccessiblememonly == INACCESSIBLEMEM_OR_ARGMEMONLY
215+
old.inaccessiblememonly == ALWAYS_TRUE && return false
216+
any_improved |= old.inaccessiblememonly != INACCESSIBLEMEM_OR_ARGMEMONLY
217+
elseif new.inaccessiblememonly != old.inaccessiblememonly
218+
return false
219+
end
220+
if new.noub == ALWAYS_TRUE
221+
any_improved |= old.noub != ALWAYS_TRUE
222+
elseif new.noub == NOUB_IF_NOINBOUNDS
223+
old.noub == ALWAYS_TRUE && return false
224+
any_improved |= old.noub != NOUB_IF_NOINBOUNDS
225+
elseif new.noub != old.noub
226+
return false
227+
end
228+
if new.nonoverlayed
229+
any_improved |= !old.nonoverlayed
230+
elseif new.nonoverlayed != old.nonoverlayed
231+
return false
232+
end
233+
return any_improved
234+
end
235+
174236
function merge_effects(old::Effects, new::Effects)
175237
return Effects(
176238
merge_effectbits(old.consistent, new.consistent),

base/compiler/inferencestate.jl

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,12 @@ const CACHE_MODE_GLOBAL = 0x01 << 0 # cached globally, optimization allowed
203203
const CACHE_MODE_LOCAL = 0x01 << 1 # cached locally, optimization allowed
204204
const CACHE_MODE_VOLATILE = 0x01 << 2 # not cached, optimization allowed
205205

206+
mutable struct TryCatchFrame
207+
exct
208+
const enter_idx::Int
209+
TryCatchFrame(@nospecialize(exct), enter_idx::Int) = new(exct, enter_idx)
210+
end
211+
206212
mutable struct InferenceState
207213
#= information about this method instance =#
208214
linfo::MethodInstance
@@ -218,7 +224,8 @@ mutable struct InferenceState
218224
currbb::Int
219225
currpc::Int
220226
ip::BitSet#=TODO BoundedMinPrioritySet=# # current active instruction pointers
221-
handler_at::Vector{Int} # current exception handler info
227+
handlers::Vector{TryCatchFrame}
228+
handler_at::Vector{Tuple{Int, Int}} # tuple of current (handler, exception stack) value at the pc
222229
ssavalue_uses::Vector{BitSet} # ssavalue sparsity and restart info
223230
# TODO: Could keep this sparsely by doing structural liveness analysis ahead of time.
224231
bb_vartables::Vector{Union{Nothing,VarTable}} # nothing if not analyzed yet
@@ -239,6 +246,7 @@ mutable struct InferenceState
239246
unreachable::BitSet # statements that were found to be statically unreachable
240247
valid_worlds::WorldRange
241248
bestguess #::Type
249+
exc_bestguess
242250
ipo_effects::Effects
243251

244252
#= flags =#
@@ -266,7 +274,7 @@ mutable struct InferenceState
266274

267275
currbb = currpc = 1
268276
ip = BitSet(1) # TODO BitSetBoundedMinPrioritySet(1)
269-
handler_at = compute_trycatch(code, BitSet())
277+
handler_at, handlers = compute_trycatch(code, BitSet())
270278
nssavalues = src.ssavaluetypes::Int
271279
ssavalue_uses = find_ssavalue_uses(code, nssavalues)
272280
nstmts = length(code)
@@ -296,6 +304,7 @@ mutable struct InferenceState
296304

297305
valid_worlds = WorldRange(src.min_world, src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)
298306
bestguess = Bottom
307+
exc_bestguess = Bottom
299308
ipo_effects = EFFECTS_TOTAL
300309

301310
insert_coverage = should_insert_coverage(mod, src)
@@ -315,9 +324,9 @@ mutable struct InferenceState
315324

316325
return new(
317326
linfo, world, mod, sptypes, slottypes, src, cfg, method_info,
318-
currbb, currpc, ip, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
327+
currbb, currpc, ip, handlers, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
319328
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent,
320-
result, unreachable, valid_worlds, bestguess, ipo_effects,
329+
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
321330
restrict_abstract_call_sites, cache_mode, insert_coverage,
322331
interp)
323332
end
@@ -347,16 +356,19 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
347356
empty!(ip)
348357
ip.offset = 0 # for _bits_findnext
349358
push!(ip, n + 1)
350-
handler_at = fill(0, n)
359+
handler_at = fill((0, 0), n)
360+
handlers = TryCatchFrame[]
351361

352362
# start from all :enter statements and record the location of the try
353363
for pc = 1:n
354364
stmt = code[pc]
355365
if isexpr(stmt, :enter)
356366
l = stmt.args[1]::Int
357-
handler_at[pc + 1] = pc
367+
push!(handlers, TryCatchFrame(Bottom, pc))
368+
handler_id = length(handlers)
369+
handler_at[pc + 1] = (handler_id, 0)
358370
push!(ip, pc + 1)
359-
handler_at[l] = pc
371+
handler_at[l] = (handler_id, handler_id)
360372
push!(ip, l)
361373
end
362374
end
@@ -369,25 +381,26 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
369381
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
370382
pc´ = pc + 1 # next program-counter (after executing instruction)
371383
delete!(ip, pc)
372-
cur_hand = handler_at[pc]
373-
@assert cur_hand != 0 "unbalanced try/catch"
384+
cur_stacks = handler_at[pc]
385+
@assert cur_stacks != (0, 0) "unbalanced try/catch"
374386
stmt = code[pc]
375387
if isa(stmt, GotoNode)
376388
pc´ = stmt.label
377389
elseif isa(stmt, GotoIfNot)
378390
l = stmt.dest::Int
379-
if handler_at[l] != cur_hand
380-
@assert handler_at[l] == 0 "unbalanced try/catch"
381-
handler_at[l] = cur_hand
391+
if handler_at[l] != cur_stacks
392+
@assert handler_at[l][1] == 0 || handler_at[l][1] == cur_stacks[1] "unbalanced try/catch"
393+
handler_at[l] = cur_stacks
382394
push!(ip, l)
383395
end
384396
elseif isa(stmt, ReturnNode)
385-
@assert !isdefined(stmt, :val) "unbalanced try/catch"
397+
@assert !isdefined(stmt, :val) || cur_stacks[1] == 0 "unbalanced try/catch"
386398
break
387399
elseif isa(stmt, Expr)
388400
head = stmt.head
389401
if head === :enter
390-
cur_hand = pc
402+
# Already set above
403+
cur_stacks = (handler_at[pc´][1], cur_stacks[2])
391404
elseif head === :leave
392405
l = 0
393406
for j = 1:length(stmt.args)
@@ -403,19 +416,21 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
403416
end
404417
l += 1
405418
end
419+
cur_hand = cur_stacks[1]
406420
for i = 1:l
407-
cur_hand = handler_at[cur_hand]
421+
cur_hand = handler_at[handlers[cur_hand].enter_idx][1]
408422
end
409-
cur_hand == 0 && break
423+
cur_stacks = (cur_hand, cur_stacks[2])
424+
cur_stacks == (0, 0) && break
425+
elseif head === :pop_exception
426+
cur_stacks = (cur_stacks[1], handler_at[(stmt.args[1]::SSAValue).id][2])
427+
cur_stacks == (0, 0) && break
410428
end
411429
end
412430

413431
pc´ > n && break # can't proceed with the fast-path fall-through
414-
if handler_at[pc´] != cur_hand
415-
if handler_at[pc´] != 0
416-
@assert false "unbalanced try/catch"
417-
end
418-
handler_at[pc´] = cur_hand
432+
if handler_at[pc´] != cur_stacks
433+
handler_at[pc´] = cur_stacks
419434
elseif !in(pc´, ip)
420435
break # already visited
421436
end
@@ -424,7 +439,7 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
424439
end
425440

426441
@assert first(ip) == n + 1
427-
return handler_at
442+
return handler_at, handlers
428443
end
429444

430445
# check if coverage mode is enabled

base/compiler/optimize.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -925,8 +925,10 @@ function run_passes_ipo_safe(
925925
# @timeit "verify 2" verify_ir(ir)
926926
@pass "compact 2" ir = compact!(ir)
927927
@pass "SROA" ir = sroa_pass!(ir, sv.inlining)
928-
@pass "ADCE" ir = adce_pass!(ir, sv.inlining)
929-
@pass "compact 3" ir = compact!(ir, true)
928+
@pass "ADCE" (ir, made_changes) = adce_pass!(ir, sv.inlining)
929+
if made_changes
930+
@pass "compact 3" ir = compact!(ir, true)
931+
end
930932
if JLOptions().debug_level == 2
931933
@timeit "verify 3" (verify_ir(ir, true, false, optimizer_lattice(sv.inlining.interp)); verify_linetable(ir.linetable))
932934
end

base/compiler/ssair/ir.jl

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@ struct CFGTransformState
587587
result_bbs::Vector{BasicBlock}
588588
bb_rename_pred::Vector{Int}
589589
bb_rename_succ::Vector{Int}
590+
domtree::Union{Nothing, DomTree}
590591
end
591592

592593
# N.B.: Takes ownership of the CFG array
@@ -622,11 +623,14 @@ function CFGTransformState!(blocks::Vector{BasicBlock}, allow_cfg_transforms::Bo
622623
let blocks = blocks, bb_rename = bb_rename
623624
result_bbs = BasicBlock[blocks[i] for i = 1:length(blocks) if bb_rename[i] != -1]
624625
end
626+
# TODO: This could be done by just renaming the domtree
627+
domtree = construct_domtree(result_bbs)
625628
else
626629
bb_rename = Vector{Int}()
627630
result_bbs = blocks
631+
domtree = nothing
628632
end
629-
return CFGTransformState(allow_cfg_transforms, allow_cfg_transforms, result_bbs, bb_rename, bb_rename)
633+
return CFGTransformState(allow_cfg_transforms, allow_cfg_transforms, result_bbs, bb_rename, bb_rename, domtree)
630634
end
631635

632636
mutable struct IncrementalCompact
@@ -681,7 +685,7 @@ mutable struct IncrementalCompact
681685
bb_rename = Vector{Int}()
682686
pending_nodes = NewNodeStream()
683687
pending_perm = Int[]
684-
return new(code, parent.result, CFGTransformState(false, false, parent.cfg_transform.result_bbs, bb_rename, bb_rename),
688+
return new(code, parent.result, CFGTransformState(false, false, parent.cfg_transform.result_bbs, bb_rename, bb_rename, nothing),
685689
ssa_rename, parent.used_ssas,
686690
parent.late_fixup, perm, 1,
687691
parent.new_new_nodes, parent.new_new_used_ssas, pending_nodes, pending_perm,
@@ -942,6 +946,14 @@ function insert_node_here!(compact::IncrementalCompact, newinst::NewInstruction,
942946
return inst
943947
end
944948

949+
function delete_inst_here!(compact)
950+
# Delete the statement, update refcounts etc
951+
compact[SSAValue(compact.result_idx-1)] = nothing
952+
# Pretend that we never compacted this statement in the first place
953+
compact.result_idx -= 1
954+
return nothing
955+
end
956+
945957
function getindex(view::TypesView, v::OldSSAValue)
946958
id = v.id
947959
ir = view.ir.ir
@@ -1222,19 +1234,25 @@ end
12221234

12231235
# N.B.: from and to are non-renamed indices
12241236
function kill_edge!(compact::IncrementalCompact, active_bb::Int, from::Int, to::Int)
1225-
# Note: We recursively kill as many edges as are obviously dead. However, this
1226-
# may leave dead loops in the IR. We kill these later in a CFG cleanup pass (or
1227-
# worstcase during codegen).
1228-
(; bb_rename_pred, bb_rename_succ, result_bbs) = compact.cfg_transform
1237+
# Note: We recursively kill as many edges as are obviously dead.
1238+
(; bb_rename_pred, bb_rename_succ, result_bbs, domtree) = compact.cfg_transform
12291239
preds = result_bbs[bb_rename_succ[to]].preds
12301240
succs = result_bbs[bb_rename_pred[from]].succs
12311241
deleteat!(preds, findfirst(x::Int->x==bb_rename_pred[from], preds)::Int)
12321242
deleteat!(succs, findfirst(x::Int->x==bb_rename_succ[to], succs)::Int)
1243+
if domtree !== nothing
1244+
domtree_delete_edge!(domtree, result_bbs, bb_rename_pred[from], bb_rename_succ[to])
1245+
end
12331246
# Check if the block is now dead
1234-
if length(preds) == 0
1235-
for succ in copy(result_bbs[bb_rename_succ[to]].succs)
1236-
kill_edge!(compact, active_bb, to, findfirst(x::Int->x==succ, bb_rename_pred)::Int)
1247+
if length(preds) == 0 || (domtree !== nothing && bb_unreachable(domtree, bb_rename_succ[to]))
1248+
to_succs = result_bbs[bb_rename_succ[to]].succs
1249+
for succ in copy(to_succs)
1250+
new_succ = findfirst(x::Int->x==succ, bb_rename_pred)
1251+
new_succ === nothing && continue
1252+
kill_edge!(compact, active_bb, to, new_succ)
12371253
end
1254+
empty!(preds)
1255+
empty!(to_succs)
12381256
if to < active_bb
12391257
# Kill all statements in the block
12401258
stmts = result_bbs[bb_rename_succ[to]].stmts

0 commit comments

Comments
 (0)