diff --git a/src/Statistics.jl b/src/Statistics.jl index 29aae101..0b6946e3 100644 --- a/src/Statistics.jl +++ b/src/Statistics.jl @@ -165,16 +165,22 @@ mean(A::AbstractArray; dims=:) = _mean(identity, A, dims) _mean_promote(x::T, y::S) where {T,S} = convert(promote_type(T, S), y) +# calls f(A[1]) twice +_promoted_sum(f, A::AbstractArray; init, dims) = sum(x -> _mean_promote(init, f(x)), A; dims) + # calls f(A[1]) once +_promoted_sum(f, A::AbstractVector; init, dims) = + sum(x -> _mean_promote(init, f(x)), @view A[begin+1:end]; init, dims) + # ::Dims is there to force specializing on Colon (as it is a Function) function _mean(f, A::AbstractArray, dims::Dims=:) where Dims - isempty(A) && return sum(f, A, dims=dims)/0 + isempty(A) && return sum(f, A; dims)/0 if dims === (:) n = length(A) else n = mapreduce(i -> size(A, i), *, unique(dims); init=1) end - x1 = f(first(A)) / 1 - result = sum(x -> _mean_promote(x1, f(x)), A, dims=dims) + init = f(first(A)) / 1 + result = _promoted_sum(f, A; init, dims) if dims === (:) return result / n else @@ -986,9 +992,9 @@ end require_one_based_indexing(v) n = length(v) - + @assert n > 0 # this case should never happen here - + m = alpha + p * (one(alpha) - alpha - beta) aleph = n*p + oftype(p, m) j = clamp(trunc(Int, aleph), 1, n-1) @@ -1001,7 +1007,7 @@ end a = v[j] b = v[j + 1] end - + if isfinite(a) && isfinite(b) return a + γ*(b-a) else diff --git a/test/runtests.jl b/test/runtests.jl index 00cdad10..e5118ffa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -152,7 +152,7 @@ end ≈ float(typemax(Int))) end let x = rand(10000) # mean should use sum's accurate pairwise algorithm - @test mean(x) == sum(x) / length(x) + @test mean(x) == sum((@view x[begin + 1:end]), init=x[1]) / length(x) end @test mean(Number[1, 1.5, 2+3im]) === 1.5+1im # mixed-type array @test mean(v for v in Number[1, 1.5, 2+3im]) === 1.5+1im @@ -162,6 +162,15 @@ end @test (@inferred mean(Iterators.filter(x -> true, Int[]))) === 0/0 @test (@inferred mean(Iterators.filter(x -> true, Float32[]))) === 0.f0/0 @test (@inferred mean(Iterators.filter(x -> true, Float64[]))) === 0/0 + # Check that mean does not call function argument an extra time + let _cnt = 0, N = 100, x = rand(Int, N) + f(x) = (_cnt += 1; x) + @test mean(1:N) == mean(f, 1:N) + @test _cnt == N + _cnt = 0 + @test mean(x) == mean(f, x) + @test _cnt == N + end end @testset "mean/median for ranges" begin