diff --git a/base/reduce.jl b/base/reduce.jl index c132ef385358d..bd29d7eb3d889 100644 --- a/base/reduce.jl +++ b/base/reduce.jl @@ -29,6 +29,12 @@ mul_prod(x::BitSignedSmall, y::BitSignedSmall) = Int(x) * Int(y) mul_prod(x::BitUnsignedSmall, y::BitUnsignedSmall) = UInt(x) * UInt(y) mul_prod(x::Real, y::Real)::Real = x * y +and_all(x, y) = (x && y)::Bool +or_any(x, y) = (x || y)::Bool +# As a performance optimization, avoid runtime branches: +and_all(x::Bool, y::Bool) = (x & y)::Bool +or_any(x::Bool, y::Bool) = (x | y)::Bool + ## foldl && mapfoldl function mapfoldl_impl(f::F, op::OP, nt, itr) where {F,OP} @@ -338,6 +344,8 @@ reduce_empty(::typeof(*), ::Type{T}) where {T} = one(T) reduce_empty(::typeof(*), ::Type{<:AbstractChar}) = "" reduce_empty(::typeof(&), ::Type{Bool}) = true reduce_empty(::typeof(|), ::Type{Bool}) = false +reduce_empty(::typeof(and_all), ::Type{T}) where {T} = true +reduce_empty(::typeof(or_any), ::Type{T}) where {T} = false reduce_empty(::typeof(add_sum), ::Type{T}) where {T} = reduce_empty(+, T) reduce_empty(::typeof(add_sum), ::Type{T}) where {T<:BitSignedSmall} = zero(Int) diff --git a/base/reducedim.jl b/base/reducedim.jl index 0478afe1a46b6..ba3434494bc0b 100644 --- a/base/reducedim.jl +++ b/base/reducedim.jl @@ -45,7 +45,7 @@ end initarray!(a::AbstractArray{T}, f, ::Union{typeof(min),typeof(max),typeof(_extrema_rf)}, init::Bool, src::AbstractArray) where {T} = (init && mapfirst!(f, a, src); a) -for (Op, initval) in ((:(typeof(&)), true), (:(typeof(|)), false)) +for (Op, initval) in ((:(typeof(and_all)), true), (:(typeof(or_any)), false)) @eval initarray!(a::AbstractArray, ::Any, ::$(Op), init::Bool, src::AbstractArray) = (init && fill!(a, $initval); a) end @@ -173,6 +173,10 @@ end reducedim_init(f::Union{typeof(abs),typeof(abs2)}, op::typeof(max), A::AbstractArray{T}, region) where {T} = reducedim_initarray(A, region, zero(f(zero(T))), _realtype(f, T)) +reducedim_init(f, op::typeof(and_all), A::AbstractArrayOrBroadcasted, region) = reducedim_initarray(A, region, true) +reducedim_init(f, op::typeof(or_any), A::AbstractArrayOrBroadcasted, region) = reducedim_initarray(A, region, false) + +# These definitions are wrong in general; Cf. JuliaLang/julia#45562 reducedim_init(f, op::typeof(&), A::AbstractArrayOrBroadcasted, region) = reducedim_initarray(A, region, true) reducedim_init(f, op::typeof(|), A::AbstractArrayOrBroadcasted, region) = reducedim_initarray(A, region, false) @@ -883,13 +887,13 @@ julia> A = [true false; true false] 1 0 1 0 -julia> all!([1; 1], A) -2-element Vector{Int64}: +julia> all!(Bool[1; 1], A) +2-element Vector{Bool}: 0 0 -julia> all!([1 1], A) -1×2 Matrix{Int64}: +julia> all!(Bool[1 1], A) +1×2 Matrix{Bool}: 1 0 ``` """ @@ -958,13 +962,13 @@ julia> A = [true false; true false] 1 0 1 0 -julia> any!([1; 1], A) -2-element Vector{Int64}: +julia> any!(Bool[1; 1], A) +2-element Vector{Bool}: 1 1 -julia> any!([1 1], A) -1×2 Matrix{Int64}: +julia> any!(Bool[1 1], A) +1×2 Matrix{Bool}: 1 0 ``` """ @@ -994,7 +998,7 @@ _all(a, ::Colon) = _all(identity, a, :) for (fname, op) in [(:sum, :add_sum), (:prod, :mul_prod), (:maximum, :max), (:minimum, :min), - (:all, :&), (:any, :|), + (:all, :and_all), (:any, :or_any), (:extrema, :_extrema_rf)] fname! = Symbol(fname, '!') _fname = Symbol('_', fname) diff --git a/test/reduce.jl b/test/reduce.jl index f5140c8a34bd9..bb54462bb78f0 100644 --- a/test/reduce.jl +++ b/test/reduce.jl @@ -681,6 +681,27 @@ end end end +@testset "issue #45562" begin + @test all([true, true, true], dims = 1) == [true] + @test any([true, true, true], dims = 1) == [true] + @test_throws TypeError all([3, 3, 3], dims = 1) + @test_throws TypeError any([3, 3, 3], dims = 1) + @test_throws TypeError all(Any[true, 3, 3], dims = 1) + @test_throws TypeError any(Any[false, 3, 3], dims = 1) + @test_throws TypeError all([1, 1, 1], dims = 1) + @test_throws TypeError any([0, 0, 0], dims = 1) + @test_throws TypeError all!([false], [3, 3, 3]) + @test_throws TypeError any!([false], [3, 3, 3]) + @test_throws TypeError all!([false], Any[true, 3, 3]) + @test_throws TypeError any!([false], Any[false, 3, 3]) + @test_throws TypeError all!([false], [1, 1, 1]) + @test_throws TypeError any!([false], [0, 0, 0]) + @test reduce(|, Bool[]) == false + @test reduce(&, Bool[]) == true + @test reduce(|, Bool[], dims=1) == [false] + @test reduce(&, Bool[], dims=1) == [true] +end + # issue #45748 @testset "foldl's stability for nested Iterators" begin a = Iterators.flatten((1:3, 1:3))