Skip to content

Commit fd178c3

Browse files
committed
inference: revive CachedMethodTable mechanism
`CachedMethodTable` was removed within #44240 as we couldn't confirm any performance improvement then. However it turns out the optimization was critical in some real world cases (e.g. #46492), so this commit revives the mechanism with the following tweaks that should make it more effective: - create method table cache per inference (rather than per local inference on a function call as on the previous implementation) - only use cache mechanism for abstract types (since we already cache lookup result at the next level as for concrete types) As a result, the following snippet reported at #46492 recovers the compilation performance: ```julia using ControlSystems a_2 = [-5 -3; 2 -9] C_212 = ss(a_2, [1; 2], [1 0; 0 1], [0; 0]) @time norm(C_212) ``` > on master ``` julia> @time norm(C_212) 364.489044 seconds (724.44 M allocations: 92.524 GiB, 6.01% gc time, 100.00% compilation time) 0.5345224838248489 ``` > on this commit ``` julia> @time norm(C_212) 26.539016 seconds (62.09 M allocations: 5.537 GiB, 5.55% gc time, 100.00% compilation time) 0.5345224838248489 ```
1 parent fb19a0a commit fd178c3

File tree

4 files changed

+68
-37
lines changed

4 files changed

+68
-37
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth
293293
if result === missing
294294
return FailedMethodMatch("For one of the union split cases, too many methods matched")
295295
end
296-
matches, overlayed = result
296+
(; matches, overlayed) = result
297297
nonoverlayed &= !overlayed
298298
push!(infos, MethodMatchInfo(matches))
299299
for m in matches
@@ -334,7 +334,7 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth
334334
# (assume this will always be true, so we don't compute / update valid age in this case)
335335
return FailedMethodMatch("Too many methods matched")
336336
end
337-
matches, overlayed = result
337+
(; matches, overlayed) = result
338338
fullmatch = _any(match->(match::MethodMatch).fully_covers, matches)
339339
return MethodMatches(matches.matches,
340340
MethodMatchInfo(matches),

base/compiler/compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,11 @@ something(x::Any, y...) = x
139139
############
140140

141141
include("compiler/cicache.jl")
142+
include("compiler/methodtable.jl")
142143
include("compiler/effects.jl")
143144
include("compiler/types.jl")
144145
include("compiler/utilities.jl")
145146
include("compiler/validation.jl")
146-
include("compiler/methodtable.jl")
147147

148148
function argextype end # imported by EscapeAnalysis
149149
function stmt_effect_free end # imported by EscapeAnalysis

base/compiler/methodtable.jl

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,27 @@
22

33
abstract type MethodTableView; end
44

5+
struct MethodLookupResult
6+
# Really Vector{Core.MethodMatch}, but it's easier to represent this as
7+
# and work with Vector{Any} on the C side.
8+
matches::Vector{Any}
9+
valid_worlds::WorldRange
10+
ambig::Bool
11+
end
12+
length(result::MethodLookupResult) = length(result.matches)
13+
function iterate(result::MethodLookupResult, args...)
14+
r = iterate(result.matches, args...)
15+
r === nothing && return nothing
16+
match, state = r
17+
return (match::MethodMatch, state)
18+
end
19+
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch
20+
21+
struct MethodMatchResult
22+
matches::MethodLookupResult
23+
overlayed::Bool
24+
end
25+
526
"""
627
struct InternalMethodTable <: MethodTableView
728
@@ -23,25 +44,21 @@ struct OverlayMethodTable <: MethodTableView
2344
mt::Core.MethodTable
2445
end
2546

26-
struct MethodLookupResult
27-
# Really Vector{Core.MethodMatch}, but it's easier to represent this as
28-
# and work with Vector{Any} on the C side.
29-
matches::Vector{Any}
30-
valid_worlds::WorldRange
31-
ambig::Bool
32-
end
33-
length(result::MethodLookupResult) = length(result.matches)
34-
function iterate(result::MethodLookupResult, args...)
35-
r = iterate(result.matches, args...)
36-
r === nothing && return nothing
37-
match, state = r
38-
return (match::MethodMatch, state)
47+
"""
48+
struct CachedMethodTable <: MethodTableView
49+
50+
Overlays another method table view with an additional local fast path cache that
51+
can respond to repeated, identical queries faster than the original method table.
52+
"""
53+
struct CachedMethodTable{T} <: MethodTableView
54+
cache::IdDict{Any, Union{Missing, MethodMatchResult}}
55+
table::T
3956
end
40-
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch
57+
CachedMethodTable(table::T) where T = CachedMethodTable{T}(IdDict{Any, Union{Missing, MethodMatchResult}}(), table)
4158

4259
"""
4360
findall(sig::Type, view::MethodTableView; limit::Int=typemax(Int)) ->
44-
(matches::MethodLookupResult, overlayed::Bool) or missing
61+
MethodMatchResult(matches::MethodLookupResult, overlayed::Bool) or missing
4562
4663
Find all methods in the given method table `view` that are applicable to the given signature `sig`.
4764
If no applicable methods are found, an empty result is returned.
@@ -51,7 +68,7 @@ If the number of applicable methods exceeded the specified limit, `missing` is r
5168
function findall(@nospecialize(sig::Type), table::InternalMethodTable; limit::Int=Int(typemax(Int32)))
5269
result = _findall(sig, nothing, table.world, limit)
5370
result === missing && return missing
54-
return result, false
71+
return MethodMatchResult(result, false)
5572
end
5673

5774
function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int=Int(typemax(Int32)))
@@ -60,18 +77,20 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int
6077
nr = length(result)
6178
if nr 1 && result[nr].fully_covers
6279
# no need to fall back to the internal method table
63-
return result, true
80+
return MethodMatchResult(result, true)
6481
end
6582
# fall back to the internal method table
6683
fallback_result = _findall(sig, nothing, table.world, limit)
6784
fallback_result === missing && return missing
6885
# merge the fallback match results with the internal method table
69-
return MethodLookupResult(
70-
vcat(result.matches, fallback_result.matches),
71-
WorldRange(
72-
max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world),
73-
min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)),
74-
result.ambig | fallback_result.ambig), !isempty(result)
86+
return MethodMatchResult(
87+
MethodLookupResult(
88+
vcat(result.matches, fallback_result.matches),
89+
WorldRange(
90+
max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world),
91+
min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)),
92+
result.ambig | fallback_result.ambig),
93+
!isempty(result))
7594
end
7695

7796
function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt, limit::Int)
@@ -85,6 +104,17 @@ function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable},
85104
return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0)
86105
end
87106

107+
function findall(@nospecialize(sig::Type), table::CachedMethodTable; limit::Int=typemax(Int))
108+
if isconcretetype(sig)
109+
# as for concrete types, we cache result at on the next level
110+
return findall(sig, table.table; limit)
111+
end
112+
box = Core.Box(sig)
113+
return get!(table.cache, sig) do
114+
findall(box.contents, table.table; limit)
115+
end
116+
end
117+
88118
"""
89119
findsup(sig::Type, view::MethodTableView) ->
90120
(match::MethodMatch, valid_worlds::WorldRange, overlayed::Bool) or nothing
@@ -129,6 +159,10 @@ function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable},
129159
return match, valid_worlds
130160
end
131161

162+
# This query is not cached
163+
findsup(@nospecialize(sig::Type), table::CachedMethodTable) = findsup(sig, table.table)
164+
132165
isoverlayed(::MethodTableView) = error("unsatisfied MethodTableView interface")
133166
isoverlayed(::InternalMethodTable) = false
134167
isoverlayed(::OverlayMethodTable) = true
168+
isoverlayed(mt::CachedMethodTable) = isoverlayed(mt.table)

base/compiler/types.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ struct NativeInterpreter <: AbstractInterpreter
158158
cache::Vector{InferenceResult}
159159
# The world age we're working inside of
160160
world::UInt
161+
# method table to lookup for during inference on this world age
162+
method_table::CachedMethodTable{InternalMethodTable}
161163

162164
# Parameters for inference and optimization
163165
inf_params::InferenceParams
@@ -167,27 +169,21 @@ struct NativeInterpreter <: AbstractInterpreter
167169
inf_params = InferenceParams(),
168170
opt_params = OptimizationParams(),
169171
)
172+
cache = Vector{InferenceResult}() # Initially empty cache
173+
170174
# Sometimes the caller is lazy and passes typemax(UInt).
171175
# we cap it to the current world age
172176
if world == typemax(UInt)
173177
world = get_world_counter()
174178
end
175179

180+
method_table = CachedMethodTable(InternalMethodTable(world))
181+
176182
# If they didn't pass typemax(UInt) but passed something more subtly
177183
# incorrect, fail out loudly.
178184
@assert world <= get_world_counter()
179185

180-
return new(
181-
# Initially empty cache
182-
Vector{InferenceResult}(),
183-
184-
# world age counter
185-
world,
186-
187-
# parameters for inference and optimization
188-
inf_params,
189-
opt_params,
190-
)
186+
return new(cache, world, method_table, inf_params, opt_params)
191187
end
192188
end
193189

@@ -251,6 +247,7 @@ External `AbstractInterpreter` can optionally return `OverlayMethodTable` here
251247
to incorporate customized dispatches for the overridden methods.
252248
"""
253249
method_table(interp::AbstractInterpreter) = InternalMethodTable(get_world_counter(interp))
250+
method_table(interp::NativeInterpreter) = interp.method_table
254251

255252
"""
256253
By default `AbstractInterpreter` implements the following inference bail out logic:

0 commit comments

Comments
 (0)