diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 0816dc898ee80..c992259a60e13 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -3099,5 +3099,16 @@ function _keepat!(a::AbstractVector, m::AbstractVector{Bool}) end end deleteat!(a, j:lastindex(a)) +end + +## 1-d circshift ## +function circshift!(a::AbstractVector, shift::Integer) + n = length(a) + n == 0 && return + shift = mod(shift, n) + shift == 0 && return + reverse!(a, 1, shift) + reverse!(a, shift+1, length(a)) + reverse!(a) return a end diff --git a/base/combinatorics.jl b/base/combinatorics.jl index daa534e068af6..2dd69fbce4c42 100644 --- a/base/combinatorics.jl +++ b/base/combinatorics.jl @@ -103,6 +103,18 @@ function swapcols!(a::AbstractMatrix, i, j) @inbounds a[k,i],a[k,j] = a[k,j],a[k,i] end end + +# swap rows i and j of a, in-place +function swaprows!(a::AbstractMatrix, i, j) + i == j && return + rows = axes(a,1) + @boundscheck i in rows || throw(BoundsError(a, (:,i))) + @boundscheck j in rows || throw(BoundsError(a, (:,j))) + for k in axes(a,2) + @inbounds a[i,k],a[j,k] = a[j,k],a[i,k] + end +end + # like permute!! applied to each row of a, in-place in a (overwriting p). function permutecols!!(a::AbstractMatrix, p::AbstractVector{<:Integer}) require_one_based_indexing(a, p) diff --git a/stdlib/SparseArrays/src/sparsematrix.jl b/stdlib/SparseArrays/src/sparsematrix.jl index da08e8a72981c..0ef0d06f854ab 100644 --- a/stdlib/SparseArrays/src/sparsematrix.jl +++ b/stdlib/SparseArrays/src/sparsematrix.jl @@ -2189,7 +2189,7 @@ end getindex(A::AbstractSparseMatrixCSC, I::Tuple{Integer,Integer}) = getindex(A, I[1], I[2]) function getindex(A::AbstractSparseMatrixCSC{T}, i0::Integer, i1::Integer) where T - if !(1 <= i0 <= size(A, 1) && 1 <= i1 <= size(A, 2)); throw(BoundsError()); end + @boundscheck checkbounds(A, i0, i1) r1 = Int(getcolptr(A)[i1]) r2 = Int(getcolptr(A)[i1+1]-1) (r1 > r2) && return zero(T) @@ -3840,3 +3840,91 @@ end circshift!(O::AbstractSparseMatrixCSC, X::AbstractSparseMatrixCSC, (r,)::Base.DimsInteger{1}) = circshift!(O, X, (r,0)) circshift!(O::AbstractSparseMatrixCSC, X::AbstractSparseMatrixCSC, r::Real) = circshift!(O, X, (Integer(r),0)) + +## swaprows! / swapcols! +macro swap(a, b) + esc(:(($a, $b) = ($b, $a))) +end + +function Base.swapcols!(A::AbstractSparseMatrixCSC, i, j) + i == j && return + + # For simplicitly, let i denote the smaller of the two columns + j < i && @swap(i, j) + + colptr = getcolptr(A) + irow = colptr[i]:(colptr[i+1]-1) + jrow = colptr[j]:(colptr[j+1]-1) + + function rangeexchange!(arr, irow, jrow) + if length(irow) == length(jrow) + for (a, b) in zip(irow, jrow) + @inbounds @swap(arr[i], arr[j]) + end + return + end + # This is similar to the triple-reverse tricks for + # circshift!, except that we have three ranges here, + # so it ends up being 4 reverse calls (but still + # 2 overall reversals for the memory range). Like + # circshift!, there's also a cycle chasing algorithm + # with optimal memory complexity, but the performance + # tradeoffs against this implementation are non-trivial, + # so let's just do this simple thing for now. + # See https://github.com/JuliaLang/julia/pull/42676 for + # discussion of circshift!-like algorithms. + reverse!(@view arr[irow]) + reverse!(@view arr[jrow]) + reverse!(@view arr[(last(irow)+1):(first(jrow)-1)]) + reverse!(@view arr[first(irow):last(jrow)]) + end + rangeexchange!(rowvals(A), irow, jrow) + rangeexchange!(nonzeros(A), irow, jrow) + + if length(irow) != length(jrow) + @inbounds colptr[i+1:j] .+= length(jrow) - length(irow) + end + return nothing +end + +function Base.swaprows!(A::AbstractSparseMatrixCSC, i, j) + # For simplicitly, let i denote the smaller of the two rows + j < i && @swap(i, j) + + rows = rowvals(A) + vals = nonzeros(A) + for col = 1:size(A, 2) + rr = nzrange(A, col) + iidx = searchsortedfirst(@view(rows[rr]), i) + has_i = iidx <= length(rr) && rows[rr[iidx]] == i + + jrange = has_i ? (iidx:last(rr)) : rr + jidx = searchsortedlast(@view(rows[jrange]), j) + has_j = jidx != 0 && rows[jrange[jidx]] == j + + if !has_j && !has_i + # Has neither row - nothing to do + continue + elseif has_i && has_j + # This column had both i and j rows - swap them + @swap(vals[rr[iidx]], vals[jrange[jidx]]) + elseif has_i + # Update the rowval and then rotate both nonzeros + # and the remaining rowvals into the correct place + rows[rr[iidx]] = j + jidx == 0 && continue + rotate_range = rr[iidx]:jrange[jidx] + circshift!(@view(vals[rotate_range]), -1) + circshift!(@view(rows[rotate_range]), -1) + else + # Same as i, but in the opposite direction + @assert has_j + rows[jrange[jidx]] = i + iidx > length(rr) && continue + rotate_range = rr[iidx]:jrange[jidx] + circshift!(@view(vals[rotate_range]), 1) + circshift!(@view(rows[rotate_range]), 1) + end + end + return nothing +end diff --git a/stdlib/SparseArrays/src/sparsevector.jl b/stdlib/SparseArrays/src/sparsevector.jl index 55ad738a7eb77..e2eaf7cdfe143 100644 --- a/stdlib/SparseArrays/src/sparsevector.jl +++ b/stdlib/SparseArrays/src/sparsevector.jl @@ -2085,18 +2085,6 @@ function fill!(A::Union{SparseVector, AbstractSparseMatrixCSC}, x) return A end - - -# in-place swaps (dense) blocks start:split and split+1:fin in col -function _swap!(col::AbstractVector, start::Integer, fin::Integer, split::Integer) - split == fin && return - reverse!(col, start, split) - reverse!(col, split + 1, fin) - reverse!(col, start, fin) - return -end - - # in-place shifts a sparse subvector by r. Used also by sparsematrix.jl function subvector_shifter!(R::AbstractVector, V::AbstractVector, start::Integer, fin::Integer, m::Integer, r::Integer) split = fin @@ -2110,16 +2098,14 @@ function subvector_shifter!(R::AbstractVector, V::AbstractVector, start::Integer end end # ...but rowval should be sorted within columns - _swap!(R, start, fin, split) - _swap!(V, start, fin, split) + circshift!(@view(R[start:fin]), split-start+1) + circshift!(@view(V[start:fin]), split-start+1) end - function circshift!(O::SparseVector, X::SparseVector, (r,)::Base.DimsInteger{1}) copy!(O, X) subvector_shifter!(nonzeroinds(O), nonzeros(O), 1, length(nonzeroinds(O)), length(O), mod(r, length(X))) return O end - circshift!(O::SparseVector, X::SparseVector, r::Real,) = circshift!(O, X, (Integer(r),)) diff --git a/stdlib/SparseArrays/test/sparse.jl b/stdlib/SparseArrays/test/sparse.jl index ff955e967b433..df9afba0427fb 100644 --- a/stdlib/SparseArrays/test/sparse.jl +++ b/stdlib/SparseArrays/test/sparse.jl @@ -3294,4 +3294,26 @@ end @test eval(Meta.parse(repr(m))) == m end +using Base: swaprows!, swapcols! +@testset "swaprows!, swapcols!" begin + S = sparse( + [ 0 0 0 0 0 0 + 0 -1 1 1 0 0 + 0 0 0 1 1 0 + 0 0 1 1 1 -1]) + + for (f!, i, j) in + ((swaprows!, 1, 2), # Test swapping rows where one row is fully sparse + (swaprows!, 2, 3), # Test swapping rows of unequal length + (swaprows!, 2, 4), # Test swapping non-adjacent rows + (swapcols!, 1, 2), # Test swapping columns where one column is fully sparse + (swapcols!, 2, 3), # Test swapping coulms of unequal length + (swapcols!, 2, 4)) # Test swapping non-adjacent columns + Scopy = copy(S) + Sdense = Array(S) + f!(Scopy, i, j); f!(Sdense, i, j) + @test Scopy == Sdense + end +end + end # module