Skip to content

Commit 65b5880

Browse files
committed
RFC: Change lowering of destructuring to avoid const prop dependence
I'm currently doing some work with inference passes that have const prop (temporarily) disabled and I noticed we actually rely on it quite a bit for basic things. That's not terrible - const prop works pretty well after all, but it still imposes a cost and while I want to support it in my AD use case also, it makes destructuring quite expensive, because everything needs to be inferred twice. This PR is an experiment in changing the lowering to avoid having to const prop the index. Rather than lowering `(a,b,c) = foo()` as: ``` it = foo() a, s = indexed_iterate(it, 1) b, s = indexed_iterate(it, 2) c, s = indexed_iterate(it, 3) ``` we lower as: ``` it = foo() iterate, index = index_and_itereate(it) x = iterate(it) a = index(x, 1) y = iterate(it, y) b = index(y, 2) z = iterate(it, z) c = index(z, 3) ``` For tuples `iterate` would simply return the first argument and `index` would be `getfield`. That way, there is no const prop, since `getfield` is called directly and inference can directly use its tfunc. For the fallback case `iterate` is basically just `Base.iterate`, with just a slight tweak to give an intelligent error for short iterables. On simple functions, there isn't much of a difference in execution time, but benchmarking something more complicated like: ``` function g() a, = getfield(((1,),(2.0,3),("x",),(:x,)), Base.inferencebarrier(1)) nothing end ``` shows about a 20% improvement in end-to-end inference/optimize time, which is substantial.
1 parent 6de97d5 commit 65b5880

File tree

9 files changed

+51
-31
lines changed

9 files changed

+51
-31
lines changed

base/missing.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ convert(::Type{T}, x::T) where {T>:Union{Missing, Nothing}} = x
6969
convert(::Type{T}, x) where {T>:Missing} = convert(nonmissingtype_checked(T), x)
7070
convert(::Type{T}, x) where {T>:Union{Missing, Nothing}} = convert(nonmissingtype_checked(nonnothingtype_checked(T)), x)
7171

72+
index_and_iterate(::Missing) = throw(MethodError(iterate, (missing,)))
7273

7374
# Comparison operators
7475
==(::Missing, ::Missing) = missing

base/namedtuple.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,10 @@ firstindex(t::NamedTuple) = 1
112112
lastindex(t::NamedTuple) = nfields(t)
113113
getindex(t::NamedTuple, i::Int) = getfield(t, i)
114114
getindex(t::NamedTuple, i::Symbol) = getfield(t, i)
115-
indexed_iterate(t::NamedTuple, i::Int, state=1) = (getfield(t, i), i+1)
116115
isempty(::NamedTuple{()}) = true
117116
isempty(::NamedTuple) = false
118117
empty(::NamedTuple) = NamedTuple()
118+
index_and_iterate(t::NamedTuple) = (arg1, getfield)
119119

120120
convert(::Type{NamedTuple{names,T}}, nt::NamedTuple{names,T}) where {names,T<:Tuple} = nt
121121
convert(::Type{NamedTuple{names}}, nt::NamedTuple{names}) where {names} = nt

base/pair.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Pair, =>
4747

4848
eltype(p::Type{Pair{A, B}}) where {A, B} = Union{A, B}
4949
iterate(p::Pair, i=1) = i > 2 ? nothing : (getfield(p, i), i + 1)
50-
indexed_iterate(p::Pair, i::Int, state=1) = (getfield(p, i), i + 1)
50+
index_and_iterate(p::Pair) = (arg1, getfield)
5151

5252
hash(p::Pair, h::UInt) = hash(p.second, hash(p.first, h))
5353

base/tuple.jl

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,21 +81,41 @@ function _maxlength(t::Tuple, t2::Tuple, t3::Tuple...)
8181
max(length(t), _maxlength(t2, t3...))
8282
end
8383

84-
# this allows partial evaluation of bounded sequences of next() calls on tuples,
85-
# while reducing to plain next() for arbitrary iterables.
86-
indexed_iterate(t::Tuple, i::Int, state=1) = (@_inline_meta; (getfield(t, i), i+1))
87-
indexed_iterate(a::Array, i::Int, state=1) = (@_inline_meta; (a[i], i+1))
88-
function indexed_iterate(I, i)
89-
x = iterate(I)
90-
x === nothing && throw(BoundsError(I, i))
91-
x
92-
end
93-
function indexed_iterate(I, i, state)
94-
x = iterate(I, state)
95-
x === nothing && throw(BoundsError(I, i))
96-
x
84+
# this allows partial evaluation of bounded sequences of iterate() calls on tuples,
85+
# while reducing to plain iterate() for arbitrary iterables.
86+
87+
arg1(a) = a
88+
arg1(a, b) = a
89+
index_and_iterate(t::Tuple) = (arg1, getfield)
90+
index_and_iterate(t::Array) = (arg1, getindex)
91+
92+
struct BadSlurp
93+
a
94+
end
95+
96+
function slurp_iterate(a)
97+
@_inline_meta
98+
s = iterate(a)
99+
s === nothing && return BadSlurp(a)
100+
s
97101
end
98102

103+
function slurp_iterate(a, b)
104+
@_inline_meta
105+
s = iterate(a, getfield(b, 2))
106+
s === nothing && return BadSlurp(a)
107+
s
108+
end
109+
110+
select_first(a::BadSlurp, i) = throw(BoundsError(a.a, i))
111+
select_first(a, i) = getfield(a, 1)
112+
113+
index_and_iterate(x) = (slurp_iterate, select_first)
114+
115+
# Nothing is often union'ed into other things. Kill that as quickly as possible
116+
# to make inference's life easier.
117+
index_and_iterate(::Nothing) = throw(MethodError(iterate, (nothing,)))
118+
99119
# Use dispatch to avoid a branch in first
100120
first(::Tuple{}) = throw(ArgumentError("tuple must be non-empty"))
101121
first(t::Tuple) = t[1]

src/common_symbols1.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jl_symbol("*"),
3333
jl_symbol("bitcast"),
3434
jl_symbol("slt_int"),
3535
jl_symbol("isempty"),
36-
jl_symbol("indexed_iterate"),
36+
jl_symbol("index_and_iterate"),
3737
jl_symbol("size"),
3838
jl_symbol("!"),
3939
jl_symbol("nothing"),

src/julia-syntax.scm

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2049,17 +2049,21 @@
20492049
x (make-ssavalue)))
20502050
(ini (if (eq? x xx) '() (list (sink-assignment xx (expand-forms x)))))
20512051
(n (length lhss))
2052+
(funcs (make-ssavalue))
2053+
(iterate (make-ssavalue))
2054+
(index (make-ssavalue))
20522055
(st (gensy)))
20532056
`(block
20542057
,@ini
2058+
,(lower-tuple-assignment
2059+
(list iterate index)
2060+
`(call (top index_and_iterate) ,xx))
20552061
,.(map (lambda (i lhs)
2056-
(expand-forms
2057-
(lower-tuple-assignment
2058-
(if (= i (- n 1))
2059-
(list lhs)
2060-
(list lhs st))
2061-
`(call (top indexed_iterate)
2062-
,xx ,(+ i 1) ,.(if (eq? i 0) '() `(,st))))))
2062+
(expand-forms
2063+
`(block
2064+
(= ,st (call ,iterate
2065+
,xx ,.(if (eq? i 0) '() `(,st))))
2066+
(= ,lhs (call ,index ,st ,(+ i 1))))))
20632067
(iota n)
20642068
lhss)
20652069
(unnecessary ,xx))))))

stdlib/Serialization/src/Serialization.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ const TAGS = Any[
6666
:(=), :(==), :(===), :gotoifnot, :A, :B, :C, :M, :N, :T, :S, :X, :Y, :a, :b, :c, :d, :e, :f,
6767
:g, :h, :i, :j, :k, :l, :m, :n, :o, :p, :q, :r, :s, :t, :u, :v, :w, :x, :y, :z, :add_int,
6868
:sub_int, :mul_int, :add_float, :sub_float, :new, :mul_float, :bitcast, :start, :done, :next,
69-
:indexed_iterate, :getfield, :meta, :eq_int, :slt_int, :sle_int, :ne_int, :push_loc, :pop_loc,
69+
:index_and_iterate, :getfield, :meta, :eq_int, :slt_int, :sle_int, :ne_int, :push_loc, :pop_loc,
7070
:pop, :arrayset, :arrayref, :apply_type, :inbounds, :getindex, :setindex!, :Core, :!, :+,
7171
:Base, :static_parameter, :convert, :colon, Symbol("#self#"), Symbol("#temp#"), :tuple, Symbol(""),
7272

@@ -78,7 +78,7 @@ const TAGS = Any[
7878

7979
@assert length(TAGS) == 255
8080

81-
const ser_version = 11 # do not make changes without bumping the version #!
81+
const ser_version = 12 # do not make changes without bumping the version #!
8282

8383
const NTAGS = length(TAGS)
8484

test/compiler/inference.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -562,10 +562,7 @@ end
562562

563563
function g19348(x)
564564
a, b = x
565-
g = 1
566-
g = 2
567-
c = Base.indexed_iterate(x, g, g)
568-
return a + b + c[1]
565+
return a + b
569566
end
570567

571568
for (codetype, all_ssa) in Any[

test/dict.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ using Random
1111
@test iterate(p, iterate(p, iterate(p)[2])[2]) == nothing
1212
@test firstindex(p) == 1
1313
@test lastindex(p) == length(p) == 2
14-
@test Base.indexed_iterate(p, 1, nothing) == (10,2)
15-
@test Base.indexed_iterate(p, 2, nothing) == (20,3)
1614
@test (1=>2) < (2=>3)
1715
@test (2=>2) < (2=>3)
1816
@test !((2=>3) < (2=>3))

0 commit comments

Comments
 (0)