Skip to content

Commit 2cfa6a5

Browse files
tkluckSacha0
authored andcommitted
Add fast implementation of find_next and find_prev for sparse vectors and matrices.
* Sparse vector/matrix: add fast implementation of find_next and find_prev Before this commit, find_next() will just use the default implementation of looping over each element. When find_next is called without a function filter as first argument, we *know* that semantics are to find elements x satisfying x != 0, so for sparse matrices/vectors, we may only loop over the stored elements. Some care must be taken for stored zero values; that's the reason for the indirection of _sparse_find_next (which only finds the next stored element) and the actual find_next (which does actual non-zero checks). * Optimized findnext() for sparse: update now that predicate needs to be explicit Since we now need explicit predicates [1], this optimization only works if we know that the predicate is a function that is false for zero values. As suggested in that pull request, we could find out by calling `f(zero(eltype(array)))` and hoping that `f` is pure, but I like being a bit more conservative and only applying this optimization only to the case where we *know* `f` is equal to `!iszero`. For clarity, this commit also renames the helper method _sparse_findnext() to _sparse_findnextnz(), because now that the predicate-less version doesn't exist anymore, the `nz` part isn't implicit anymore either.
1 parent 08620e5 commit 2cfa6a5

File tree

5 files changed

+109
-3
lines changed

5 files changed

+109
-3
lines changed

base/sparse/abstractsparse.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,24 @@ function Base.reinterpret(::Type, A::AbstractSparseArray)
4242
Try reinterpreting the value itself instead.
4343
""")
4444
end
45+
46+
# The following two methods should be overloaded by concrete types to avoid
47+
# allocating the I = find(...)
48+
_sparse_findnextnz(v::AbstractSparseArray, i::Integer) = (I = find(!iszero, v); n = searchsortedfirst(I, i); n<=length(I) ? I[n] : zero(indtype(v)))
49+
_sparse_findprevnz(v::AbstractSparseArray, i::Integer) = (I = find(!iszero, v); n = searchsortedlast(I, i); !iszero(n) ? I[n] : zero(indtype(v)))
50+
51+
function findnext(f::typeof(!iszero), v::AbstractSparseArray, i::Integer)
52+
j = _sparse_findnextnz(v, i)
53+
while !iszero(j) && !f(v[j])
54+
j = _sparse_findnextnz(v, j+1)
55+
end
56+
return j
57+
end
58+
59+
function findprev(f::typeof(!iszero), v::AbstractSparseArray, i::Integer)
60+
j = _sparse_findprevnz(v, i)
61+
while !iszero(j) && !f(v[j])
62+
j = _sparse_findprevnz(v, j-1)
63+
end
64+
return j
65+
end

base/sparse/sparse.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ import Base.LinAlg: mul!, ldiv!, rdiv!
1515
import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
1616
atan, atand, atanh, broadcast!, chol, conj!, cos, cosc, cosd, cosh, cospi, cot,
1717
cotd, coth, count, csc, cscd, csch, adjoint!, diag, diff, done, dot, eig,
18-
exp10, exp2, findn, floor, hash, indmin, inv, issymmetric, istril, istriu,
19-
log10, log2, lu, next, sec, secd, sech, show, sin,
20-
sinc, sind, sinh, sinpi, squeeze, start, sum, summary, tan,
18+
exp10, exp2, findn, findprev, findnext, floor, hash, indmin, inv,
19+
issymmetric, istril, istriu, log10, log2, lu, next, sec, secd, sech, show,
20+
sin, sinc, sind, sinh, sinpi, squeeze, start, sum, summary, tan,
2121
tand, tanh, trace, transpose!, tril!, triu!, trunc, vecnorm, abs, abs2,
2222
broadcast, ceil, complex, cond, conj, convert, copy, copyto!, adjoint, diagm,
2323
exp, expm1, factorize, find, findmax, findmin, findnz, float, getindex,

base/sparse/sparsematrix.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,6 +1315,42 @@ function findnz(S::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
13151315
return (I, J, V)
13161316
end
13171317

1318+
function _sparse_findnextnz(m::SparseMatrixCSC, i::Integer)
1319+
if i > length(m)
1320+
return zero(indtype(m))
1321+
end
1322+
row, col = Tuple(CartesianIndices(m)[i])
1323+
lo, hi = m.colptr[col], m.colptr[col+1]
1324+
n = searchsortedfirst(m.rowval, row, lo, hi-1, Base.Order.Forward)
1325+
if lo <= n <= hi-1
1326+
return LinearIndices(m)[m.rowval[n], col]
1327+
end
1328+
nextcol = findnext(c->(c>hi), m.colptr, col+1)
1329+
if iszero(nextcol)
1330+
return zero(indtype(m))
1331+
end
1332+
nextlo = m.colptr[nextcol-1]
1333+
return LinearIndices(m)[m.rowval[nextlo], nextcol-1]
1334+
end
1335+
1336+
function _sparse_findprevnz(m::SparseMatrixCSC, i::Integer)
1337+
if iszero(i)
1338+
return zero(indtype(m))
1339+
end
1340+
row, col = Tuple(CartesianIndices(m)[i])
1341+
lo, hi = m.colptr[col], m.colptr[col+1]
1342+
n = searchsortedlast(m.rowval, row, lo, hi-1, Base.Order.Forward)
1343+
if lo <= n <= hi-1
1344+
return LinearIndices(m)[m.rowval[n], col]
1345+
end
1346+
prevcol = findprev(c->(c<lo), m.colptr, col-1)
1347+
if iszero(prevcol)
1348+
return zero(indtype(m))
1349+
end
1350+
prevhi = m.colptr[prevcol+1]
1351+
return LinearIndices(m)[m.rowval[prevhi-1], prevcol]
1352+
end
1353+
13181354
import Base.Random.GLOBAL_RNG
13191355
function sprand_IJ(r::AbstractRNG, m::Integer, n::Integer, density::AbstractFloat)
13201356
((m < 0) || (n < 0)) && throw(ArgumentError("invalid Array dimensions"))

base/sparse/sparsevector.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,24 @@ function findnz(x::SparseVector{Tv,Ti}) where {Tv,Ti}
735735
return (I, V)
736736
end
737737

738+
function _sparse_findnextnz(v::SparseVector, i::Integer)
739+
n = searchsortedfirst(v.nzind, i)
740+
if n > length(v.nzind)
741+
return zero(indtype(v))
742+
else
743+
return v.nzind[n]
744+
end
745+
end
746+
747+
function _sparse_findprevnz(v::SparseVector, i::Integer)
748+
n = searchsortedlast(v.nzind, i)
749+
if iszero(n)
750+
return zero(indtype(v))
751+
else
752+
return v.nzind[n]
753+
end
754+
end
755+
738756
### Generic functions operating on AbstractSparseVector
739757

740758
### getindex

test/sparse/sparse.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2169,6 +2169,37 @@ end
21692169
@test count(SparseMatrixCSC(2, 2, Int[1, 2, 3], Int[1, 2], Bool[true, true, true])) == 2
21702170
end
21712171

2172+
@testset "sparse findprev/findnext operations" begin
2173+
2174+
x = [0,0,0,0,1,0,1,0,1,1,0]
2175+
x_sp = sparse(x)
2176+
2177+
for i=1:length(x)
2178+
@test findnext(!iszero, x,i) == findnext(!iszero, x_sp,i)
2179+
@test findprev(!iszero, x,i) == findprev(!iszero, x_sp,i)
2180+
end
2181+
2182+
y = [0 0 0 0 0;
2183+
1 0 1 0 0;
2184+
1 0 0 0 1;
2185+
0 0 1 0 0;
2186+
1 0 1 1 0]
2187+
y_sp = sparse(y)
2188+
2189+
for i=1:length(y)
2190+
@test findnext(!iszero, y,i) == findnext(!iszero, y_sp,i)
2191+
@test findprev(!iszero, y,i) == findprev(!iszero, y_sp,i)
2192+
end
2193+
2194+
z_sp = sparsevec(Dict(1=>1, 5=>1, 8=>0, 10=>1))
2195+
z = collect(z_sp)
2196+
2197+
for i=1:length(z)
2198+
@test findnext(!iszero, z,i) == findnext(!iszero, z_sp,i)
2199+
@test findprev(!iszero, z,i) == findprev(!iszero, z_sp,i)
2200+
end
2201+
end
2202+
21722203
# #20711
21732204
@testset "vec returns a view" begin
21742205
local A = sparse(Matrix(1.0I, 3, 3))

0 commit comments

Comments
 (0)