diff --git a/src/batch.jl b/src/batch.jl index fcc236c..0a7d8eb 100644 --- a/src/batch.jl +++ b/src/batch.jl @@ -1,8 +1,8 @@ struct BatchClosure{F,A,C} # C is a Val{Bool} triggering local storage f::F end -function (b::BatchClosure{F,A,C})(p::Ptr{UInt}) where {F,A,C} - (offset, args) = ThreadingUtilities.load(p, A, 2 * sizeof(UInt)) +function (b::BatchClosure{F,A,C})(p::Ptr{UInt}, offset) where {F,A,C} + (offset, args) = ThreadingUtilities.load(p, A, offset) (offset, start) = ThreadingUtilities.load(p, UInt, offset) (offset, stop) = ThreadingUtilities.load(p, UInt, offset) if C @@ -15,14 +15,39 @@ function (b::BatchClosure{F,A,C})(p::Ptr{UInt}) where {F,A,C} nothing end +(b::BatchClosure{F,A,C})(p::Ptr{UInt}) where {F,A,C} = b(p, 2 * sizeof(UInt)) + + +struct FakeClosure{F,A,C} end + +function (::FakeClosure{F,A,C})(p::Ptr{UInt}) where {F,A,C} + (offset, bc) = ThreadingUtilities.load(p, Reference{BatchClosure{F,A,C}}, 2 * sizeof(UInt)) + return bc(p, offset) +end + + +# Same condition as in `emit_cfunction` in 'julia/src/codegen.cpp' +const CFUNCTION_CLOSURES_UNAVAILABLE = Sys.ARCH in ( + :aarch64, :aarch64_be, :aarch64_32, # isAArch64 + :arm, :armeb, # isARM + :ppc64, :ppc64le # isPPC64 +) + + @generated function batch_closure(f::F, args::A, ::Val{C}) where {F,A,C} q = if Base.issingletontype(F) bc = BatchClosure{F,A,C}(F.instance) - :(@cfunction($bc, Cvoid, (Ptr{UInt},))) + :(return @cfunction($bc, Cvoid, (Ptr{UInt},)), nothing) + elseif CFUNCTION_CLOSURES_UNAVAILABLE + fc = FakeClosure{F,A,C}() + quote + bc = BatchClosure{F,A,C}(f) + return @cfunction($fc, Cvoid, (Ptr{UInt},)), bc + end else quote bc = BatchClosure{F,A,C}(f) - @cfunction($(Expr(:$, :bc)), Cvoid, (Ptr{UInt},)) + return @cfunction($(Expr(:$, :bc)), Cvoid, (Ptr{UInt},)), nothing end end return Expr(:block, Expr(:meta, :inline), q) @@ -32,14 +57,17 @@ end # @cfunction($bc, Cvoid, (Ptr{UInt},)) # end + @inline function setup_batch!( p::Ptr{UInt}, fptr::Ptr{Cvoid}, + closure_obj, argtup, start::UInt, stop::UInt, ) offset = ThreadingUtilities.store!(p, fptr, sizeof(UInt)) + !isnothing(closure_obj) && (offset = ThreadingUtilities.store!(p, Reference(closure_obj), offset)) offset = ThreadingUtilities.store!(p, argtup, offset) offset = ThreadingUtilities.store!(p, start, offset) offset = ThreadingUtilities.store!(p, stop, offset) @@ -48,35 +76,38 @@ end @inline function setup_batch!( p::Ptr{UInt}, fptr::Ptr{Cvoid}, + closure_obj, argtup, start::UInt, stop::UInt, i::UInt, ) offset = ThreadingUtilities.store!(p, fptr, sizeof(UInt)) + !isnothing(closure_obj) && (offset = ThreadingUtilities.store!(p, Reference(closure_obj), offset)) offset = ThreadingUtilities.store!(p, argtup, offset) offset = ThreadingUtilities.store!(p, start, offset) offset = ThreadingUtilities.store!(p, stop, offset) offset = ThreadingUtilities.store!(p, i, offset) nothing end -@inline function launch_batched_thread!(cfunc, tid, argtup, start, stop) +@inline function launch_batched_thread!(cfunc, closure_obj, tid, argtup, start, stop) fptr = Base.unsafe_convert(Ptr{Cvoid}, cfunc) - ThreadingUtilities.launch(tid, fptr, argtup, start, stop) do p, fptr, argtup, start, stop - setup_batch!(p, fptr, argtup, start, stop) + ThreadingUtilities.launch(tid, fptr, closure_obj, argtup, start, stop) do p, fptr, closure_obj, argtup, start, stop + setup_batch!(p, fptr, closure_obj, argtup, start, stop) end end -@inline function launch_batched_thread!(cfunc, tid, argtup, start, stop, i) +@inline function launch_batched_thread!(cfunc, closure_obj, tid, argtup, start, stop, i) fptr = Base.unsafe_convert(Ptr{Cvoid}, cfunc) ThreadingUtilities.launch( tid, fptr, + closure_obj, argtup, start, stop, i, - ) do p, fptr, argtup, start, stop, i - setup_batch!(p, fptr, argtup, start, stop, i) + ) do p, fptr, closure_obj, argtup, start, stop, i + setup_batch!(p, fptr, closure_obj, argtup, start, stop, i) end end _extract_params(::Type{T}) where {T<:Tuple} = T.parameters @@ -128,9 +159,9 @@ end Ndp = Nd + one(Nd) end launch_quote = if thread_local - :(launch_batched_thread!(cfunc, tid, argtup, start, stop, i % UInt)) + :(launch_batched_thread!(cfunc, closure_obj, tid, argtup, start, stop, tid % UInt)) else - :(launch_batched_thread!(cfunc, tid, argtup, start, stop)) + :(launch_batched_thread!(cfunc, closure_obj, tid, argtup, start, stop)) end rem_quote = if thread_local :(f!(arguments, (start + one(UInt)) % Int, ulen % Int, (sum(nthread_tuple) + 1) % Int)) @@ -173,7 +204,7 @@ end free_threads!(torelease_tuple) nothing end - gcpr = Expr(:gc_preserve, block, :cfunc) + gcpr = Expr(:gc_preserve, block, :cfunc, :closure_obj) argt = Expr(:tuple) for k ∈ 1:K add_var!(q, argt, gcpr, args[k], :args, :gcp, k) @@ -182,7 +213,7 @@ end q.args, :(arguments = $argt), :(argtup = Reference(arguments)), - :(cfunc = batch_closure(f!, argtup, Val{$thread_local}())), + :((cfunc, closure_obj) = batch_closure(f!, argtup, Val{$thread_local}())), gcpr, ) push!(q.args, nothing)