Skip to content

Commit fb89009

Browse files
committed
Use map and filter transducers in foldl as an optimization
1 parent a68237f commit fb89009

File tree

1 file changed

+79
-30
lines changed

1 file changed

+79
-30
lines changed

base/reduce.jl

Lines changed: 79 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,31 +36,80 @@ mul_prod(x::Real, y::Real)::Real = x * y
3636

3737
## foldl && mapfoldl
3838

39-
function mapfoldl_impl(f, op, nt::NamedTuple{(:init,)}, itr, i...)
40-
init = nt.init
39+
mapfoldl_impl(f, op, nt, itr) = foldl_impl(op, nt, Generator(f, itr))
40+
41+
function foldl_impl(op, nt, itr)
42+
op′, itr′ = _xfadjoint(BottomRF(op), itr)
43+
return _foldl_impl(op′, nt, itr′)
44+
end
45+
46+
function _foldl_impl(op, nt, itr)
47+
init = get(nt, :init, _InitialValue())
4148
# Unroll the while loop once; if init is known, the call to op may
4249
# be evaluated at compile time
43-
y = iterate(itr, i...)
50+
y = iterate(itr)
4451
y === nothing && return init
45-
v = op(init, f(y[1]))
52+
v = op(init, y[1])
4653
while true
4754
y = iterate(itr, y[2])
4855
y === nothing && break
49-
v = op(v, f(y[1]))
56+
v = op(v, y[1])
5057
end
58+
v isa _InitialValue && reduce_empty_iter(op, itr, IteratorEltype(itr))
5159
return v
5260
end
5361

54-
function mapfoldl_impl(f, op, nt::NamedTuple{()}, itr)
55-
y = iterate(itr)
56-
if y === nothing
57-
return Base.mapreduce_empty_iter(f, op, itr, IteratorEltype(itr))
58-
end
59-
x, i = y
60-
init = mapreduce_first(f, op, x)
61-
return mapfoldl_impl(f, op, (init=init,), itr, i)
62+
struct _InitialValue end
63+
64+
struct BottomRF{T}
65+
rf::T
6266
end
6367

68+
@inline (op::BottomRF)(::_InitialValue, x) = x
69+
@inline (op::BottomRF)(acc, x) = op.rf(acc, x)
70+
71+
struct MappingRF{F, T}
72+
f::F
73+
rf::T
74+
end
75+
76+
@inline (op::MappingRF)(acc, x) = op.rf(acc, op.f(x))
77+
78+
struct FilteringRF{F, T}
79+
f::F
80+
rf::T
81+
end
82+
83+
@inline (op::FilteringRF)(acc, x) = op.f(x) ? op.rf(acc, x) : acc
84+
85+
"""
86+
_xfadjoint(op, itr) -> op′, itr′
87+
88+
Given a pair of reducing function `op` and an iterator `itr`, return a pair
89+
`(op′, itr′)` of similar types. If the iterator `itr` is transformed by an
90+
iterator transform `ixf` whose adjoint transducer `xf` is known, `op′ = xf(op)`
91+
and `itr′ = "parent" of itr` is returned. Otherwise, `op` and `itr` are
92+
returned as-is. For example, transducer `rf -> MappingRF(f, rf)` is the
93+
adjoint of iterator transform `itr -> Generator(f, itr)`.
94+
95+
Nested iterator transforms are converted recursively. That is to say,
96+
given `op` and
97+
98+
itr = (ixf₁ ∘ ixf₂ ∘ ... ∘ ixfₙ)(itr′)
99+
100+
what is returned is `itr′` and
101+
102+
op′ = (xfₙ ∘ ... ∘ xf₂ ∘ xf₁)(op)
103+
"""
104+
_xfadjoint(op, itr) = (op, itr)
105+
_xfadjoint(op, itr::Generator) =
106+
if itr.f === identity
107+
op, itr.iter
108+
else
109+
_xfadjoint(MappingRF(itr.f, op), itr.iter)
110+
end
111+
_xfadjoint(op, itr::Filter) =
112+
_xfadjoint(FilteringRF(itr.flt, op), itr.itr)
64113

65114
"""
66115
mapfoldl(f, op, itr; [init])
@@ -92,21 +141,14 @@ foldl(op, itr; kw...) = mapfoldl(identity, op, itr; kw...)
92141
## foldr & mapfoldr
93142

94143
mapfoldr_impl(f, op, nt::NamedTuple{(:init,)}, itr) =
95-
mapfoldl_impl(f, (x,y) -> op(y,x), nt, Iterators.reverse(itr))
96-
97-
# we can't just call mapfoldl_impl with (x,y) -> op(y,x), because
98-
# we need to use the type of op for mapreduce_empty_iter and mapreduce_first.
99-
function mapfoldr_impl(f, op, nt::NamedTuple{()}, itr)
100-
ritr = Iterators.reverse(itr)
101-
y = iterate(ritr)
102-
if y === nothing
103-
return Base.mapreduce_empty_iter(f, op, itr, IteratorEltype(itr))
104-
end
105-
x, i = y
106-
init = mapreduce_first(f, op, x)
107-
return mapfoldl_impl(f, (x,y) -> op(y,x), (init=init,), ritr, i)
144+
mapfoldl_impl(f, FlipArgs(op), nt, Iterators.reverse(itr))
145+
146+
struct FlipArgs{F}
147+
f::F
108148
end
109149

150+
@inline (f::FlipArgs)(x, y) = f.f(y, x)
151+
110152
"""
111153
mapfoldr(f, op, itr; [init])
112154
@@ -234,6 +276,11 @@ reduce_empty(::typeof(mul_prod), T) = reduce_empty(*, T)
234276
reduce_empty(::typeof(mul_prod), ::Type{T}) where {T<:SmallSigned} = one(Int)
235277
reduce_empty(::typeof(mul_prod), ::Type{T}) where {T<:SmallUnsigned} = one(UInt)
236278

279+
reduce_empty(op::BottomRF, T) = reduce_empty(op.rf, T)
280+
reduce_empty(op::MappingRF, T) = mapreduce_empty(op.f, op.rf, T)
281+
reduce_empty(op::FilteringRF, T) = reduce_empty(op.rf, T)
282+
reduce_empty(op::FlipArgs, T) = reduce_empty(op.f, T)
283+
237284
"""
238285
Base.mapreduce_empty(f, op, T)
239286
@@ -251,10 +298,12 @@ mapreduce_empty(::typeof(abs2), op, T) = abs2(reduce_empty(op, T))
251298
mapreduce_empty(f::typeof(abs), ::typeof(max), T) = abs(zero(T))
252299
mapreduce_empty(f::typeof(abs2), ::typeof(max), T) = abs2(zero(T))
253300

254-
mapreduce_empty_iter(f, op, itr, ::HasEltype) = mapreduce_empty(f, op, eltype(itr))
255-
mapreduce_empty_iter(f, op::typeof(&), itr, ::EltypeUnknown) = true
256-
mapreduce_empty_iter(f, op::typeof(|), itr, ::EltypeUnknown) = false
257-
mapreduce_empty_iter(f, op, itr, ::EltypeUnknown) = _empty_reduce_error()
301+
# For backward compatibility:
302+
mapreduce_empty_iter(f, op, itr, ItrEltype) =
303+
reduce_empty_iter(MappingRF(f, op), itr, ItrEltype)
304+
305+
reduce_empty_iter(op, itr, ::HasEltype) = reduce_empty(op, eltype(itr))
306+
reduce_empty_iter(op, itr, ::EltypeUnknown) = _empty_reduce_error()
258307

259308
# handling of single-element iterators
260309
"""

0 commit comments

Comments
 (0)