Skip to content

Commit b96ba19

Browse files
committed
Automatically parallelize map
1 parent bbdb280 commit b96ba19

File tree

5 files changed

+219
-1
lines changed

5 files changed

+219
-1
lines changed

base/Base.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,8 @@ include("task.jl")
273273
include("threads_overloads.jl")
274274
include("weakkeydict.jl")
275275

276+
include("parallelism.jl")
277+
276278
include("env.jl")
277279

278280
# BinaryPlatforms, used by Artifacts

base/abstractarray.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2862,7 +2862,12 @@ function map!(f::F, dest::AbstractArray, A::AbstractArray) where F
28622862
end
28632863

28642864
# map on collections
2865-
map(f, A::AbstractArray) = collect_similar(A, Generator(f,A))
2865+
function map(f, A::AbstractArray)
2866+
iter = Generator(f, A)
2867+
ans = _maybe_parallelize_collect(iter)
2868+
ans === nothing || return something(ans)
2869+
return collect_similar(A, iter)
2870+
end
28662871

28672872
mapany(f, A::AbstractArray) = map!(f, Vector{Any}(undef, length(A)), A)
28682873
mapany(f, itr) = Any[f(x) for x in itr]

base/array.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,12 @@ else
775775
end
776776

777777
function collect(itr::Generator)
778+
ans = _maybe_parallelize_collect(itr)
779+
ans === nothing || return something(ans) # Note: `@something` not available yet here
780+
return _serial_collect(itr)
781+
end
782+
783+
function _serial_collect(itr::Generator)
778784
isz = IteratorSize(itr.iter)
779785
et = @default_eltype(itr)
780786
if isa(isz, SizeUnknown)

base/compiler/typeinfer.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,3 +1023,50 @@ function _return_type(interp::AbstractInterpreter, @nospecialize(f), @nospeciali
10231023
end
10241024
return rt
10251025
end
1026+
1027+
function infer_effects(@nospecialize(f), @nospecialize(t))
1028+
world = ccall(:jl_get_tls_world_age, UInt, ())
1029+
return ccall(:jl_call_in_typeinf_world, Any, (Ptr{Ptr{Cvoid}}, Cint), Any[_infer_effects, f, t, world], 4)
1030+
end
1031+
1032+
_infer_effects(@nospecialize(f), @nospecialize(t), world) = _infer_effects(NativeInterpreter(world), f, t)
1033+
1034+
function _infer_effects(interp::AbstractInterpreter, @nospecialize(f), @nospecialize(t))
1035+
if isa(f, Builtin)
1036+
argtypes = Any[t.parameters...]
1037+
rt = builtin_tfunction(interp, f, argtypes, nothing)
1038+
return builtin_effect(interp, f, argtypes, widenconst(rt))
1039+
else
1040+
eff = nothing
1041+
for match in _methods(f, t, -1, get_world_counter(interp))::Vector
1042+
match = match::MethodMatch
1043+
ans = _infer_effects(interp, match.method, match.spec_types, match.sparams)
1044+
eff = (eff === nothing ? ans : tristate_merge(eff::Effects, ans))::Effects
1045+
eff == Effects() && return eff
1046+
end
1047+
eff === nothing && return Effects()
1048+
return eff
1049+
end
1050+
end
1051+
1052+
# like `typeinf_type` but for effects
1053+
function _infer_effects(interp::AbstractInterpreter, method::Method, @nospecialize(atype), sparams::SimpleVector)
1054+
if contains_is(unwrap_unionall(atype).parameters, Union{})
1055+
return Effects()
1056+
end
1057+
mi = specialize_method(method, atype, sparams)::MethodInstance
1058+
for i = 1:2 # test-and-lock-and-test
1059+
i == 2 && ccall(:jl_typeinf_begin, Cvoid, ())
1060+
code = get(code_cache(interp), mi, nothing)
1061+
if code isa CodeInstance
1062+
# see if this CodeInstance already exists in the cache
1063+
i == 2 && ccall(:jl_typeinf_end, Cvoid, ())
1064+
return ipo_effects(code)
1065+
end
1066+
end
1067+
result = InferenceResult(mi)
1068+
typeinf(interp, result, :global)
1069+
ccall(:jl_typeinf_end, Cvoid, ())
1070+
result.result isa InferenceState && return Effects()
1071+
return result.ipo_effects
1072+
end

base/parallelism.jl

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
"""
4+
_is_effect_free(f, t::Type{<:Tuple}) -> ans::Bool
5+
6+
Return `true` if `f(args...)` with `args::t` is consdiered to be `:effect_free`
7+
as defined in `@assume_effects`.
8+
"""
9+
function _is_effect_free(f::F, t) where {F}
10+
eff = Core.Compiler.infer_effects(f, t)
11+
return eff.effect_free == Core.Compiler.ALWAYS_TRUE
12+
end
13+
14+
# Choosing a rather very conservative bound for parallelism to avoid slowing
15+
# down the case parallelism is not helpful. This number is chosen s.t. the
16+
# sequential `map(identity, ::Vector{Float64})` of this size takes about > 100
17+
# μs. So, the cost of task spawn/sync is likely neglegible even for very
18+
# trivial `f`. This is a global mutable so that it is easy to check the
19+
# parallel path in end-to-end tests.
20+
const _MAP_MIN_BASESIZE = Ref(2^18)
21+
22+
function _maybe_parallelize_collect(iter::Generator)
23+
# TODO: use transducer to cleanly implement this
24+
# TODO: support filter, flatten, product, etc.
25+
arrays = if iter.iter isa AbstractArray
26+
(iter.iter,)
27+
elseif iter.iter isa Iterators.Zip{<:Tuple{Vararg{AbstractArray}}}
28+
iter.iter.is
29+
else
30+
return nothing
31+
end
32+
arrays === () && return nothing
33+
34+
# TODO: Maybe avoid parallel implementation if `f` and `getindex` are also
35+
# `:consistent`? It can happen for something like FillArrays.
36+
37+
# TODO: guess a good number from the cost of `f` and `getindex`?
38+
min_basesize = max(2, _MAP_MIN_BASESIZE[])
39+
length(arrays[1]) < min_basesize && return nothing
40+
41+
# Only handle compatible shape (and thus uniform length) for now.
42+
all(==(axes(arrays[1])), map(axes, tail(arrays))) || return nothing
43+
shape = size(arrays[1]) # relies on the check above
44+
45+
Threads.nthreads() == 1 && return nothing
46+
47+
indices = eachindex(arrays...)
48+
indextype = eltype(indices)
49+
_is_effect_free(ith_all, Tuple{indextype,typeof(arrays)}) || return nothing
50+
51+
# FIXME: `getvalue` captures `iter` just in case `f` is a `Type` (which
52+
# would be captured as a `DataType`)
53+
getvalue = if iter.iter isa AbstractArray
54+
_is_effect_free(iter.f, Tuple{eltype(arrays[1])}) || return nothing
55+
@inline getvalue1(i) = iter.f(@inbounds arrays[1][i])
56+
else
57+
_is_effect_free(iter.f, Tuple{Tuple{map(eltype, arrays)...}}) || return nothing
58+
@inline getvalue2(i) = iter.f((@inbounds ith_all(i, arrays)))
59+
end
60+
61+
# Cap the `basesize` assuming that the workload of `f` is uniform. This may
62+
# not be a good choice especially once we support filter and flatten.
63+
# However, since this code path is enabled automatically, it may be better
64+
# to play on the very safe side.
65+
basesize = min(min_basesize, cld(length(indices), Threads.nthreads()))
66+
67+
# Note: `@default_eltype` is incorrect if `T(....)` (with `T::Type`) does
68+
# not return a `T`. However, `collect(::Generator)` already uses `@default_eltype`.
69+
et = @default_eltype(iter)
70+
if isconcretetype(et)
71+
# We do not leak compiler internal here even though `et` is visible from
72+
# the user because we have checked that input is not empty. It is not
73+
# perfect since the sequential implementation of `collect(::Generator)`
74+
# itself does leak the compiler internal by returning `Array{et}`.
75+
# However, if/when `collect(::Generator)` solved this issue, the
76+
# parallel implementation does not get in the way of typocalyps-free
77+
# Base.
78+
dest = Array{et}(undef, size(arrays[1]))
79+
return Some(_parallel_map!(getvalue, dest, indices))
80+
else
81+
# TODO: use `_parallel_map!` if `allocatedinline(et)` and then refine
82+
# type (and fuse the mapping and the type bound computation)
83+
ys = _parallel_map(getvalue, indices; basesize)::Array{<:et}
84+
if length(shape) == 1
85+
return Some(ys)
86+
else
87+
return Some(reshape(ys, shape))
88+
end
89+
end
90+
end
91+
92+
"""
93+
_parallel_map!(f, dest, xs) -> dest
94+
95+
A parallel version of `map!` (i.e., `dest .= f.(xs)`).
96+
97+
Before turning this to a proper API (say) `Threads.map!`:
98+
* (ideally: define infrastructure for making it extensible)
99+
* use basesize to control parallelism in a compositional manner
100+
* reject obviously wrong inputs like `dest::BitArray` (or just support it)
101+
"""
102+
function _parallel_map!(f, dest, xs)
103+
# TODO: use divide-and-conquer strategy for fast spawn and sync
104+
# TODO: don't use `@threads`
105+
# TODO: since the caller allocates `dest` and `f` is effect-free, we can use
106+
# `@simd ivdep`
107+
Threads.@threads for i in eachindex(dest, xs)
108+
@inbounds dest[i] = f(xs[i])
109+
end
110+
return dest
111+
end
112+
113+
function _halve(xs::AbstractVector)
114+
f = firstindex(xs)
115+
l = lastindex(xs)
116+
h = length(xs) ÷ 2
117+
return view(xs, f:f+h), view(xs, f+h+1:l)
118+
end
119+
120+
_halve(xs::AbstractArray) = _halve_ndarray(xs)
121+
122+
function _halve_ndarray(xs::AbstractArray{N}, ::Val{D} = Val(N)) where {N,D}
123+
if D > 1
124+
size(xs, D) < 2 && return _halve_ndarray(xs, Val(D - 1))
125+
end
126+
f = firstindex(xs, D)
127+
l = lastindex(xs, D)
128+
h = size(xs, D) ÷ 2
129+
cs1 = ntuple(_ -> :, Val(D - 1))
130+
cs2 = ntuple(_ -> :, Val(N - D))
131+
return view(xs, cs1..., f:f+h, cs2...), view(xs, cs1..., f+h+1:l, cs2...)
132+
end
133+
134+
"""
135+
_parallel_map(f, xs; basesize) -> ys::Vector
136+
137+
Note: The output is always a `Vector` even if the input can have arbitrary
138+
`ndims`. The caller is responsible for `reshape`ing the output properly.
139+
"""
140+
function _parallel_map(f, xs; basesize)
141+
length(xs) <= max(2, basesize) && return vec(_serial_collect(Iterators.map(f, xs)))
142+
xs1, xs2 = _halve(xs)
143+
task = Threads.@spawn _parallel_map(f, xs2; basesize)
144+
ys1 = _parallel_map(f, xs1; basesize)::Vector
145+
ys2 = fetch(task)::Vector
146+
if eltype(ys2) <: eltype(ys1)
147+
return append!(ys1, ys2)
148+
elseif eltype(ys1) <: eltype(ys2)
149+
insert!(ys1, firstindex(ys1), ys2)
150+
return ys2
151+
else
152+
# Note: we cannot use `vcat` here since `collect` uses
153+
# `promote_typejoin` instead of `promote`
154+
T = promote_typejoin(eltype(ys1), eltype(ys2))
155+
ys3 = empty!(Vector{T}(undef, length(ys1) + length(ys2)))
156+
return append!(ys3, ys1, ys2)
157+
end
158+
end

0 commit comments

Comments
 (0)