@@ -43,6 +43,18 @@ julia> mean(skipmissing([1, missing, 3]))
4343"""
4444mean (itr) = mean (identity, itr)
4545
46+ struct Counter{F} <: Function
47+ f:: F
48+ n:: Base.RefValue{Int}
49+ end
50+ Counter (f:: F ) where {F} = Counter {F} (f, Ref (0 ))
51+ (f:: Counter )(x) = (f. n[] += 1 ; f. f (x))
52+
53+ struct DivOne{F} <: Function
54+ f:: F
55+ end
56+ (f:: DivOne )(x) = f. f (x)/ 1
57+
4658"""
4759 mean(f, itr)
4860
@@ -59,23 +71,13 @@ julia> mean([√1, √2, √3])
5971```
6072"""
6173function mean (f, itr)
62- y = iterate (itr)
63- if y === nothing
64- return Base. mapreduce_empty_iter (f, + , itr,
65- Base. IteratorEltype (itr)) / 0
66- end
67- count = 1
68- value, state = y
69- f_value = f (value)/ 1
70- total = Base. reduce_first (+ , f_value)
71- y = iterate (itr, state)
72- while y != = nothing
73- value, state = y
74- total += _mean_promote (total, f (value))
75- count += 1
76- y = iterate (itr, state)
74+ if Base. IteratorSize (itr) === Base. SizeUnknown ()
75+ g = Counter (DivOne (f))
76+ result = mapfoldl (g, add_mean, itr)
77+ return result/ g. n[]
78+ else
79+ return mapfoldl (DivOne (f), add_mean, itr)/ length (itr)
7780 end
78- return total/ count
7981end
8082
8183"""
@@ -180,20 +182,24 @@ mean(A::AbstractArray; dims=:) = _mean(identity, A, dims)
180182
181183_mean_promote (x:: T , y:: S ) where {T,S} = convert (promote_type (T, S), y)
182184
185+ add_mean (x, y) = Base. add_sum (x, _mean_promote (x, y))
186+
187+ Base. reduce_empty (:: typeof (add_mean), T) = Base. reduce_empty (Base. add_sum, T)
188+ Base. mapreduce_empty (g:: DivOne , :: typeof (add_mean), T) = Base. mapreduce_empty (g. f, Base. add_sum, T)/ 1
189+ Base. mapreduce_empty (g:: Counter{<:DivOne} , :: typeof (add_mean), T) = Base. mapreduce_empty (g. f. f, Base. add_sum, T)/ 1
190+
191+
183192# ::Dims is there to force specializing on Colon (as it is a Function)
184193function _mean (f, A:: AbstractArray , dims:: Dims = :) where Dims
185- isempty (A) && return sum (f, A, dims= dims)/ 0
186194 if dims === (:)
195+ result = mapreduce (DivOne (f), add_mean, A, dims= dims)
187196 n = length (A)
188- else
189- n = mapreduce (i -> size (A, i), * , unique (dims); init= 1 )
190- end
191- x1 = f (first (A)) / 1
192- result = sum (x -> _mean_promote (x1, f (x)), A, dims= dims)
193- if dims === (:)
194197 return result / n
195198 else
196- return result ./= n
199+ result = mapreduce (DivOne (f), add_mean, A, dims= dims)
200+ n = prod (i -> size (A, i), unique (dims); init= 1 )
201+ result ./= n
202+ return result
197203 end
198204end
199205
@@ -211,6 +217,7 @@ realXcY(x::Complex, y::Complex) = real(x)*real(y) + imag(x)*imag(y)
211217var (iterable; corrected:: Bool = true , mean= nothing ) = _var (iterable, corrected, mean)
212218
213219function _var (iterable, corrected:: Bool , mean)
220+ ismissing (mean) && return missing
214221 y = iterate (iterable)
215222 if y === nothing
216223 T = eltype (iterable)
@@ -252,61 +259,36 @@ function _var(iterable, corrected::Bool, mean)
252259 end
253260end
254261
255- centralizedabs2fun (m) = x -> abs2 .(x - m)
256- centralize_sumabs2 (A:: AbstractArray , m) =
257- mapreduce (centralizedabs2fun (m), + , A)
258- centralize_sumabs2 (A:: AbstractArray , m, ifirst:: Int , ilast:: Int ) =
259- Base. mapreduce_impl (centralizedabs2fun (m), + , A, ifirst, ilast)
260-
261- function centralize_sumabs2! (R:: AbstractArray{S} , A:: AbstractArray , means:: AbstractArray ) where S
262- # following the implementation of _mapreducedim! at base/reducedim.jl
263- lsiz = Base. check_reducedims (R,A)
264- for i in 1 : max (ndims (R), ndims (means))
265- if axes (means, i) != axes (R, i)
266- throw (DimensionMismatch (" dimension $i of `mean` should have indices $(axes (R, i)) , but got $(axes (means, i)) " ))
267- end
268- end
269- isempty (R) || fill! (R, zero (S))
270- isempty (A) && return R
271-
272- if Base. has_fast_linear_indexing (A) && lsiz > 16 && ! has_offset_axes (R, means)
273- nslices = div (length (A), lsiz)
274- ibase = first (LinearIndices (A))- 1
275- for i = 1 : nslices
276- @inbounds R[i] = centralize_sumabs2 (A, means[i], ibase+ 1 , ibase+ lsiz)
277- ibase += lsiz
278- end
279- return R
280- end
281- indsAt, indsRt = Base. safe_tail (axes (A)), Base. safe_tail (axes (R)) # handle d=1 manually
282- keep, Idefault = Broadcast. shapeindexer (indsRt)
283- if Base. reducedim1 (R, A)
284- i1 = first (Base. axes1 (R))
285- @inbounds for IA in CartesianIndices (indsAt)
286- IR = Broadcast. newindex (IA, keep, Idefault)
287- r = R[i1,IR]
288- m = means[i1,IR]
289- @simd for i in axes (A, 1 )
290- r += abs2 (A[i,IA] - m)
291- end
292- R[i1,IR] = r
293- end
294- else
295- @inbounds for IA in CartesianIndices (indsAt)
296- IR = Broadcast. newindex (IA, keep, Idefault)
297- @simd for i in axes (A, 1 )
298- R[i,IR] += abs2 (A[i,IA] - means[i,IR])
299- end
300- end
301- end
302- return R
262+ struct CentralizedAbs2Fun{T,S} <: Function
263+ mean:: S
303264end
265+ CentralizedAbs2Fun {T} (means) where {T} = CentralizedAbs2Fun {T,typeof(means)} (means)
266+ CentralizedAbs2Fun (means) = CentralizedAbs2Fun {typeof(means)} (means)
267+ CentralizedAbs2Fun (means, extrude) = CentralizedAbs2Fun {eltype(means)} (Broadcast. extrude (means))
268+ # Division is generally costly, but Julia is typically able to constant propagate a /1
269+ # and simply ensure we get the type right at no cost, allowing the division in-place later
270+ (f:: CentralizedAbs2Fun )(x) = abs2 .(x - f. mean)/ 1
271+ (f:: CentralizedAbs2Fun{<:Any,<:Broadcast.Extruded} )((i, x),) = abs2 .(x - Broadcast. _broadcast_getindex (f. mean, i))/ 1
272+ _doubled (x) = x+ x
273+ Base. mapreduce_empty (:: CentralizedAbs2Fun{T,<:Broadcast.Extruded} , :: typeof (Base. add_sum), :: Type{Tuple{_Any,S}} ) where {T<: Number , S<: Number , _Any} = _doubled (abs2 (zero (T)- zero (S)))/ 1
274+ Base. mapreduce_empty (:: CentralizedAbs2Fun{T,<:Broadcast.Extruded} , :: typeof (Base. add_sum), :: Type{Tuple{_Any, Union{Missing, S}}} ) where {T<: Number , S<: Number , _Any} = _doubled (abs2 (zero (T)- zero (S)))/ 1
275+ Base. mapreduce_empty (:: CentralizedAbs2Fun{T} , :: typeof (Base. add_sum), :: Type{S} ) where {T<: Number , S<: Number } = _doubled (abs2 (zero (T)- zero (S)))/ 1
276+ Base. mapreduce_empty (:: CentralizedAbs2Fun{T} , :: typeof (Base. add_sum), :: Type{Union{Missing, S}} ) where {T<: Number , S<: Number } = _doubled (abs2 (zero (T)- zero (S)))/ 1
277+
278+ centralize_sumabs2 (A:: AbstractArray , m) =
279+ sum (CentralizedAbs2Fun (m), A)
280+ centralize_sumabs2 (A:: AbstractArray , m:: AbstractArray , region) =
281+ sum (CentralizedAbs2Fun (m, true ), Base. PairsArray (A), dims= region)
282+ centralize_sumabs2! (R:: AbstractArray , A:: AbstractArray , means:: AbstractArray ) =
283+ sum! (CentralizedAbs2Fun (means, true ), R, Base. PairsArray (A))
284+
304285
305286function varm! (R:: AbstractArray{S} , A:: AbstractArray , m:: AbstractArray ; corrected:: Bool = true ) where S
306- if isempty (A)
287+ _checkm (R, m, ntuple (identity, Val (max (ndims (R), ndims (m)))))
288+ if isempty (A) || length (A) == 1 && corrected
307289 fill! (R, convert (S, NaN ))
308290 else
309- rn = div ( length (A), length (R )) - Int (corrected)
291+ rn = prod ( ntuple (d -> size (R, d) == 1 ? size (A, d) : 1 , Val ( max ( ndims (A), ndims (R))) )) - Int (corrected)
310292 centralize_sumabs2! (R, A, m)
311293 R .= R .* (1 // rn)
312294 end
@@ -339,15 +321,33 @@ over dimensions. In that case, `mean` must be an array with the same shape as
339321"""
340322varm (A:: AbstractArray , m:: AbstractArray ; corrected:: Bool = true , dims= :) = _varm (A, m, corrected, dims)
341323
342- _varm (A:: AbstractArray{T} , m, corrected:: Bool , region) where {T} =
343- varm! (Base. reducedim_init (t -> abs2 (t)/ 2 , + , A, region), A, m; corrected= corrected)
324+ _throw_mean_mismatch (A, m, region) = throw (DimensionMismatch (" axes of means ($(axes (m)) ) does not match reduction over $(region) of $(axes (A)) " ))
325+ function _checkm (A:: AbstractArray , m:: AbstractArray , region)
326+ for d in 1 : max (ndims (A), ndims (m))
327+ if d in region
328+ size (m, d) == 1 || _throw_mean_mismatch (A, m, region)
329+ else
330+ axes (m, d) == axes (A, d) || _throw_mean_mismatch (A, m, region)
331+ end
332+ end
333+ end
334+ function _varm (A:: AbstractArray , m, corrected:: Bool , region)
335+ _checkm (A, m, region)
336+ rn = prod (ntuple (d-> d in region ? size (A, d) : 1 , Val (ndims (A)))) - Int (corrected)
337+ R = centralize_sumabs2 (A, m, region)
338+ if rn <= 0
339+ R .= R ./ 0
340+ else
341+ R .= R .* 1 // rn # why use Rational?
342+ end
343+ return R
344+ end
344345
345346varm (A:: AbstractArray , m; corrected:: Bool = true ) = _varm (A, m, corrected, :)
346347
347348function _varm (A:: AbstractArray{T} , m, corrected:: Bool , :: Colon ) where T
348- n = length (A)
349- n == 0 && return oftype ((abs2 (zero (T)) + abs2 (zero (T)))/ 2 , NaN )
350- return centralize_sumabs2 (A, m) / (n - Int (corrected))
349+ rn = max (length (A) - Int (corrected), 0 )
350+ centralize_sumabs2 (A, m)/ rn
351351end
352352
353353
0 commit comments