diff --git a/base/array.jl b/base/array.jl index 3dbaa5957e885..aa0042429075c 100644 --- a/base/array.jl +++ b/base/array.jl @@ -1992,6 +1992,19 @@ function findall(A) end collect(first(p) for p in _pairs(A) if last(p) != 0) end +# Allocating result upfront is faster (possible only when collection can be iterated twice) +function findall(A::AbstractArray{Bool}) + n = count(A) + I = Vector{eltype(keys(A))}(uninitialized, n) + cnt = 1 + for (i,a) in pairs(A) + if a + I[cnt] = i + cnt += 1 + end + end + I +end findall(x::Bool) = x ? [1] : Vector{Int}() findall(testf::Function, x::Number) = testf(x) ? [1] : Vector{Int}() diff --git a/test/arrayops.jl b/test/arrayops.jl index 101c465037000..df224959f76c6 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -506,6 +506,19 @@ end @test findlast(!iszero, g3) == CartesianIndex(9, 2) @test findfirst(equalto(2), g3) === nothing @test findlast(equalto(2), g3) === nothing + + g4 = (x for x in [true, false, true, false]) + @test findall(g4) == [1, 3] + @test findfirst(g4) == 1 + @test findlast(g4) == 3 + + g5 = (x for x in [true false; true false]) + @test findall(g5) == findall(collect(g5)) + @test findfirst(g5) == CartesianIndex(1, 1) + @test findlast(g5) == CartesianIndex(2, 1) + + @test findfirst(x for x in Bool[]) === nothing + @test findlast(x for x in Bool[]) === nothing end @testset "findmin findmax argmin argmax" begin