Skip to content

Commit abcc700

Browse files
committed
add unitless buffer to BatchIntegrand
1 parent eda5c03 commit abcc700

File tree

2 files changed

+20
-19
lines changed

2 files changed

+20
-19
lines changed

src/batch.jl

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,19 @@ If `x` or `X` are not specified, `quadgk` internally creates a new `BatchIntegra
1919
user-supplied `y` buffer and a freshly-allocated `x` buffer based on the domain types. So,
2020
if you want to reuse the `x` buffer between calls, supply `{Y,X}` or pass `y,x` explicitly.
2121
"""
22-
struct BatchIntegrand{Y,X,Ty<:AbstractVector{Y},Tx<:AbstractVector{X},F}
22+
struct BatchIntegrand{Y,X,Ty<:AbstractVector{Y},Tx<:AbstractVector{X},F,T}
2323
# in-place function f!(y, x) that takes an array of x values and outputs an array of results in-place
2424
f!::F
2525
y::Ty
2626
x::Tx
27+
t::T
2728
max_batch::Int # maximum number of x to supply in parallel
2829
end
2930

3031
function BatchIntegrand(f!, y::AbstractVector, x::AbstractVector=similar(y, Nothing); max_batch::Integer=typemax(Int))
3132
max_batch > 0 || throw(ArgumentError("max_batch must be positive"))
32-
return BatchIntegrand(f!, y, x, max_batch)
33+
X = eltype(x)
34+
return BatchIntegrand(f!, y, x, X <: Nothing ? nothing : similar(x, typeof(one(X))), max_batch)
3335
end
3436
BatchIntegrand{Y,X}(f!; kws...) where {Y,X} = BatchIntegrand(f!, Y[], X[]; kws...)
3537
BatchIntegrand{Y}(f!; kws...) where {Y} = BatchIntegrand(f!, Y[]; kws...)
@@ -65,20 +67,20 @@ function evalrules(f::BatchIntegrand, s::NTuple{N}, x,w,gw, nrm) where {N}
6567
l = length(x)
6668
m = 2l-1 # evaluations per segment
6769
n = (N-1)*m # total evaluations
68-
resize!(f.x, n)
70+
resize!(f.t, n)
6971
resize!(f.y, n)
7072
for i in 1:(N-1) # fill buffer with evaluation points
7173
a = s[i]; b = s[i+1]
7274
check_endpoint_roundoff(a, b, x, throw_error=true)
7375
c = convert(eltype(x), 0.5) * (b-a)
7476
o = (i-1)*m
75-
f.x[l+o] = a + c
77+
f.t[l+o] = a + c
7678
for j in 1:l-1
77-
f.x[j+o] = a + (1 + x[j]) * c
78-
f.x[m+1-j+o] = a + (1 - x[j]) * c
79+
f.t[j+o] = a + (1 + x[j]) * c
80+
f.t[m+1-j+o] = a + (1 - x[j]) * c
7981
end
8082
end
81-
f.f!(f.y, f.x) # evaluate integrand
83+
f.f!(f.y, f.t) # evaluate integrand
8284
return ntuple(Val(N-1)) do i
8385
return batchevalrule(view(f.y, (1+(i-1)*m):(i*m)), s[i], s[i+1], x,w,gw, nrm)
8486
end
@@ -105,7 +107,7 @@ function refine(f::BatchIntegrand, segs::Vector{T}, I, E, numevals, x,w,gw,n, at
105107
len > nsegs && DataStructures.percolate_down!(segs, 1, y, Reverse, len-nsegs)
106108
end
107109

108-
resize!(f.x, 2m*nsegs)
110+
resize!(f.t, 2m*nsegs)
109111
resize!(f.y, 2m*nsegs)
110112
for i in 1:nsegs # fill buffer with evaluation points
111113
s = segs[len-i+1]
@@ -114,15 +116,15 @@ function refine(f::BatchIntegrand, segs::Vector{T}, I, E, numevals, x,w,gw,n, at
114116
check_endpoint_roundoff(a, b, x) && return segs
115117
c = convert(eltype(x), 0.5) * (b-a)
116118
o = (2i-j)*m
117-
f.x[l+o] = a + c
119+
f.t[l+o] = a + c
118120
for k in 1:l-1
119121
# early return if integrand evaluated at endpoints
120-
f.x[k+o] = a + (1 + x[k]) * c
121-
f.x[m+1-k+o] = a + (1 - x[k]) * c
122+
f.t[k+o] = a + (1 + x[k]) * c
123+
f.t[m+1-k+o] = a + (1 - x[k]) * c
122124
end
123125
end
124126
end
125-
f.f!(f.y, f.x)
127+
f.f!(f.y, f.t)
126128

127129
resize!(segs, len+nsegs)
128130
for i in 1:nsegs # evaluate segments and update estimates & heap
@@ -145,27 +147,26 @@ end
145147
function handle_infinities(workfunc, f::BatchIntegrand, s)
146148
s1, s2 = s[1], s[end]
147149
u = float(real(oneunit(s1))) # the units of the segment
148-
tbuf = similar(f.x, typeof(s1/oneunit(s1)))
149150
if realone(s1) && realone(s2) # check for infinite or semi-infinite intervals
150151
inf1, inf2 = isinf(s1), isinf(s2)
151152
if inf1 || inf2
152153
if inf1 && inf2 # x = t/(1-t^2) coordinate transformation
153154
I, E = workfunc(BatchIntegrand((v, t) -> begin resize!(f.x, length(t));
154-
f.f!(v, f.x .= u .* t ./ (1 .- t .* t)); v .*= (1 .+ t .* t) ./ (1 .- t .* t) .^ 2; end, f.y, tbuf, f.max_batch),
155+
f.f!(v, f.x .= u .* t ./ (1 .- t .* t)); v .*= (1 .+ t .* t) ./ (1 .- t .* t) .^ 2; end, f.y, f.x, f.t, f.max_batch),
155156
map(x -> isinf(x) ? (signbit(x) ? -one(x) : one(x)) : 2x / (oneunit(x)+hypot(oneunit(x),2x)), s),
156157
t -> u * t / (1 - t^2))
157158
return u * I, u * E
158159
end
159160
let (s0,si) = inf1 ? (s2,s1) : (s1,s2) # let is needed for JuliaLang/julia#15276
160161
if si < zero(si) # x = s0 - t/(1-t)
161162
I, E = workfunc(BatchIntegrand((v, t) -> begin resize!(f.x, length(t));
162-
f.f!(v, f.x .= s0 .- u .* t ./ (1 .- t)); v ./= (1 .- t) .^ 2; end, f.y, tbuf, f.max_batch),
163+
f.f!(v, f.x .= s0 .- u .* t ./ (1 .- t)); v ./= (1 .- t) .^ 2; end, f.y, f.x, f.t, f.max_batch),
163164
reverse(map(x -> 1 / (1 + oneunit(x) / (s0 - x)), s)),
164165
t -> s0 - u*t/(1-t))
165166
return u * I, u * E
166167
else # x = s0 + t/(1-t)
167168
I, E = workfunc(BatchIntegrand((v, t) -> begin resize!(f.x, length(t));
168-
f.f!(v, f.x .= s0 .+ u .* t ./ (1 .- t)); v ./= (1 .- t) .^ 2; end, f.y, tbuf, f.max_batch),
169+
f.f!(v, f.x .= s0 .+ u .* t ./ (1 .- t)); v ./= (1 .- t) .^ 2; end, f.y, f.x, f.t, f.max_batch),
169170
map(x -> 1 / (1 + oneunit(x) / (x - s0)), s),
170171
t -> s0 + u*t/(1-t))
171172
return u * I, u * E
@@ -174,7 +175,7 @@ function handle_infinities(workfunc, f::BatchIntegrand, s)
174175
end
175176
end
176177
I, E = workfunc(BatchIntegrand((y, t) -> begin resize!(f.x, length(t));
177-
f.f!(y, f.x .= u .* t); end, f.y, tbuf, f.max_batch),
178+
f.f!(y, f.x .= u .* t); end, f.y, f.x, f.t, f.max_batch),
178179
map(x -> x/oneunit(x), s),
179180
identity)
180181
return u * I, u * E
@@ -199,6 +200,6 @@ simultaneously. In particular, there are two differences from `quadgk`
199200
"""
200201
function quadgk(f::BatchIntegrand{Y,Nothing}, segs::T...; kws...) where {Y,T}
201202
FT = float(T) # the gk points are floating-point
202-
g = BatchIntegrand(f.f!, f.y, similar(f.x, FT), f.max_batch)
203+
g = BatchIntegrand(f.f!, f.y, similar(f.x, FT), similar(f.x, typeof(float(one(FT)))), f.max_batch)
203204
return quadgk(g, segs...; kws...)
204205
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ end
338338
end
339339

340340
# test constructors
341-
ref = BatchIntegrand(f!, Float64[], Nothing[], typemax(Int))
341+
ref = BatchIntegrand(f!, Float64[], Nothing[], nothing, typemax(Int))
342342
for b in (
343343
BatchIntegrand(f!, Float64[]),
344344
BatchIntegrand(f!, Float64[], Nothing[]),

0 commit comments

Comments
 (0)