Skip to content

Commit d8d66fe

Browse files
maxbennedichKristofferC
authored andcommitted
Faster findall for bitarrays (#29888)
* Faster findall for bitarrays * Add a few tests for findall for bitarrays * Code review updates for bitarray findall (#29888) (cherry picked from commit 96ce5ba)
1 parent 3b160d5 commit d8d66fe

File tree

2 files changed

+79
-25
lines changed

2 files changed

+79
-25
lines changed

base/bitarray.jl

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1510,37 +1510,70 @@ function findprev(testf::Function, B::BitArray, start::Integer)
15101510
end
15111511
#findlast(testf::Function, B::BitArray) = findprev(testf, B, 1) ## defined in array.jl
15121512

1513+
# findall helper functions
1514+
# Generic case (>2 dimensions)
1515+
function allindices!(I, B::BitArray)
1516+
ind = first(keys(B))
1517+
for k = 1:length(B)
1518+
I[k] = ind
1519+
ind = nextind(B, ind)
1520+
end
1521+
end
1522+
1523+
# Optimized case for vector
1524+
function allindices!(I, B::BitVector)
1525+
I[:] .= 1:length(B)
1526+
end
1527+
1528+
# Optimized case for matrix
1529+
function allindices!(I, B::BitMatrix)
1530+
k = 1
1531+
for c = 1:size(B,2), r = 1:size(B,1)
1532+
I[k] = CartesianIndex(r, c)
1533+
k += 1
1534+
end
1535+
end
1536+
1537+
@inline _overflowind(i1, irest::Tuple{}, size) = (i1, irest)
1538+
@inline function _overflowind(i1, irest, size)
1539+
i2 = irest[1]
1540+
while i1 > size[1]
1541+
i1 -= size[1]
1542+
i2 += 1
1543+
end
1544+
i2, irest = _overflowind(i2, tail(irest), tail(size))
1545+
return (i1, (i2, irest...))
1546+
end
1547+
1548+
@inline _toind(i1, irest::Tuple{}) = i1
1549+
@inline _toind(i1, irest) = CartesianIndex(i1, irest...)
1550+
15131551
function findall(B::BitArray)
1514-
l = length(B)
15151552
nnzB = count(B)
1516-
ind = first(keys(B))
1517-
I = Vector{typeof(ind)}(undef, nnzB)
1553+
I = Vector{eltype(keys(B))}(undef, nnzB)
15181554
nnzB == 0 && return I
1555+
nnzB == length(B) && (allindices!(I, B); return I)
15191556
Bc = B.chunks
1520-
Icount = 1
1521-
for i = 1:length(Bc)-1
1522-
u = UInt64(1)
1523-
c = Bc[i]
1524-
for j = 1:64
1525-
if c & u != 0
1526-
I[Icount] = ind
1527-
Icount += 1
1528-
end
1529-
ind = nextind(B, ind)
1530-
u <<= 1
1531-
end
1532-
end
1533-
u = UInt64(1)
1534-
c = Bc[end]
1535-
for j = 0:_mod64(l-1)
1536-
if c & u != 0
1537-
I[Icount] = ind
1538-
Icount += 1
1557+
Bs = size(B)
1558+
Bi = i1 = i = 1
1559+
irest = ntuple(one, ndims(B) - 1)
1560+
c = Bc[1]
1561+
@inbounds while true
1562+
while c == 0
1563+
Bi == length(Bc) && return I
1564+
i1 += 64
1565+
Bi += 1
1566+
c = Bc[Bi]
15391567
end
1540-
ind = nextind(B, ind)
1541-
u <<= 1
1568+
1569+
tz = trailing_zeros(c)
1570+
c = _blsr(c)
1571+
1572+
i1, irest = _overflowind(i1 + tz, irest, Bs)
1573+
I[i] = _toind(i1, irest)
1574+
i += 1
1575+
i1 -= tz
15421576
end
1543-
return I
15441577
end
15451578

15461579
# For performance

test/bitarray.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,9 +1159,30 @@ timesofar("datamove")
11591159
@test findnextnot((.~(b1 >> i)) .⊻ submask, j) == i+1
11601160
end
11611161

1162+
# Do a few more thorough tests for findall
11621163
b1 = bitrand(n1, n2)
11631164
@check_bit_operation findall(b1) Vector{CartesianIndex{2}}
11641165
@check_bit_operation findall(!iszero, b1) Vector{CartesianIndex{2}}
1166+
1167+
# tall-and-skinny (test index overflow logic in findall)
1168+
@check_bit_operation findall(bitrand(1, 1, 1, 250)) Vector{CartesianIndex{4}}
1169+
1170+
# empty dimensions
1171+
@check_bit_operation findall(bitrand(0, 0, 10)) Vector{CartesianIndex{3}}
1172+
1173+
# sparse (test empty 64-bit chunks in findall)
1174+
b1 = falses(8, 8, 8)
1175+
b1[3,3,3] = b1[6,6,6] = true
1176+
@check_bit_operation findall(b1) Vector{CartesianIndex{3}}
1177+
1178+
# BitArrays of various dimensions
1179+
for dims = 0:8
1180+
t = Tuple(fill(2, dims))
1181+
ret_type = Vector{dims == 1 ? Int : CartesianIndex{dims}}
1182+
@check_bit_operation findall(trues(t)) ret_type
1183+
@check_bit_operation findall(falses(t)) ret_type
1184+
@check_bit_operation findall(bitrand(t)) ret_type
1185+
end
11651186
end
11661187

11671188
timesofar("find")

0 commit comments

Comments
 (0)