@@ -19,17 +19,19 @@ If `x` or `X` are not specified, `quadgk` internally creates a new `BatchIntegra
1919user-supplied `y` buffer and a freshly-allocated `x` buffer based on the domain types. So,
2020if you want to reuse the `x` buffer between calls, supply `{Y,X}` or pass `y,x` explicitly.
2121"""
22- struct BatchIntegrand{Y,X,Ty<: AbstractVector{Y} ,Tx<: AbstractVector{X} ,F}
22+ struct BatchIntegrand{Y,X,Ty<: AbstractVector{Y} ,Tx<: AbstractVector{X} ,F,T }
2323 # in-place function f!(y, x) that takes an array of x values and outputs an array of results in-place
2424 f!:: F
2525 y:: Ty
2626 x:: Tx
27+ t:: T
2728 max_batch:: Int # maximum number of x to supply in parallel
2829end
2930
3031function BatchIntegrand (f!, y:: AbstractVector , x:: AbstractVector = similar (y, Nothing); max_batch:: Integer = typemax (Int))
3132 max_batch > 0 || throw (ArgumentError (" max_batch must be positive" ))
32- return BatchIntegrand (f!, y, x, max_batch)
33+ X = eltype (x)
34+ return BatchIntegrand (f!, y, x, X <: Nothing ? nothing : similar (x, typeof (one (X))), max_batch)
3335end
3436BatchIntegrand {Y,X} (f!; kws... ) where {Y,X} = BatchIntegrand (f!, Y[], X[]; kws... )
3537BatchIntegrand {Y} (f!; kws... ) where {Y} = BatchIntegrand (f!, Y[]; kws... )
@@ -65,20 +67,20 @@ function evalrules(f::BatchIntegrand, s::NTuple{N}, x,w,gw, nrm) where {N}
6567 l = length (x)
6668 m = 2 l- 1 # evaluations per segment
6769 n = (N- 1 )* m # total evaluations
68- resize! (f. x , n)
70+ resize! (f. t , n)
6971 resize! (f. y, n)
7072 for i in 1 : (N- 1 ) # fill buffer with evaluation points
7173 a = s[i]; b = s[i+ 1 ]
7274 check_endpoint_roundoff (a, b, x, throw_error= true )
7375 c = convert (eltype (x), 0.5 ) * (b- a)
7476 o = (i- 1 )* m
75- f. x [l+ o] = a + c
77+ f. t [l+ o] = a + c
7678 for j in 1 : l- 1
77- f. x [j+ o] = a + (1 + x[j]) * c
78- f. x [m+ 1 - j+ o] = a + (1 - x[j]) * c
79+ f. t [j+ o] = a + (1 + x[j]) * c
80+ f. t [m+ 1 - j+ o] = a + (1 - x[j]) * c
7981 end
8082 end
81- f. f! (f. y, f. x ) # evaluate integrand
83+ f. f! (f. y, f. t ) # evaluate integrand
8284 return ntuple (Val (N- 1 )) do i
8385 return batchevalrule (view (f. y, (1 + (i- 1 )* m): (i* m)), s[i], s[i+ 1 ], x,w,gw, nrm)
8486 end
@@ -105,7 +107,7 @@ function refine(f::BatchIntegrand, segs::Vector{T}, I, E, numevals, x,w,gw,n, at
105107 len > nsegs && DataStructures. percolate_down! (segs, 1 , y, Reverse, len- nsegs)
106108 end
107109
108- resize! (f. x , 2 m* nsegs)
110+ resize! (f. t , 2 m* nsegs)
109111 resize! (f. y, 2 m* nsegs)
110112 for i in 1 : nsegs # fill buffer with evaluation points
111113 s = segs[len- i+ 1 ]
@@ -114,15 +116,15 @@ function refine(f::BatchIntegrand, segs::Vector{T}, I, E, numevals, x,w,gw,n, at
114116 check_endpoint_roundoff (a, b, x) && return segs
115117 c = convert (eltype (x), 0.5 ) * (b- a)
116118 o = (2 i- j)* m
117- f. x [l+ o] = a + c
119+ f. t [l+ o] = a + c
118120 for k in 1 : l- 1
119121 # early return if integrand evaluated at endpoints
120- f. x [k+ o] = a + (1 + x[k]) * c
121- f. x [m+ 1 - k+ o] = a + (1 - x[k]) * c
122+ f. t [k+ o] = a + (1 + x[k]) * c
123+ f. t [m+ 1 - k+ o] = a + (1 - x[k]) * c
122124 end
123125 end
124126 end
125- f. f! (f. y, f. x )
127+ f. f! (f. y, f. t )
126128
127129 resize! (segs, len+ nsegs)
128130 for i in 1 : nsegs # evaluate segments and update estimates & heap
@@ -145,27 +147,26 @@ end
145147function handle_infinities (workfunc, f:: BatchIntegrand , s)
146148 s1, s2 = s[1 ], s[end ]
147149 u = float (real (oneunit (s1))) # the units of the segment
148- tbuf = similar (f. x, typeof (s1/ oneunit (s1)))
149150 if realone (s1) && realone (s2) # check for infinite or semi-infinite intervals
150151 inf1, inf2 = isinf (s1), isinf (s2)
151152 if inf1 || inf2
152153 if inf1 && inf2 # x = t/(1-t^2) coordinate transformation
153154 I, E = workfunc (BatchIntegrand ((v, t) -> begin resize! (f. x, length (t));
154- f. f! (v, f. x .= u .* t ./ (1 .- t .* t)); v .*= (1 .+ t .* t) ./ (1 .- t .* t) .^ 2 ; end , f. y, tbuf , f. max_batch),
155+ f. f! (v, f. x .= u .* t ./ (1 .- t .* t)); v .*= (1 .+ t .* t) ./ (1 .- t .* t) .^ 2 ; end , f. y, f . x, f . t , f. max_batch),
155156 map (x -> isinf (x) ? (signbit (x) ? - one (x) : one (x)) : 2 x / (oneunit (x)+ hypot (oneunit (x),2 x)), s),
156157 t -> u * t / (1 - t^ 2 ))
157158 return u * I, u * E
158159 end
159160 let (s0,si) = inf1 ? (s2,s1) : (s1,s2) # let is needed for JuliaLang/julia#15276
160161 if si < zero (si) # x = s0 - t/(1-t)
161162 I, E = workfunc (BatchIntegrand ((v, t) -> begin resize! (f. x, length (t));
162- f. f! (v, f. x .= s0 .- u .* t ./ (1 .- t)); v ./= (1 .- t) .^ 2 ; end , f. y, tbuf , f. max_batch),
163+ f. f! (v, f. x .= s0 .- u .* t ./ (1 .- t)); v ./= (1 .- t) .^ 2 ; end , f. y, f . x, f . t , f. max_batch),
163164 reverse (map (x -> 1 / (1 + oneunit (x) / (s0 - x)), s)),
164165 t -> s0 - u* t/ (1 - t))
165166 return u * I, u * E
166167 else # x = s0 + t/(1-t)
167168 I, E = workfunc (BatchIntegrand ((v, t) -> begin resize! (f. x, length (t));
168- f. f! (v, f. x .= s0 .+ u .* t ./ (1 .- t)); v ./= (1 .- t) .^ 2 ; end , f. y, tbuf , f. max_batch),
169+ f. f! (v, f. x .= s0 .+ u .* t ./ (1 .- t)); v ./= (1 .- t) .^ 2 ; end , f. y, f . x, f . t , f. max_batch),
169170 map (x -> 1 / (1 + oneunit (x) / (x - s0)), s),
170171 t -> s0 + u* t/ (1 - t))
171172 return u * I, u * E
@@ -174,7 +175,7 @@ function handle_infinities(workfunc, f::BatchIntegrand, s)
174175 end
175176 end
176177 I, E = workfunc (BatchIntegrand ((y, t) -> begin resize! (f. x, length (t));
177- f. f! (y, f. x .= u .* t); end , f. y, tbuf , f. max_batch),
178+ f. f! (y, f. x .= u .* t); end , f. y, f . x, f . t , f. max_batch),
178179 map (x -> x/ oneunit (x), s),
179180 identity)
180181 return u * I, u * E
@@ -199,6 +200,6 @@ simultaneously. In particular, there are two differences from `quadgk`
199200"""
200201function quadgk (f:: BatchIntegrand{Y,Nothing} , segs:: T... ; kws... ) where {Y,T}
201202 FT = float (T) # the gk points are floating-point
202- g = BatchIntegrand (f. f!, f. y, similar (f. x, FT), f. max_batch)
203+ g = BatchIntegrand (f. f!, f. y, similar (f. x, FT), similar (f . x, typeof ( float ( one (FT)))), f. max_batch)
203204 return quadgk (g, segs... ; kws... )
204205end
0 commit comments