From 5a16dcb3043d232cbd9ad565e4a84115a48c0c0b Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 9 Feb 2024 15:29:09 -0500 Subject: [PATCH 01/17] [BlockSparseArrays] with mismatched blocking --- .../src/BlockSparseArrays.jl | 1 + .../src/abstractblocksparsearray/map.jl | 4 +- .../map_mismatched_blocking.jl | 79 +++++++++++++++++++ 3 files changed, 81 insertions(+), 3 deletions(-) create mode 100644 NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl index 58aad7e425..91af0516fc 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl @@ -12,6 +12,7 @@ include("abstractblocksparsearray/sparsearrayinterface.jl") include("abstractblocksparsearray/linearalgebra.jl") include("abstractblocksparsearray/broadcast.jl") include("abstractblocksparsearray/map.jl") +include("abstractblocksparsearray/map_mismatched_blocking.jl") include("blocksparsearray/defaults.jl") include("blocksparsearray/blocksparsearray.jl") include("BlockArraysExtensions/BlockArraysExtensions.jl") diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl index 4033875a44..78a07626b8 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl @@ -19,9 +19,7 @@ function SparseArrayInterface.sparse_map!( # map based on the blocks. map!(f, blocks(a_dest), blocks.(a_srcs)...) else - # Else, loop over all sparse elements naively. - # TODO: Make sure this is optimized, taking advantage of sparsity. - sparse_map!(SparseArrayStyle(Val(ndims(a_dest))), f, a_dest, a_srcs...) + map_mismatched_blocking!(f, a_dest, a_srcs...) end return a_dest end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl new file mode 100644 index 0000000000..b66c03b87f --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl @@ -0,0 +1,79 @@ +using BlockArrays: Block, BlockedUnitRange, block, blockindex, blocks, blocksize, findblock, findblockindex +using ..SparseArrayInterface: stored_indices + +function blockdiagonal(f!, elt::Type, axes::Tuple) + a = BlockSparseArray{elt}(axes) + for i in 1:minimum(blocksize(a)) + b = Block(ntuple(Returns(i), ndims(a))) + a[b] = f!(a[b]) + end + return a +end + +function cartesianindices(axes::Tuple, b::Block) + return CartesianIndices(ntuple(dim -> axes[dim][Tuple(b)[dim]], length(axes))) +end + +function blockindexrange(axis::BlockedUnitRange, r::UnitRange) + bi1 = findblockindex(axis, first(r)) + bi2 = findblockindex(axis, last(r)) + b = block(bi1) + # Range must fall within a single block. + @assert b == block(bi2) + i1 = blockindex(bi1) + i2 = blockindex(bi2) + return b[i1:i2] +end + +function blockindexrange(axes::Tuple, I::CartesianIndices) + brs = blockindexrange.(axes, I.indices) + b = Block(block.(brs)) + rs = map(br -> only(br.indices), brs) + return b[rs...] +end + +function blockindexrange(a::AbstractArray, I::CartesianIndices) + return blockindexrange(axes(a), I) +end + +function cartesianindices(a::AbstractArray, b::Block) + return cartesianindices(axes(a), b) +end + +# Output which blocks of `axis` are contained within the unit range `range`. +# The start and end points must match. +function findblocks(axis::AbstractUnitRange, range::AbstractUnitRange) + # TODO: Add a test that the start and end points of the ranges match. + return findblock(axis, first(range)):findblock(axis, last(range)) +end + +function block_stored_indices(a::AbstractArray) + return Block.(Tuple.(stored_indices(blocks(a)))) +end + +############################################################## +using BlockArrays: BlockArrays, BlockRange + +function map_mismatched_blocking!(f, a_dest::AbstractArray, a_src::AbstractArray) + # Create a common set of axes with a blocking that includes the + # blocking of `a_dest` and `a_src`. + matching_axes = BlockArrays.combine_blockaxes.(axes(a_dest), axes(a_src)) + for b in block_stored_indices(a_src) + # Get the subblocks of the matching axes + # TODO: `union` all `subblocks` of all `a_src` and `a_dest`. + subblocks = BlockRange(ntuple(ndims(a_dest)) do dim + findblocks(matching_axes[dim], axes(a_src, dim)[Tuple(b)[dim]]) + end) + for subblock in subblocks + I = cartesianindices(matching_axes, subblock) + I_dest = blockindexrange(a_dest, I) + I_src = blockindexrange(a_src, I) + + # TODO: Broken, need to fix. + # map!(f, view(a_dest, I_dest), view(a_src, I_src)) + + map!(f, view(a_dest, I_dest), a_src[I_src]) + end + end + return a_dest +end From f108102799cccfb2326227c415454d2622ca7f5d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 9 Feb 2024 15:47:57 -0500 Subject: [PATCH 02/17] Format --- .../map_mismatched_blocking.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl index b66c03b87f..6ec5de3c33 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl @@ -1,4 +1,5 @@ -using BlockArrays: Block, BlockedUnitRange, block, blockindex, blocks, blocksize, findblock, findblockindex +using BlockArrays: + Block, BlockedUnitRange, block, blockindex, blocks, blocksize, findblock, findblockindex using ..SparseArrayInterface: stored_indices function blockdiagonal(f!, elt::Type, axes::Tuple) @@ -61,9 +62,11 @@ function map_mismatched_blocking!(f, a_dest::AbstractArray, a_src::AbstractArray for b in block_stored_indices(a_src) # Get the subblocks of the matching axes # TODO: `union` all `subblocks` of all `a_src` and `a_dest`. - subblocks = BlockRange(ntuple(ndims(a_dest)) do dim - findblocks(matching_axes[dim], axes(a_src, dim)[Tuple(b)[dim]]) - end) + subblocks = BlockRange( + ntuple(ndims(a_dest)) do dim + findblocks(matching_axes[dim], axes(a_src, dim)[Tuple(b)[dim]]) + end, + ) for subblock in subblocks I = cartesianindices(matching_axes, subblock) I_dest = blockindexrange(a_dest, I) From a2aa734e3d6df304482b576d3ae720777359e343 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sun, 18 Feb 2024 14:19:45 -0500 Subject: [PATCH 03/17] Progress towards slicing --- .../src/BlockSparseArrays.jl | 1 + .../abstractblocksparsearray.jl | 29 +++--- .../abstractblocksparsearray/arraylayouts.jl | 18 ++-- .../map_mismatched_blocking.jl | 18 +++- .../wrappedabstractblocksparsearray.jl | 25 ++++- .../src/blocksparsearray/blocksparsearray.jl | 7 +- .../blocksparsearrayinterface/arraylayouts.jl | 16 ++++ .../blocksparsearrayinterface.jl | 91 +++++++++++-------- 8 files changed, 130 insertions(+), 75 deletions(-) create mode 100644 NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/arraylayouts.jl diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl index 91af0516fc..c88c69e2ff 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl @@ -3,6 +3,7 @@ include("blocksparsearrayinterface/blocksparsearrayinterface.jl") include("blocksparsearrayinterface/linearalgebra.jl") include("blocksparsearrayinterface/blockzero.jl") include("blocksparsearrayinterface/broadcast.jl") +include("blocksparsearrayinterface/arraylayouts.jl") include("abstractblocksparsearray/abstractblocksparsearray.jl") include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl") include("abstractblocksparsearray/abstractblocksparsematrix.jl") diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl index 275bf7312e..5c60395b4b 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl @@ -19,20 +19,21 @@ function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N}) w return blocksparse_getindex(a, I...) end -# Fix ambiguity error with `BlockArrays`. -function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Block{N}) where {N} - return ArrayLayouts.layout_getindex(a, I) -end - -# Fix ambiguity error with `BlockArrays`. -function Base.getindex(a::AbstractBlockSparseArray{<:Any,1}, I::Block{1}) - return ArrayLayouts.layout_getindex(a, I) -end - -# Fix ambiguity error with `BlockArrays`. -function Base.getindex(a::AbstractBlockSparseArray, I::Vararg{AbstractVector}) - return blocksparse_getindex(a, I...) -end +## # Fix ambiguity error with `BlockArrays`. +## function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Block{N}) where {N} +## return ArrayLayouts.layout_getindex(a, I) +## end +## +## # Fix ambiguity error with `BlockArrays`. +## function Base.getindex(a::AbstractBlockSparseArray{<:Any,1}, I::Block{1}) +## return ArrayLayouts.layout_getindex(a, I) +## end +## +## # Fix ambiguity error with `BlockArrays`. +## function Base.getindex(a::AbstractBlockSparseArray, I::Vararg{AbstractVector}) +## ## return blocksparse_getindex(a, I...) +## return ArrayLayouts.layout_getindex(a, I...) +## end # Specialized in order to fix ambiguity error with `BlockArrays`. function Base.setindex!( diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl index 7e768bc73a..d8e79ba743 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl @@ -1,10 +1,9 @@ -using ArrayLayouts: ArrayLayouts, MemoryLayout, MatMulMatAdd, MulAdd +using ArrayLayouts: ArrayLayouts, MemoryLayout, MulAdd using BlockArrays: BlockLayout using ..SparseArrayInterface: SparseLayout using LinearAlgebra: mul! -# TODO: Generalize to `BlockSparseArrayLike`. -function ArrayLayouts.MemoryLayout(arraytype::Type{<:AbstractBlockSparseArray}) +function ArrayLayouts.MemoryLayout(arraytype::Type{<:BlockSparseArrayLike}) outer_layout = typeof(MemoryLayout(blockstype(arraytype))) inner_layout = typeof(MemoryLayout(blocktype(arraytype))) return BlockLayout{outer_layout,inner_layout}() @@ -16,14 +15,9 @@ function Base.similar( return similar(BlockSparseArray{elt}, axes) end -function ArrayLayouts.materialize!( - m::MatMulMatAdd{ - <:BlockLayout{<:SparseLayout}, - <:BlockLayout{<:SparseLayout}, - <:BlockLayout{<:SparseLayout}, - }, -) - α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C - mul!(a_dest, a1, a2, α, β) +# Materialize a SubArray view. +function ArrayLayouts.sub_materialize(layout::BlockLayout{<:SparseLayout}, a, axes) + a_dest = BlockSparseArray{eltype(a)}(axes) + a_dest .= a return a_dest end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl index 6ec5de3c33..7612110519 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl @@ -15,6 +15,7 @@ function cartesianindices(axes::Tuple, b::Block) return CartesianIndices(ntuple(dim -> axes[dim][Tuple(b)[dim]], length(axes))) end +# Get the range within a block. function blockindexrange(axis::BlockedUnitRange, r::UnitRange) bi1 = findblockindex(axis, first(r)) bi2 = findblockindex(axis, last(r)) @@ -26,13 +27,27 @@ function blockindexrange(axis::BlockedUnitRange, r::UnitRange) return b[i1:i2] end -function blockindexrange(axes::Tuple, I::CartesianIndices) +# Fallback for non-blocked ranges. +function blockindexrange(axis::AbstractUnitRange, r::UnitRange) + return r +end + +function blockindexrange( + axes::Tuple{Vararg{BlockedUnitRange,N}}, I::CartesianIndices{N} +) where {N} brs = blockindexrange.(axes, I.indices) b = Block(block.(brs)) rs = map(br -> only(br.indices), brs) return b[rs...] end +# Fallback for non-blocked ranges. +function blockindexrange( + axes::Tuple{Vararg{AbstractUnitRange,N}}, I::CartesianIndices{N} +) where {N} + return I +end + function blockindexrange(a::AbstractArray, I::CartesianIndices) return blockindexrange(axes(a), I) end @@ -59,6 +74,7 @@ function map_mismatched_blocking!(f, a_dest::AbstractArray, a_src::AbstractArray # Create a common set of axes with a blocking that includes the # blocking of `a_dest` and `a_src`. matching_axes = BlockArrays.combine_blockaxes.(axes(a_dest), axes(a_src)) + # TODO: Also include `block_stored_indices(a_dest)`! for b in block_stored_indices(a_src) # Get the subblocks of the matching axes # TODO: `union` all `subblocks` of all `a_src` and `a_dest`. diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index 5a5a84f6fc..ef7f7edeea 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -1,4 +1,13 @@ -using Adapt: WrappedArray +using BlockArrays: BlockedUnitRange, blockedrange +using SplitApplyCombine: groupcount +# TODO: Write a specialized version for `indices::AbstractUnitRange`. +# TODO: Write a generic version for blocked unit ranges (like `GradedAxes`). +function sub_unitrange(a::BlockedUnitRange, indices) + indices = sort(indices) + return blockedrange(collect(groupcount(i -> findblock(a, i), indices))) +end + +using Adapt: Adapt, WrappedArray const WrappedAbstractBlockSparseArray{T,N,A} = WrappedArray{ T,N,<:AbstractBlockSparseArray,<:AbstractBlockSparseArray{T,N} @@ -8,14 +17,20 @@ const BlockSparseArrayLike{T,N} = Union{ <:AbstractBlockSparseArray{T,N},<:WrappedAbstractBlockSparseArray{T,N} } +# AbstractArray interface +function Base.axes(a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray}) + return ntuple(i -> sub_unitrange(axes(parent(a), i), a.indices[i]), ndims(a)) +end + # BlockArrays `AbstractBlockArray` interface BlockArrays.blocks(a::BlockSparseArrayLike) = blocksparse_blocks(a) -blocktype(a::BlockSparseArrayLike) = eltype(blocks(a)) - -# TODO: Use `parenttype` from `Unwrap`. -blockstype(arraytype::Type{<:WrappedAbstractBlockSparseArray}) = parenttype(arraytype) +# TODO: Use `TypeParameterAccessors`. +function blockstype(arraytype::Type{<:WrappedAbstractBlockSparseArray}) + return blockstype(Adapt.parent_type(arraytype)) +end +blocktype(a::BlockSparseArrayLike) = eltype(blocks(a)) blocktype(arraytype::Type{<:BlockSparseArrayLike}) = eltype(blockstype(arraytype)) function Base.setindex!(a::BlockSparseArrayLike{<:Any,N}, value, I::BlockIndex{N}) where {N} diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl index 439a062b24..b22ee08a72 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl @@ -101,10 +101,11 @@ end # Base `AbstractArray` interface Base.axes(a::BlockSparseArray) = a.axes -# BlockArrays `AbstractBlockArray` interface -BlockArrays.blocks(a::BlockSparseArray) = a.blocks +# BlockArrays `AbstractBlockArray` interface. +# This is used by `blocks(::BlockSparseArrayLike)`. +blocksparse_blocks(a::BlockSparseArray) = a.blocks -# TODO: Use `SetParameters`. +# TODO: Use `TypeParameterAccessors`. blockstype(::Type{<:BlockSparseArray{<:Any,<:Any,<:Any,B}}) where {B} = B # Base interface diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/arraylayouts.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/arraylayouts.jl new file mode 100644 index 0000000000..bf4d515a34 --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/arraylayouts.jl @@ -0,0 +1,16 @@ +using ArrayLayouts: ArrayLayouts, MatMulMatAdd +using BlockArrays: BlockLayout +using ..SparseArrayInterface: SparseLayout +using LinearAlgebra: mul! + +function ArrayLayouts.materialize!( + m::MatMulMatAdd{ + <:BlockLayout{<:SparseLayout}, + <:BlockLayout{<:SparseLayout}, + <:BlockLayout{<:SparseLayout}, + }, +) + α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C + mul!(a_dest, a1, a2, α, β) + return a_dest +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 81362fb648..8b7ddd4c7f 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -11,48 +11,49 @@ using BlockArrays: using ..SparseArrayInterface: perm, iperm, nstored using MappedArrays: mappedarray -function blocksparse_blocks(a::AbstractArray) - return blocks(a) -end +blocksparse_blocks(a::AbstractArray) = error("Not implemented") function blocksparse_getindex(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where {N} @boundscheck checkbounds(a, I...) return a[findblockindex.(axes(a), I)...] end -# TODO: Implement as `copy(@view a[I...])`, which is then implemented -# through `ArrayLayouts.sub_materialize`. -function blocksparse_getindex( - a::AbstractArray{<:Any,N}, I::Vararg{AbstractVector{<:Block{1}},N} -) where {N} - blocks_a = blocks(a) - # Convert to cartesian indices of the underlying sparse array of blocks. - CI = map(i -> Int.(i), I) - subblocks_a = blocks_a[CI...] - subaxes = ntuple(ndims(a)) do i - return axes(a, i)[I[i]] - end - return typeof(a)(subblocks_a, subaxes) -end - -# Slice by block and merge according to the blocking structure of the indices. -function blocksparse_getindex( - a::AbstractArray{<:Any,N}, I::Vararg{AbstractBlockVector{<:Block{1}},N} -) where {N} - a_sub = a[Vector.(I)...] - # TODO: Define `blocklengths(::AbstractBlockVector)`? Maybe make a PR - # to `BlockArrays`. - blockmergers = blockedrange.(blocklengths.(only.(axes.(I)))) - # TODO: Need to implement this! - a_merged = block_merge(a_sub, blockmergers...) - return a_merged -end +## # TODO: Implement as `copy(@view a[I...])`, which is then implemented +## # through `ArrayLayouts.sub_materialize`. +## function blocksparse_getindex( +## a::AbstractArray{<:Any,N}, I::Vararg{AbstractVector{<:Block{1}},N} +## ) where {N} +## blocks_a = blocks(a) +## # Convert to cartesian indices of the underlying sparse array of blocks. +## CI = map(i -> Int.(i), I) +## subblocks_a = blocks_a[CI...] +## subaxes = ntuple(ndims(a)) do i +## return axes(a, i)[I[i]] +## end +## return typeof(a)(subblocks_a, subaxes) +## end +## +## # Slice by block and merge according to the blocking structure of the indices. +## function blocksparse_getindex( +## a::AbstractArray{<:Any,N}, I::Vararg{AbstractBlockVector{<:Block{1}},N} +## ) where {N} +## a_sub = a[Vector.(I)...] +## # TODO: Define `blocklengths(::AbstractBlockVector)`? Maybe make a PR +## # to `BlockArrays`. +## blockmergers = blockedrange.(blocklengths.(only.(axes.(I)))) +## # TODO: Need to implement this! +## a_merged = block_merge(a_sub, blockmergers...) +## return a_merged +## end +## +## # TODO: Need to implement this! +## function block_merge(a::AbstractArray{<:Any,N}, I::Vararg{BlockedUnitRange,N}) where {N} +## # Need to `block_merge` each axis. +## return a +## end # TODO: Need to implement this! -function block_merge(a::AbstractArray{<:Any,N}, I::Vararg{BlockedUnitRange,N}) where {N} - # Need to `block_merge` each axis. - return a -end +function block_merge end function blocksparse_setindex!(a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N}) where {N} @boundscheck checkbounds(a, I...) @@ -72,7 +73,7 @@ function blocksparse_setindex!(a::AbstractArray{<:Any,N}, value, I::Block{N}) wh # TODO: Create a conversion function, say `CartesianIndex(Int.(Tuple(I)))`. i = I.n @boundscheck blockcheckbounds(a, i...) - blocksparse_blocks(a)[i...] = value + blocks(a)[i...] = value return a end @@ -80,18 +81,17 @@ function blocksparse_viewblock(a::AbstractArray{<:Any,N}, I::Block{N}) where {N} # TODO: Create a conversion function, say `CartesianIndex(Int.(Tuple(I)))`. i = I.n @boundscheck blockcheckbounds(a, i...) - return blocksparse_blocks(a)[i...] + return blocks(a)[i...] end function block_nstored(a::AbstractArray) - return nstored(blocksparse_blocks(a)) + return nstored(blocks(a)) end -# Base +# BlockArrays -# PermutedDimsArray function blocksparse_blocks(a::PermutedDimsArray) - blocks_parent = blocksparse_blocks(parent(a)) + blocks_parent = blocks(parent(a)) # Lazily permute each block blocks_parent_mapped = mappedarray( Base.Fix2(PermutedDimsArray, perm(a)), @@ -100,3 +100,14 @@ function blocksparse_blocks(a::PermutedDimsArray) ) return PermutedDimsArray(blocks_parent_mapped, perm(a)) end + +function blocksparse_blocks(a::SubArray) + parent_blocks = blocks(parent(a)) + indices = sort.(a.indices) + @show parent_blocks + @show indices + + @show [findblockindex(axes(parent(a), 1), i) for i in indices[1]] + + return error("`blocksparse_blocks(::SubArray)` not implemented yet.") +end From 84a45782ffa0ad2c4695bb8701bc0cdf2c4a0bc4 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 21 Feb 2024 13:50:13 -0500 Subject: [PATCH 04/17] Fix some slicing operations --- .../BlockArraysExtensions.jl | 66 +++++++++++++++++ .../map_mismatched_blocking.jl | 70 ------------------- .../blocksparsearrayinterface.jl | 39 +++++++++-- .../blocksparsearrayinterface/blockzero.jl | 4 ++ .../src/abstractsparsearray/baseinterface.jl | 9 ++- .../src/sparsearrayinterface/map.jl | 2 + 6 files changed, 112 insertions(+), 78 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index bf13b19837..43c938da78 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -38,3 +38,69 @@ end function block_reshape(a::AbstractArray, axes::Vararg{AbstractUnitRange}) return block_reshape(a, axes) end + +################################################################################## +using BlockArrays: + Block, BlockedUnitRange, block, blockindex, blocks, findblock, findblockindex +using ..SparseArrayInterface: stored_indices + +function cartesianindices(axes::Tuple, b::Block) + return CartesianIndices(ntuple(dim -> axes[dim][Tuple(b)[dim]], length(axes))) +end + +# Get the range within a block. +function blockindexrange(axis::BlockedUnitRange, r::UnitRange) + bi1 = findblockindex(axis, first(r)) + bi2 = findblockindex(axis, last(r)) + b = block(bi1) + # Range must fall within a single block. + @assert b == block(bi2) + i1 = blockindex(bi1) + i2 = blockindex(bi2) + return b[i1:i2] +end + +# Fallback for non-blocked ranges. +function blockindexrange(axis::AbstractUnitRange, r::UnitRange) + return r +end + +function blockindexrange( + axes::Tuple{Vararg{BlockedUnitRange,N}}, I::CartesianIndices{N} +) where {N} + brs = blockindexrange.(axes, I.indices) + b = Block(block.(brs)) + rs = map(br -> only(br.indices), brs) + return b[rs...] +end + +# Fallback for non-blocked ranges. +function blockindexrange( + axes::Tuple{Vararg{AbstractUnitRange,N}}, I::CartesianIndices{N} +) where {N} + return I +end + +function blockindexrange(a::AbstractArray, I::CartesianIndices) + return blockindexrange(axes(a), I) +end + +# Get the blocks the range spans across. +function blockrange(axis::BlockedUnitRange, r::UnitRange) + return findblock(axis, first(r)):findblock(axis, last(r)) +end + +function cartesianindices(a::AbstractArray, b::Block) + return cartesianindices(axes(a), b) +end + +# Output which blocks of `axis` are contained within the unit range `range`. +# The start and end points must match. +function findblocks(axis::AbstractUnitRange, range::AbstractUnitRange) + # TODO: Add a test that the start and end points of the ranges match. + return findblock(axis, first(range)):findblock(axis, last(range)) +end + +function block_stored_indices(a::AbstractArray) + return Block.(Tuple.(stored_indices(blocks(a)))) +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl index 7612110519..7e7e3ade9c 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl @@ -1,73 +1,3 @@ -using BlockArrays: - Block, BlockedUnitRange, block, blockindex, blocks, blocksize, findblock, findblockindex -using ..SparseArrayInterface: stored_indices - -function blockdiagonal(f!, elt::Type, axes::Tuple) - a = BlockSparseArray{elt}(axes) - for i in 1:minimum(blocksize(a)) - b = Block(ntuple(Returns(i), ndims(a))) - a[b] = f!(a[b]) - end - return a -end - -function cartesianindices(axes::Tuple, b::Block) - return CartesianIndices(ntuple(dim -> axes[dim][Tuple(b)[dim]], length(axes))) -end - -# Get the range within a block. -function blockindexrange(axis::BlockedUnitRange, r::UnitRange) - bi1 = findblockindex(axis, first(r)) - bi2 = findblockindex(axis, last(r)) - b = block(bi1) - # Range must fall within a single block. - @assert b == block(bi2) - i1 = blockindex(bi1) - i2 = blockindex(bi2) - return b[i1:i2] -end - -# Fallback for non-blocked ranges. -function blockindexrange(axis::AbstractUnitRange, r::UnitRange) - return r -end - -function blockindexrange( - axes::Tuple{Vararg{BlockedUnitRange,N}}, I::CartesianIndices{N} -) where {N} - brs = blockindexrange.(axes, I.indices) - b = Block(block.(brs)) - rs = map(br -> only(br.indices), brs) - return b[rs...] -end - -# Fallback for non-blocked ranges. -function blockindexrange( - axes::Tuple{Vararg{AbstractUnitRange,N}}, I::CartesianIndices{N} -) where {N} - return I -end - -function blockindexrange(a::AbstractArray, I::CartesianIndices) - return blockindexrange(axes(a), I) -end - -function cartesianindices(a::AbstractArray, b::Block) - return cartesianindices(axes(a), b) -end - -# Output which blocks of `axis` are contained within the unit range `range`. -# The start and end points must match. -function findblocks(axis::AbstractUnitRange, range::AbstractUnitRange) - # TODO: Add a test that the start and end points of the ranges match. - return findblock(axis, first(range)):findblock(axis, last(range)) -end - -function block_stored_indices(a::AbstractArray) - return Block.(Tuple.(stored_indices(blocks(a)))) -end - -############################################################## using BlockArrays: BlockArrays, BlockRange function map_mismatched_blocking!(f, a_dest::AbstractArray, a_src::AbstractArray) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 8b7ddd4c7f..94ee3fe1d4 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -101,13 +101,38 @@ function blocksparse_blocks(a::PermutedDimsArray) return PermutedDimsArray(blocks_parent_mapped, perm(a)) end -function blocksparse_blocks(a::SubArray) - parent_blocks = blocks(parent(a)) - indices = sort.(a.indices) - @show parent_blocks - @show indices +function blockindices(a::AbstractArray, block::Block, indices::Tuple) + return blockindices(axes(a), block, indices) +end - @show [findblockindex(axes(parent(a), 1), i) for i in indices[1]] +function blockindices(axes::Tuple, block::Block, indices::Tuple) + return blockindices.(axes, Tuple(block), indices) +end - return error("`blocksparse_blocks(::SubArray)` not implemented yet.") +function blockindices(axis::AbstractUnitRange, block::Block, indices) + indices_within_block = intersect(indices, axis[block]) + if iszero(length(indices_within_block)) + # Falls outside of block + return 1:0 + end + return only(blockindexrange(axis, indices_within_block).indices) +end + +function blocksparse_blocks(a::SubArray) + # First slice blockwise. + blockranges = blockrange.(axes(parent(a)), a.indices) + # Then slice the blocks. + sliced_blocks = map(stored_indices(blocks(parent(a)))) do index + tuple_index = Tuple(index) + block = Block(tuple_index) + return view( + blocks(parent(a))[tuple_index...], blockindices(parent(a), block, a.indices)... + ) + end + # TODO: Use a `set_data` function, or some kind of `similar` or `zero` method? + blocks_a_sub = SparseArrayDOK( + sliced_blocks, size(blocks(parent(a))), blocks(parent(a)).zero + ) + # TODO: Avoid copying, use a view? + return blocks_a_sub[map(blockrange -> Int.(blockrange), blockranges)...] end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blockzero.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blockzero.jl index 479f78c334..b4618415fe 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blockzero.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blockzero.jl @@ -22,6 +22,10 @@ function (f::BlockZero)(a::AbstractArray, I) return f(eltype(a), I) end +function (f::BlockZero)(arraytype::Type{<:SubArray{<:Any,<:Any,P}}, I) where {P} + return f(P, I) +end + function (f::BlockZero)(arraytype::Type{<:AbstractArray}, I) # TODO: Make sure this works for sparse or block sparse blocks, immutable # blocks, diagonal blocks, etc.! diff --git a/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/baseinterface.jl b/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/baseinterface.jl index 334a3f30cb..bd9dd326af 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/baseinterface.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/baseinterface.jl @@ -11,7 +11,14 @@ function Base.getindex(a::AbstractSparseArray, I...) end # Fixes ambiguity error with `ArrayLayouts`. -function Base.getindex(a::AbstractSparseArray, I1::AbstractVector, I2::AbstractVector) +function Base.getindex(a::AbstractSparseMatrix, I1::AbstractVector, I2::AbstractVector) + return SparseArrayInterface.sparse_getindex(a, I1, I2) +end + +# Fixes ambiguity error with `ArrayLayouts`. +function Base.getindex( + a::AbstractSparseMatrix, I1::AbstractUnitRange, I2::AbstractUnitRange +) return SparseArrayInterface.sparse_getindex(a, I1, I2) end diff --git a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/map.jl b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/map.jl index 993d4d8a7c..f637a5701e 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/map.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/map.jl @@ -10,6 +10,8 @@ end value(v::NotStoredValue) = v.value nstored(::NotStoredValue) = false Base.:*(x::Number, y::NotStoredValue) = false +Base.:*(x::NotStoredValue, y::Number) = false +Base.:/(x::NotStoredValue, y::Number) = false Base.:+(::NotStoredValue, ::NotStoredValue...) = false Base.:-(::NotStoredValue, ::NotStoredValue...) = false Base.:+(x::Number, ::NotStoredValue...) = x From da01bd73baa50ef5f870becea0e66d3d64132564 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sun, 25 Feb 2024 18:05:17 -0500 Subject: [PATCH 05/17] Some progress on slicing --- .../map_mismatched_blocking.jl | 2 ++ .../blocksparsearrayinterface.jl | 29 +++++++++++++++++++ .../src/SparseArrayInterface.jl | 4 +-- 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl index 7e7e3ade9c..f789ba2b28 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl @@ -1,6 +1,8 @@ using BlockArrays: BlockArrays, BlockRange function map_mismatched_blocking!(f, a_dest::AbstractArray, a_src::AbstractArray) + ## @show typeof(a_src) + # Create a common set of axes with a blocking that includes the # blocking of `a_dest` and `a_src`. matching_axes = BlockArrays.combine_blockaxes.(axes(a_dest), axes(a_src)) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 94ee3fe1d4..0d8de63757 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -101,14 +101,17 @@ function blocksparse_blocks(a::PermutedDimsArray) return PermutedDimsArray(blocks_parent_mapped, perm(a)) end +# TODO: Move to `BlockArraysExtensions`. function blockindices(a::AbstractArray, block::Block, indices::Tuple) return blockindices(axes(a), block, indices) end +# TODO: Move to `BlockArraysExtensions`. function blockindices(axes::Tuple, block::Block, indices::Tuple) return blockindices.(axes, Tuple(block), indices) end +# TODO: Move to `BlockArraysExtensions`. function blockindices(axis::AbstractUnitRange, block::Block, indices) indices_within_block = intersect(indices, axis[block]) if iszero(length(indices_within_block)) @@ -118,7 +121,33 @@ function blockindices(axis::AbstractUnitRange, block::Block, indices) return only(blockindexrange(axis, indices_within_block).indices) end +using ..SparseArrayInterface: SparseArrayInterface, AbstractSparseArray + +# Represents the array of arrays of a `SubArray` +# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `SubArray`. +struct SubBlocks{T,N,Array<:SubArray{T,N}} <: AbstractSparseArray{T,N} + array::Array +end +function Base.axes(a::SubBlocks) + blockranges = blockrange.(axes(parent(a.array)), a.array.indices) + return map(blockrange -> Int.(blockrange), blockranges) +end +function SparseArrayInterface.stored_indices(a::SubBlocks) + return stored_indices(view(blocks(parent(a.array)), axes(a)...)) +end +function Base.getindex(a::SubBlocks{<:Any,N}, I::CartesianIndex{N}) where {N} + parent_blocks = view(blocks(parent(a.array)), axes(a)...) + return parent_blocks[blockindices(parent(a.array), Block(Tuple(I)), a.array.indices)...] +end +function SparseArrayInterface.sparse_storage(a::SubBlocks) + error() + # TODO: This also needs to slice the blocks! + return view(blocks(parent(a.array)), axes(a)...) +end + function blocksparse_blocks(a::SubArray) + return SubBlocks(a) + error() # First slice blockwise. blockranges = blockrange.(axes(parent(a)), a.indices) # Then slice the blocks. diff --git a/NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl b/NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl index 9c2df4b7a3..33647bf476 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl @@ -14,6 +14,8 @@ include("sparsearrayinterface/wrappers.jl") include("sparsearrayinterface/zero.jl") include("sparsearrayinterface/SparseArrayInterfaceLinearAlgebraExt.jl") include("abstractsparsearray/abstractsparsearray.jl") +include("abstractsparsearray/abstractsparsematrix.jl") +include("abstractsparsearray/abstractsparsevector.jl") include("abstractsparsearray/wrappedabstractsparsearray.jl") include("abstractsparsearray/arraylayouts.jl") include("abstractsparsearray/sparsearrayinterface.jl") @@ -23,7 +25,5 @@ include("abstractsparsearray/map.jl") include("abstractsparsearray/baseinterface.jl") include("abstractsparsearray/convert.jl") include("abstractsparsearray/SparseArrayInterfaceSparseArraysExt.jl") -include("abstractsparsearray/abstractsparsematrix.jl") include("abstractsparsearray/SparseArrayInterfaceLinearAlgebraExt.jl") -include("abstractsparsearray/abstractsparsevector.jl") end From 693b9b85836e2c08ee920017269f4e01ae53e978 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 1 Mar 2024 11:49:50 -0500 Subject: [PATCH 06/17] More general block sparse map/broadcast functionality --- .../test/runtests.jl | 5 + .../BlockArraysExtensions.jl | 102 ++++++++++++++++-- .../BlockArraysSparseArrayInterfaceExt.jl | 6 ++ .../src/BlockSparseArrays.jl | 2 +- .../abstractblocksparsearray.jl | 2 +- .../src/abstractblocksparsearray/map.jl | 68 ++++++++++-- .../map_mismatched_blocking.jl | 30 ------ .../src/abstractblocksparsearray/view.jl | 14 +++ .../wrappedabstractblocksparsearray.jl | 59 ++++++++-- .../src/blocksparsearray/blocksparsearray.jl | 17 +-- .../blocksparsearrayinterface.jl | 93 +++++++++------- 11 files changed, 301 insertions(+), 97 deletions(-) delete mode 100644 NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl create mode 100644 NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/view.jl diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index ae887791af..b1e1f41ed7 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -10,6 +10,11 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "BlockSparseArraysGradedAxesExt (eltype=$elt)" for elt in elts d1 = gradedrange([U1(0) => 1, U1(1) => 1]) d2 = gradedrange([U1(1) => 1, U1(0) => 1]) + + ## using BlockArrays: blockedrange + ## d1 = blockedrange([1, 1]) + ## d2 = blockedrange([2, 2]) + a = BlockSparseArray{elt}(d1, d2, d1, d2) for i in 1:minimum(blocksize(a)) b = Block(i, i, i, i) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index 43c938da78..f258703db0 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -1,7 +1,52 @@ -using BlockArrays: AbstractBlockArray, AbstractBlockVector, Block, blockedrange +using BlockArrays: + BlockArrays, + AbstractBlockArray, + AbstractBlockVector, + Block, + BlockRange, + BlockedUnitRange, + BlockVector, + block, + blockedrange, + blockindex, + blocks, + findblock, + findblockindex using Dictionaries: Dictionary, Indices using ..SparseArrayInterface: stored_indices +# Outputs a `BlockUnitRange`. +function sub_unitrange(a::AbstractUnitRange, indices) + @show indices + @show typeof(indices) + return error("Not implemented") +end + +# TODO: Write a specialized version for `indices::AbstractUnitRange`. +# TODO: Write a generic version for blocked unit ranges (like `GradedAxes`). +# Outputs a `BlockUnitRange`. +function sub_unitrange(a::AbstractUnitRange, indices::AbstractUnitRange) + @show typeof(indices) + indices = sort(indices) + br = blockedrange(collect(groupcount(i -> findblock(a, i), indices))) + @show typeof(br) + return br +end + +# Outputs a `BlockUnitRange`. +function sub_unitrange(a::AbstractUnitRange, indices::Vector{<:Block}) + @show indices + @show [a[index] for index in indices] + return error("Not implemented") +end + +function sub_unitrange(a::AbstractUnitRange, indices::BlockVector{<:Block}) + println("TEST") + @show indices + @show [a[index] for index in indices] + return blockedrange([length(a[index]) for index in indices]) +end + # TODO: Use `Tuple` conversion once # BlockArrays.jl PR is merged. block_to_cartesianindex(b::Block) = CartesianIndex(b.n) @@ -39,11 +84,6 @@ function block_reshape(a::AbstractArray, axes::Vararg{AbstractUnitRange}) return block_reshape(a, axes) end -################################################################################## -using BlockArrays: - Block, BlockedUnitRange, block, blockindex, blocks, findblock, findblockindex -using ..SparseArrayInterface: stored_indices - function cartesianindices(axes::Tuple, b::Block) return CartesianIndices(ntuple(dim -> axes[dim][Tuple(b)[dim]], length(axes))) end @@ -86,10 +126,21 @@ function blockindexrange(a::AbstractArray, I::CartesianIndices) end # Get the blocks the range spans across. -function blockrange(axis::BlockedUnitRange, r::UnitRange) +function blockrange(axis::AbstractUnitRange, r::UnitRange) return findblock(axis, first(r)):findblock(axis, last(r)) end +function blockrange(axis::AbstractUnitRange, r::Int) + error("Slicing with integer values isn't supported.") + return findblock(axis, r) +end + +function blockrange(axis::AbstractUnitRange, r) + @show r + @show typeof(r) + return error("Not implemented") +end + function cartesianindices(a::AbstractArray, b::Block) return cartesianindices(axes(a), b) end @@ -104,3 +155,40 @@ end function block_stored_indices(a::AbstractArray) return Block.(Tuple.(stored_indices(blocks(a)))) end + +_block(indices) = block(indices) +_block(indices::CartesianIndices) = Block(ntuple(Returns(1), ndims(indices))) + +function combine_axes(as::Vararg{Tuple}) + @assert allequal(length.(as)) + ndims = length(first(as)) + return ntuple(ndims) do dim + dim_axes = map(a -> a[dim], as) + return reduce(BlockArrays.combine_blockaxes, dim_axes) + end +end + +# Returns `BlockRange` +# Convert the block of the axes to blocks of the subaxes. +function subblocks(axes::Tuple, subaxes::Tuple, block::Block) + @assert length(axes) == length(subaxes) + return BlockRange( + ntuple(length(axes)) do dim + findblocks(subaxes[dim], axes[dim][Tuple(block)[dim]]) + end, + ) +end + +# Returns `Vector{<:Block}` +function subblocks(axes::Tuple, subaxes::Tuple, blocks) + return mapreduce(vcat, blocks; init=eltype(blocks)[]) do block + return vec(subblocks(axes, subaxes, block)) + end +end + +# Returns `Vector{<:CartesianIndices}` +function blocked_cartesianindices(axes::Tuple, subaxes::Tuple, blocks) + return map(subblocks(axes, subaxes, blocks)) do block + return cartesianindices(subaxes, block) + end +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl index ca050c90ed..658fe4436d 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl @@ -1,5 +1,11 @@ +using BlockArrays: AbstractBlockArray, BlocksView using ..SparseArrayInterface: SparseArrayInterface, nstored function SparseArrayInterface.nstored(a::AbstractBlockArray) return sum(b -> nstored(b), blocks(a); init=zero(Int)) end + +# TODO: Handle `BlocksView` wrapping a sparse array? +function SparseArrayInterface.storage_indices(a::BlocksView) + return CartesianIndices(a) +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl index c88c69e2ff..d0430732fb 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl @@ -8,12 +8,12 @@ include("abstractblocksparsearray/abstractblocksparsearray.jl") include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl") include("abstractblocksparsearray/abstractblocksparsematrix.jl") include("abstractblocksparsearray/abstractblocksparsevector.jl") +include("abstractblocksparsearray/view.jl") include("abstractblocksparsearray/arraylayouts.jl") include("abstractblocksparsearray/sparsearrayinterface.jl") include("abstractblocksparsearray/linearalgebra.jl") include("abstractblocksparsearray/broadcast.jl") include("abstractblocksparsearray/map.jl") -include("abstractblocksparsearray/map_mismatched_blocking.jl") include("blocksparsearray/defaults.jl") include("blocksparsearray/blocksparsearray.jl") include("BlockArraysExtensions/BlockArraysExtensions.jl") diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl index 5c60395b4b..40c15b7d05 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl @@ -14,7 +14,7 @@ Base.axes(::AbstractBlockSparseArray) = error("Not implemented") blockstype(::Type{<:AbstractBlockSparseArray}) = error("Not implemented") -# Specialized in order to fix ambiguity error with `BlockArrays`. +## # Specialized in order to fix ambiguity error with `BlockArrays`. function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N}) where {N} return blocksparse_getindex(a, I...) end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl index 78a07626b8..0420fdf80e 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl @@ -11,19 +11,75 @@ using ..SparseArrayInterface: sparse_iszero, sparse_isreal +## using BlockArrays: BlockArrays, BlockRange, block + +## _block(indices) = block(indices) +## _block(indices::CartesianIndices) = Block(ntuple(Returns(1), ndims(indices))) +## +## function combine_axes(as::Vararg{Tuple}) +## @assert allequal(length.(as)) +## ndims = length(first(as)) +## return ntuple(ndims) do dim +## dim_axes = map(a -> a[dim], as) +## return reduce(BlockArrays.combine_blockaxes, dim_axes) +## end +## end +## +## # Returns `BlockRange` +## # Convert the block of the axes to blocks of the subaxes. +## function subblocks(axes::Tuple, subaxes::Tuple, block::Block) +## @assert length(axes) == length(subaxes) +## return BlockRange( +## ntuple(length(axes)) do dim +## findblocks(subaxes[dim], axes[dim][Tuple(block)[dim]]) +## end, +## ) +## end +## +## # Returns `Vector{<:Block}` +## function subblocks(axes::Tuple, subaxes::Tuple, blocks) +## return mapreduce(vcat, blocks; init=eltype(blocks)[]) do block +## return vec(subblocks(axes, subaxes, block)) +## end +## end +## +## # Returns `Vector{<:CartesianIndices}` +## function stored_blocked_cartesianindices(a::AbstractArray, subaxes::Tuple) +## return map(subblocks(axes(a), subaxes, block_stored_indices(a))) do block +## return cartesianindices(subaxes, block) +## end +## end + +# Returns `Vector{<:CartesianIndices}` +function union_stored_blocked_cartesianindices(as::Vararg{AbstractArray}) + stored_blocked_cartesianindices_as = map(as) do a + return blocked_cartesianindices( + axes(a), combine_axes(axes.(as)...), block_stored_indices(a) + ) + end + return ∪(stored_blocked_cartesianindices_as...) +end + function SparseArrayInterface.sparse_map!( ::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray} ) - if all(a_src -> blockisequal(axes(a_dest), axes(a_src)), a_srcs) - # If the axes/block structure are all the same, - # map based on the blocks. - map!(f, blocks(a_dest), blocks.(a_srcs)...) - else - map_mismatched_blocking!(f, a_dest, a_srcs...) + for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...) + BI_dest = blockindexrange(a_dest, I) + BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs) + block_dest = @view a_dest[_block(BI_dest)] + block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs)) + subblock_dest = @view block_dest[BI_dest.indices...] + subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs)) + # TODO: Use `map!!` to handle immutable blocks. + map!(f, subblock_dest, subblock_srcs...) + # Replace the entire block, handles initializing new blocks + # or if blocks are immutable. + blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...] = block_dest end return a_dest end +# TODO: Implement this. # function SparseArrayInterface.sparse_mapreduce(::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray}) # end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl deleted file mode 100644 index f789ba2b28..0000000000 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map_mismatched_blocking.jl +++ /dev/null @@ -1,30 +0,0 @@ -using BlockArrays: BlockArrays, BlockRange - -function map_mismatched_blocking!(f, a_dest::AbstractArray, a_src::AbstractArray) - ## @show typeof(a_src) - - # Create a common set of axes with a blocking that includes the - # blocking of `a_dest` and `a_src`. - matching_axes = BlockArrays.combine_blockaxes.(axes(a_dest), axes(a_src)) - # TODO: Also include `block_stored_indices(a_dest)`! - for b in block_stored_indices(a_src) - # Get the subblocks of the matching axes - # TODO: `union` all `subblocks` of all `a_src` and `a_dest`. - subblocks = BlockRange( - ntuple(ndims(a_dest)) do dim - findblocks(matching_axes[dim], axes(a_src, dim)[Tuple(b)[dim]]) - end, - ) - for subblock in subblocks - I = cartesianindices(matching_axes, subblock) - I_dest = blockindexrange(a_dest, I) - I_src = blockindexrange(a_src, I) - - # TODO: Broken, need to fix. - # map!(f, view(a_dest, I_dest), view(a_src, I_src)) - - map!(f, view(a_dest, I_dest), a_src[I_src]) - end - end - return a_dest -end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/view.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/view.jl new file mode 100644 index 0000000000..a9d389f04c --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/view.jl @@ -0,0 +1,14 @@ +using BlockArrays: BlockIndexRange, block + +function Base.view(a::BlockSparseArrayLike{<:Any,N}, index::Block{N}) where {N} + return blocks(a)[Int.(Tuple(index))...] +end + +# TODO: Define `AnyBlockSparseVector`. +function Base.view(a::BlockSparseArrayLike{<:Any,1}, index::Block{1}) + return blocks(a)[Int.(Tuple(index))...] +end + +function Base.view(a::BlockSparseArrayLike, indices::BlockIndexRange) + return view(view(a, block(indices)), indices.indices...) +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index ef7f7edeea..bdbebdc60f 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -1,11 +1,12 @@ using BlockArrays: BlockedUnitRange, blockedrange using SplitApplyCombine: groupcount -# TODO: Write a specialized version for `indices::AbstractUnitRange`. -# TODO: Write a generic version for blocked unit ranges (like `GradedAxes`). -function sub_unitrange(a::BlockedUnitRange, indices) - indices = sort(indices) - return blockedrange(collect(groupcount(i -> findblock(a, i), indices))) -end + +## # TODO: Write a specialized version for `indices::AbstractUnitRange`. +## # TODO: Write a generic version for blocked unit ranges (like `GradedAxes`). +## function sub_unitrange(a::BlockedUnitRange, indices) +## indices = sort(indices) +## return blockedrange(collect(groupcount(i -> findblock(a, i), indices))) +## end using Adapt: Adapt, WrappedArray @@ -18,6 +19,8 @@ const BlockSparseArrayLike{T,N} = Union{ } # AbstractArray interface +# TODO: Use `BlockSparseArrayLike`. +# TODO: Need to handle block indexing. function Base.axes(a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray}) return ntuple(i -> sub_unitrange(axes(parent(a), i), a.indices[i]), ndims(a)) end @@ -33,6 +36,23 @@ end blocktype(a::BlockSparseArrayLike) = eltype(blocks(a)) blocktype(arraytype::Type{<:BlockSparseArrayLike}) = eltype(blockstype(arraytype)) +using ArrayLayouts: ArrayLayouts +## function Base.getindex(a::BlockSparseArrayLike{<:Any,N}, I::Vararg{Int,N}) where {N} +## return ArrayLayouts.layout_getindex(a, I...) +## end +function Base.getindex(a::BlockSparseArrayLike{<:Any,N}, I::CartesianIndices{N}) where {N} + return ArrayLayouts.layout_getindex(a, I) +end +function Base.getindex( + a::BlockSparseArrayLike{<:Any,N}, I::Vararg{AbstractUnitRange,N} +) where {N} + return ArrayLayouts.layout_getindex(a, I...) +end +# TODO: Define `AnyBlockSparseMatrix`. +function Base.getindex(a::BlockSparseArrayLike{<:Any,2}, I::Vararg{AbstractUnitRange,2}) + return ArrayLayouts.layout_getindex(a, I...) +end + function Base.setindex!(a::BlockSparseArrayLike{<:Any,N}, value, I::BlockIndex{N}) where {N} blocksparse_setindex!(a, value, I) return a @@ -70,6 +90,7 @@ function Base.similar( end # Needed by `BlockArrays` matrix multiplication interface +# TODO: Define a `blocksparse_similar` function. function Base.similar( arraytype::Type{<:BlockSparseArrayLike}, elt::Type, axes::Tuple{Vararg{AbstractUnitRange}} ) @@ -78,14 +99,26 @@ function Base.similar( return BlockSparseArray{elt}(undef, axes) end +# TODO: Define a `blocksparse_similar` function. function Base.similar( a::BlockSparseArrayLike, elt::Type, axes::Tuple{Vararg{AbstractUnitRange}} ) # TODO: Make generic for GPU, maybe using `blocktype`. # TODO: For non-block axes this should output `Array`. - return BlockSparseArray{eltype(a)}(undef, axes) + return BlockSparseArray{elt}(undef, axes) +end + +# TODO: Define a `blocksparse_similar` function. +# Fixes ambiguity error with `BlockArrays`. +function Base.similar( + a::BlockSparseArrayLike, elt::Type, axes::Tuple{BlockedUnitRange,Vararg{BlockedUnitRange}} +) + # TODO: Make generic for GPU, maybe using `blocktype`. + # TODO: For non-block axes this should output `Array`. + return BlockSparseArray{elt}(undef, axes) end +# TODO: Define a `blocksparse_similar` function. # Fixes ambiguity error with `OffsetArrays`. function Base.similar( a::BlockSparseArrayLike, @@ -94,5 +127,15 @@ function Base.similar( ) # TODO: Make generic for GPU, maybe using `blocktype`. # TODO: For non-block axes this should output `Array`. - return BlockSparseArray{eltype(a)}(undef, axes) + return BlockSparseArray{elt}(undef, axes) +end + +# TODO: Define a `blocksparse_similar` function. +# Fixes ambiguity error with `StaticArrays`. +function Base.similar( + a::BlockSparseArrayLike, elt::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}} +) + # TODO: Make generic for GPU, maybe using `blocktype`. + # TODO: For non-block axes this should output `Array`. + return BlockSparseArray{elt}(undef, axes) end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl index b22ee08a72..4ecde33381 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl @@ -16,6 +16,9 @@ struct BlockSparseArray{ axes::Axes end +const BlockSparseMatrix{T,A,Blocks,Axes} = BlockSparseArray{T,2,A,Blocks,Axes} +const BlockSparseVector{T,A,Blocks,Axes} = BlockSparseArray{T,1,A,Blocks,Axes} + function BlockSparseArray( block_data::Dictionary{<:Block{N},<:AbstractArray{<:Any,N}}, axes::Tuple{Vararg{AbstractUnitRange,N}}, @@ -108,10 +111,10 @@ blocksparse_blocks(a::BlockSparseArray) = a.blocks # TODO: Use `TypeParameterAccessors`. blockstype(::Type{<:BlockSparseArray{<:Any,<:Any,<:Any,B}}) where {B} = B -# Base interface -function Base.similar( - a::AbstractBlockSparseArray, elt::Type, axes::Tuple{Vararg{BlockedUnitRange}} -) - # TODO: Preserve GPU data! - return BlockSparseArray{elt}(undef, axes) -end +## # Base interface +## function Base.similar( +## a::AbstractBlockSparseArray, elt::Type, axes::Tuple{Vararg{BlockedUnitRange}} +## ) +## # TODO: Preserve GPU data! +## return BlockSparseArray{elt}(undef, axes) +## end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 0d8de63757..77d00e214b 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -9,7 +9,7 @@ using BlockArrays: blocklengths, findblockindex using ..SparseArrayInterface: perm, iperm, nstored -using MappedArrays: mappedarray +## using MappedArrays: mappedarray blocksparse_blocks(a::AbstractArray) = error("Not implemented") @@ -90,15 +90,33 @@ end # BlockArrays +using ..SparseArrayInterface: SparseArrayInterface, AbstractSparseArray + +# Represents the array of arrays of a `SubArray` +# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `SubArray`. +struct SparsePermutedDimsArrayBlocks{T,N,Array<:PermutedDimsArray{T,N}} <: + AbstractSparseArray{T,N} + array::Array +end function blocksparse_blocks(a::PermutedDimsArray) - blocks_parent = blocks(parent(a)) - # Lazily permute each block - blocks_parent_mapped = mappedarray( - Base.Fix2(PermutedDimsArray, perm(a)), - Base.Fix2(PermutedDimsArray, iperm(a)), - blocks_parent, + return SparsePermutedDimsArrayBlocks(a) +end +_perm(::PermutedDimsArray{<:Any,<:Any,P}) where {P} = P +_getindices(t::Tuple, indices) = map(i -> t[i], indices) +_getindices(i::CartesianIndex, indices) = CartesianIndex(_getindices(Tuple(i), indices)) +function SparseArrayInterface.stored_indices(a::SparsePermutedDimsArrayBlocks) + return map(I -> _getindices(I, _perm(a.array)), stored_indices(blocks(parent(a.array)))) +end +function Base.size(a::SparsePermutedDimsArrayBlocks) + return _getindices(size(blocks(parent(a.array))), _perm(a.array)) +end +function Base.getindex(a::SparsePermutedDimsArrayBlocks, index::Vararg{Int}) + return PermutedDimsArray( + blocks(parent(a.array))[_getindices(index, _perm(a.array))...], _perm(a.array) ) - return PermutedDimsArray(blocks_parent_mapped, perm(a)) +end +function SparseArrayInterface.sparse_storage(a::SparsePermutedDimsArrayBlocks) + return error("Not implemented") end # TODO: Move to `BlockArraysExtensions`. @@ -121,47 +139,48 @@ function blockindices(axis::AbstractUnitRange, block::Block, indices) return only(blockindexrange(axis, indices_within_block).indices) end -using ..SparseArrayInterface: SparseArrayInterface, AbstractSparseArray - # Represents the array of arrays of a `SubArray` # wrapping a block spare array, i.e. `blocks(array)` where `a` is a `SubArray`. -struct SubBlocks{T,N,Array<:SubArray{T,N}} <: AbstractSparseArray{T,N} +struct SparseSubArrayBlocks{T,N,Array<:SubArray{T,N}} <: AbstractSparseArray{T,N} array::Array end -function Base.axes(a::SubBlocks) +# TODO: Define this as `blockrange(a::AbstractArray, indices::Tuple{Vararg{AbstractUnitRange}})`. +function blockrange(a::SparseSubArrayBlocks) blockranges = blockrange.(axes(parent(a.array)), a.array.indices) return map(blockrange -> Int.(blockrange), blockranges) end -function SparseArrayInterface.stored_indices(a::SubBlocks) +function Base.axes(a::SparseSubArrayBlocks) + return Base.OneTo.(length.(blockrange(a))) +end +function Base.size(a::SparseSubArrayBlocks) + return length.(axes(a)) +end +function SparseArrayInterface.stored_indices(a::SparseSubArrayBlocks) return stored_indices(view(blocks(parent(a.array)), axes(a)...)) end -function Base.getindex(a::SubBlocks{<:Any,N}, I::CartesianIndex{N}) where {N} +function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::CartesianIndex{N}) where {N} + return a[Tuple(I)...] +end +function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where {N} + parent_blocks = @view blocks(parent(a.array))[axes(a)...] + parent_block = parent_blocks[I...] + # TODO: Define this using `blockrange(a::AbstractArray, indices::Tuple{Vararg{AbstractUnitRange}})`. + block = Block(ntuple(i -> blockrange(a)[i][I[i]], ndims(a))) + return @view parent_block[blockindices(parent(a.array), block, a.array.indices)...] +end +function Base.setindex!(a::SparseSubArrayBlocks{<:Any,N}, value, I::Vararg{Int,N}) where {N} parent_blocks = view(blocks(parent(a.array)), axes(a)...) - return parent_blocks[blockindices(parent(a.array), Block(Tuple(I)), a.array.indices)...] + return parent_blocks[I...][blockindices(parent(a.array), Block(I), a.array.indices)...] = + value +end +function Base.isassigned(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where {N} + # TODO: Implement this properly. + return true end -function SparseArrayInterface.sparse_storage(a::SubBlocks) - error() - # TODO: This also needs to slice the blocks! - return view(blocks(parent(a.array)), axes(a)...) +function SparseArrayInterface.sparse_storage(a::SparseSubArrayBlocks) + return error("Not implemented") end function blocksparse_blocks(a::SubArray) - return SubBlocks(a) - error() - # First slice blockwise. - blockranges = blockrange.(axes(parent(a)), a.indices) - # Then slice the blocks. - sliced_blocks = map(stored_indices(blocks(parent(a)))) do index - tuple_index = Tuple(index) - block = Block(tuple_index) - return view( - blocks(parent(a))[tuple_index...], blockindices(parent(a), block, a.indices)... - ) - end - # TODO: Use a `set_data` function, or some kind of `similar` or `zero` method? - blocks_a_sub = SparseArrayDOK( - sliced_blocks, size(blocks(parent(a))), blocks(parent(a)).zero - ) - # TODO: Avoid copying, use a view? - return blocks_a_sub[map(blockrange -> Int.(blockrange), blockranges)...] + return SparseSubArrayBlocks(a) end From 03e709e42ece3e139bb8bbffc065c97d7ab072a2 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 1 Mar 2024 11:57:19 -0500 Subject: [PATCH 07/17] Remove printing --- .../BlockArraysExtensions/BlockArraysExtensions.jl | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index f258703db0..0cf0375e24 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -17,8 +17,6 @@ using ..SparseArrayInterface: stored_indices # Outputs a `BlockUnitRange`. function sub_unitrange(a::AbstractUnitRange, indices) - @show indices - @show typeof(indices) return error("Not implemented") end @@ -26,24 +24,16 @@ end # TODO: Write a generic version for blocked unit ranges (like `GradedAxes`). # Outputs a `BlockUnitRange`. function sub_unitrange(a::AbstractUnitRange, indices::AbstractUnitRange) - @show typeof(indices) indices = sort(indices) - br = blockedrange(collect(groupcount(i -> findblock(a, i), indices))) - @show typeof(br) - return br + return blockedrange(collect(groupcount(i -> findblock(a, i), indices))) end # Outputs a `BlockUnitRange`. function sub_unitrange(a::AbstractUnitRange, indices::Vector{<:Block}) - @show indices - @show [a[index] for index in indices] return error("Not implemented") end function sub_unitrange(a::AbstractUnitRange, indices::BlockVector{<:Block}) - println("TEST") - @show indices - @show [a[index] for index in indices] return blockedrange([length(a[index]) for index in indices]) end @@ -136,8 +126,6 @@ function blockrange(axis::AbstractUnitRange, r::Int) end function blockrange(axis::AbstractUnitRange, r) - @show r - @show typeof(r) return error("Not implemented") end From 6fb858327e82fb6e0fb7a10b8419eedafbe18f3e Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 8 Mar 2024 14:18:47 -0500 Subject: [PATCH 08/17] Fix some issues with slicing BlockSparseArrays with GradedAxes, some remain --- .../test/runtests.jl | 46 +++++++++++-------- .../BlockArraysExtensions.jl | 18 ++------ .../src/abstractblocksparsearray/map.jl | 9 +++- .../blocksparsearrayinterface/broadcast.jl | 6 +-- 4 files changed, 41 insertions(+), 38 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index b1e1f41ed7..578f9d6857 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -6,26 +6,36 @@ using NDTensors.GradedAxes: gradedrange using NDTensors.Sectors: U1 using NDTensors.TensorAlgebra: fusedims, splitdims using Random: randn! +function blockdiagonal!(f, a::AbstractArray) + for i in 1:minimum(blocksize(a)) + b = Block(ntuple(Returns(i), ndims(a))) + a[b] = f(a[b]) + end + return a +end const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "BlockSparseArraysGradedAxesExt (eltype=$elt)" for elt in elts - d1 = gradedrange([U1(0) => 1, U1(1) => 1]) - d2 = gradedrange([U1(1) => 1, U1(0) => 1]) - - ## using BlockArrays: blockedrange - ## d1 = blockedrange([1, 1]) - ## d2 = blockedrange([2, 2]) - - a = BlockSparseArray{elt}(d1, d2, d1, d2) - for i in 1:minimum(blocksize(a)) - b = Block(i, i, i, i) - a[b] = randn!(a[b]) + @testset "map" begin + d1 = gradedrange([U1(0) => 1, U1(1) => 1]) + d2 = gradedrange([U1(1) => 1, U1(0) => 1]) + a = BlockSparseArray{elt}(d1, d2, d1, d2) + blockdiagonal!(randn!, a) + @test Array(a) isa Array{elt} + @test Array(a) == a + @test 2 * Array(a) == 2a + end + @testset "fusedims" begin + d1 = gradedrange([U1(0) => 1, U1(1) => 1]) + d2 = gradedrange([U1(1) => 1, U1(0) => 1]) + a = BlockSparseArray{elt}(d1, d2, d1, d2) + blockdiagonal!(randn!, a) + m = fusedims(a, (1, 2), (3, 4)) + @test a[1, 1, 1, 1] == m[2, 2] + @test a[2, 2, 2, 2] == m[3, 3] + # TODO: Current `fusedims` doesn't merge + # common sectors, need to fix. + @test_broken blocksize(m) == (3, 3) + @test a == splitdims(m, (d1, d2), (d1, d2)) end - m = fusedims(a, (1, 2), (3, 4)) - @test a[1, 1, 1, 1] == m[2, 2] - @test a[2, 2, 2, 2] == m[3, 3] - # TODO: Current `fusedims` doesn't merge - # common sectors, need to fix. - @test_broken blocksize(m) == (3, 3) - @test a == splitdims(m, (d1, d2), (d1, d2)) end end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index 0cf0375e24..ce0dee6861 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -79,7 +79,7 @@ function cartesianindices(axes::Tuple, b::Block) end # Get the range within a block. -function blockindexrange(axis::BlockedUnitRange, r::UnitRange) +function blockindexrange(axis::AbstractUnitRange, r::UnitRange) bi1 = findblockindex(axis, first(r)) bi2 = findblockindex(axis, last(r)) b = block(bi1) @@ -90,13 +90,8 @@ function blockindexrange(axis::BlockedUnitRange, r::UnitRange) return b[i1:i2] end -# Fallback for non-blocked ranges. -function blockindexrange(axis::AbstractUnitRange, r::UnitRange) - return r -end - function blockindexrange( - axes::Tuple{Vararg{BlockedUnitRange,N}}, I::CartesianIndices{N} + axes::Tuple{Vararg{AbstractUnitRange,N}}, I::CartesianIndices{N} ) where {N} brs = blockindexrange.(axes, I.indices) b = Block(block.(brs)) @@ -104,13 +99,6 @@ function blockindexrange( return b[rs...] end -# Fallback for non-blocked ranges. -function blockindexrange( - axes::Tuple{Vararg{AbstractUnitRange,N}}, I::CartesianIndices{N} -) where {N} - return I -end - function blockindexrange(a::AbstractArray, I::CartesianIndices) return blockindexrange(axes(a), I) end @@ -126,7 +114,7 @@ function blockrange(axis::AbstractUnitRange, r::Int) end function blockrange(axis::AbstractUnitRange, r) - return error("Not implemented") + return error("Slicing not implemented for range of type `$(typeof(r))`.") end function cartesianindices(a::AbstractArray, b::Block) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl index 0420fdf80e..4084dc1edc 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl @@ -60,6 +60,10 @@ function union_stored_blocked_cartesianindices(as::Vararg{AbstractArray}) return ∪(stored_blocked_cartesianindices_as...) end +# This is used by `map` to get the output axes. +# This is type piracy, try to avoid this, maybe requires defining `map`. +## Base.promote_shape(a1::Tuple{Vararg{BlockedUnitRange}}, a2::Tuple{Vararg{BlockedUnitRange}}) = combine_axes(a1, a2) + function SparseArrayInterface.sparse_map!( ::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray} ) @@ -83,12 +87,15 @@ end # function SparseArrayInterface.sparse_mapreduce(::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray}) # end -# Map function Base.map!(f, a_dest::AbstractArray, a_srcs::Vararg{BlockSparseArrayLike}) sparse_map!(f, a_dest, a_srcs...) return a_dest end +function Base.map(f, as::Vararg{BlockSparseArrayLike}) + return f.(as...) +end + function Base.copy!(a_dest::AbstractArray, a_src::BlockSparseArrayLike) sparse_copy!(a_dest, a_src) return a_dest diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/broadcast.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/broadcast.jl index 751e5c6c09..7ce8d024ef 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/broadcast.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/broadcast.jl @@ -23,11 +23,9 @@ function Broadcast.BroadcastStyle( return DefaultArrayStyle{N}() end -# TODO: Use `allocate_output`, share logic with `map`. function Base.similar(bc::Broadcasted{<:BlockSparseArrayStyle}, elt::Type) - # TODO: Is this a good definition? Probably should check that - # they have consistent axes. - return similar(first(map_args(bc)), elt) + # TODO: Make sure this handles GPU arrays properly. + return similar(first(map_args(bc)), elt, combine_axes(axes.(map_args(bc))...)) end # Broadcasting implementation From 3bfa911db8c126a9204e14cbba5e8c499bc9948f Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 8 Mar 2024 17:46:49 -0500 Subject: [PATCH 09/17] Fix some more slicing --- .../test/runtests.jl | 1 + .../BlockArraysExtensions.jl | 17 +++++++- .../src/abstractblocksparsearray/broadcast.jl | 43 +++++++++++++++++++ .../src/abstractblocksparsearray/map.jl | 39 ----------------- .../src/abstractblocksparsearray/view.jl | 29 +++++++++++-- .../wrappedabstractblocksparsearray.jl | 13 ++++++ .../blocksparsearrayinterface.jl | 26 ++++++++++- .../GradedAxes/src/abstractgradedunitrange.jl | 10 ++++- 8 files changed, 131 insertions(+), 47 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index 578f9d6857..4736bfd8b3 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -24,6 +24,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test Array(a) == a @test 2 * Array(a) == 2a end + # TODO: Add tests for various slicing operations. @testset "fusedims" begin d1 = gradedrange([U1(0) => 1, U1(1) => 1]) d2 = gradedrange([U1(1) => 1, U1(0) => 1]) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index ce0dee6861..3bd4966092 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -7,6 +7,7 @@ using BlockArrays: BlockedUnitRange, BlockVector, block, + blockaxes, blockedrange, blockindex, blocks, @@ -29,8 +30,8 @@ function sub_unitrange(a::AbstractUnitRange, indices::AbstractUnitRange) end # Outputs a `BlockUnitRange`. -function sub_unitrange(a::AbstractUnitRange, indices::Vector{<:Block}) - return error("Not implemented") +function sub_unitrange(a::AbstractUnitRange, indices::AbstractVector{<:Block}) + return blockedrange([length(a[index]) for index in indices]) end function sub_unitrange(a::AbstractUnitRange, indices::BlockVector{<:Block}) @@ -113,6 +114,18 @@ function blockrange(axis::AbstractUnitRange, r::Int) return findblock(axis, r) end +function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Block{1}}) + for b in r + @assert b ∈ blockaxes(axis, 1) + end + return r +end + +using BlockArrays: BlockSlice +function blockrange(axis::AbstractUnitRange, r::BlockSlice) + return blockrange(axis, r.block) +end + function blockrange(axis::AbstractUnitRange, r) return error("Slicing not implemented for range of type `$(typeof(r))`.") end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/broadcast.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/broadcast.jl index 0d1c942d18..50faf109dc 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/broadcast.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/broadcast.jl @@ -1,5 +1,48 @@ +using BlockArrays: BlockedUnitRange, BlockSlice using Base.Broadcast: Broadcast function Broadcast.BroadcastStyle(arraytype::Type{<:BlockSparseArrayLike}) return BlockSparseArrayStyle{ndims(arraytype)}() end + +# Fix ambiguity error with `BlockArrays`. +function Broadcast.BroadcastStyle( + arraytype::Type{ + <:SubArray{ + <:Any, + <:Any, + <:AbstractBlockSparseArray, + <:Tuple{BlockSlice{<:Any,<:BlockedUnitRange},Vararg{Any}}, + }, + }, +) + return BlockSparseArrayStyle{ndims(arraytype)}() +end +function Broadcast.BroadcastStyle( + arraytype::Type{ + <:SubArray{ + <:Any, + <:Any, + <:AbstractBlockSparseArray, + <:Tuple{ + BlockSlice{<:Any,<:BlockedUnitRange}, + BlockSlice{<:Any,<:BlockedUnitRange}, + Vararg{Any}, + }, + }, + }, +) + return BlockSparseArrayStyle{ndims(arraytype)}() +end +function Broadcast.BroadcastStyle( + arraytype::Type{ + <:SubArray{ + <:Any, + <:Any, + <:AbstractBlockSparseArray, + <:Tuple{Any,BlockSlice{<:Any,<:BlockedUnitRange},Vararg{Any}}, + }, + }, +) + return BlockSparseArrayStyle{ndims(arraytype)}() +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl index 4084dc1edc..2d22efd277 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl @@ -11,45 +11,6 @@ using ..SparseArrayInterface: sparse_iszero, sparse_isreal -## using BlockArrays: BlockArrays, BlockRange, block - -## _block(indices) = block(indices) -## _block(indices::CartesianIndices) = Block(ntuple(Returns(1), ndims(indices))) -## -## function combine_axes(as::Vararg{Tuple}) -## @assert allequal(length.(as)) -## ndims = length(first(as)) -## return ntuple(ndims) do dim -## dim_axes = map(a -> a[dim], as) -## return reduce(BlockArrays.combine_blockaxes, dim_axes) -## end -## end -## -## # Returns `BlockRange` -## # Convert the block of the axes to blocks of the subaxes. -## function subblocks(axes::Tuple, subaxes::Tuple, block::Block) -## @assert length(axes) == length(subaxes) -## return BlockRange( -## ntuple(length(axes)) do dim -## findblocks(subaxes[dim], axes[dim][Tuple(block)[dim]]) -## end, -## ) -## end -## -## # Returns `Vector{<:Block}` -## function subblocks(axes::Tuple, subaxes::Tuple, blocks) -## return mapreduce(vcat, blocks; init=eltype(blocks)[]) do block -## return vec(subblocks(axes, subaxes, block)) -## end -## end -## -## # Returns `Vector{<:CartesianIndices}` -## function stored_blocked_cartesianindices(a::AbstractArray, subaxes::Tuple) -## return map(subblocks(axes(a), subaxes, block_stored_indices(a))) do block -## return cartesianindices(subaxes, block) -## end -## end - # Returns `Vector{<:CartesianIndices}` function union_stored_blocked_cartesianindices(as::Vararg{AbstractArray}) stored_blocked_cartesianindices_as = map(as) do a diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/view.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/view.jl index a9d389f04c..e2e5c8acb9 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/view.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/view.jl @@ -1,12 +1,35 @@ -using BlockArrays: BlockIndexRange, block +using BlockArrays: BlockIndexRange, BlockRange, BlockSlice, block -function Base.view(a::BlockSparseArrayLike{<:Any,N}, index::Block{N}) where {N} +function blocksparse_view(a::AbstractArray, index::Block) return blocks(a)[Int.(Tuple(index))...] end +# TODO: Define `AnyBlockSparseVector`. +function Base.view(a::BlockSparseArrayLike{<:Any,N}, index::Block{N}) where {N} + return blocksparse_view(a, index) +end + +# Fix ambiguity error with `BlockArrays`. +function Base.view( + a::SubArray{ + <:Any, + N, + <:AbstractBlockSparseArray, + <:Tuple{ + Vararg{ + Union{Base.Slice,BlockSlice{<:BlockRange{1,<:Tuple{AbstractUnitRange{Int}}}}},N + }, + }, + }, + index::Block{N}, +) where {N} + return blocksparse_view(a, index) +end + +# Fix ambiguity error with `BlockArrays`. # TODO: Define `AnyBlockSparseVector`. function Base.view(a::BlockSparseArrayLike{<:Any,1}, index::Block{1}) - return blocks(a)[Int.(Tuple(index))...] + return blocksparse_view(a, index) end function Base.view(a::BlockSparseArrayLike, indices::BlockIndexRange) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index bdbebdc60f..3bbb7166b3 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -14,6 +14,7 @@ const WrappedAbstractBlockSparseArray{T,N,A} = WrappedArray{ T,N,<:AbstractBlockSparseArray,<:AbstractBlockSparseArray{T,N} } +# TODO: Rename `AnyBlockSparseArray`. const BlockSparseArrayLike{T,N} = Union{ <:AbstractBlockSparseArray{T,N},<:WrappedAbstractBlockSparseArray{T,N} } @@ -28,6 +29,14 @@ end # BlockArrays `AbstractBlockArray` interface BlockArrays.blocks(a::BlockSparseArrayLike) = blocksparse_blocks(a) +# Fix ambiguity error with `BlockArrays` +using BlockArrays: BlockSlice +function BlockArrays.blocks( + a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{BlockSlice}}} +) + return blocksparse_blocks(a) +end + # TODO: Use `TypeParameterAccessors`. function blockstype(arraytype::Type{<:WrappedAbstractBlockSparseArray}) return blockstype(Adapt.parent_type(arraytype)) @@ -53,6 +62,10 @@ function Base.getindex(a::BlockSparseArrayLike{<:Any,2}, I::Vararg{AbstractUnitR return ArrayLayouts.layout_getindex(a, I...) end +function Base.isassigned(a::BlockSparseArrayLike, index::Vararg{Block}) + return isassigned(blocks(a), Int.(index)...) +end + function Base.setindex!(a::BlockSparseArrayLike{<:Any,N}, value, I::BlockIndex{N}) where {N} blocksparse_setindex!(a, value, I) return a diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 77d00e214b..d50513f34e 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -120,6 +120,8 @@ function SparseArrayInterface.sparse_storage(a::SparsePermutedDimsArrayBlocks) end # TODO: Move to `BlockArraysExtensions`. +# This takes a range of indices `indices` of array `a` +# and maps it to the range of indices within block `block`. function blockindices(a::AbstractArray, block::Block, indices::Tuple) return blockindices(axes(a), block, indices) end @@ -130,7 +132,9 @@ function blockindices(axes::Tuple, block::Block, indices::Tuple) end # TODO: Move to `BlockArraysExtensions`. -function blockindices(axis::AbstractUnitRange, block::Block, indices) +function blockindices(axis::AbstractUnitRange, block::Block, indices::AbstractUnitRange) + @show indices + error() indices_within_block = intersect(indices, axis[block]) if iszero(length(indices_within_block)) # Falls outside of block @@ -139,6 +143,21 @@ function blockindices(axis::AbstractUnitRange, block::Block, indices) return only(blockindexrange(axis, indices_within_block).indices) end +# This catches the case of `Vector{<:Block{1}}`. +# `BlockRange` gets wrapped in a `BlockSlice`, which is handled properly +# by the version with `indices::AbstractUnitRange`. +# TODO: This should get fixed in a better way inside of `BlockArrays`. +function blockindices( + axis::AbstractUnitRange, block::Block, indices::AbstractVector{<:Block{1}} +) + error() + if block ∉ indices + # Falls outside of block + return 1:0 + end + return Base.OneTo(length(axis[block])) +end + # Represents the array of arrays of a `SubArray` # wrapping a block spare array, i.e. `blocks(array)` where `a` is a `SubArray`. struct SparseSubArrayBlocks{T,N,Array<:SubArray{T,N}} <: AbstractSparseArray{T,N} @@ -162,7 +181,7 @@ function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::CartesianIndex{N}) w return a[Tuple(I)...] end function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where {N} - parent_blocks = @view blocks(parent(a.array))[axes(a)...] + parent_blocks = @view blocks(parent(a.array))[blockrange(a)...] parent_block = parent_blocks[I...] # TODO: Define this using `blockrange(a::AbstractArray, indices::Tuple{Vararg{AbstractUnitRange}})`. block = Block(ntuple(i -> blockrange(a)[i][I[i]], ndims(a))) @@ -174,6 +193,9 @@ function Base.setindex!(a::SparseSubArrayBlocks{<:Any,N}, value, I::Vararg{Int,N value end function Base.isassigned(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where {N} + if CartesianIndex(I) ∉ CartesianIndices(a) + return false + end # TODO: Implement this properly. return true end diff --git a/NDTensors/src/lib/GradedAxes/src/abstractgradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/abstractgradedunitrange.jl index 1953230045..224874b377 100644 --- a/NDTensors/src/lib/GradedAxes/src/abstractgradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/abstractgradedunitrange.jl @@ -117,12 +117,20 @@ function blockmergesortperm(a::AbstractGradedUnitRange) return Block.(groupsortperm(nondual_sectors(a); rev=isdual(a))) end -function Base.getindex(a::AbstractGradedUnitRange, I::AbstractVector{<:Block}) +function block_getindex(a::AbstractGradedUnitRange, I::AbstractVector{<:Block{1}}) nondual_sectors_sub = map(b -> nondual_sector(a, b), I) blocklengths_sub = map(b -> length(a, b), I) return gradedrange(nondual_sectors_sub, blocklengths_sub, isdual(a)) end +function Base.getindex(a::AbstractGradedUnitRange, I::AbstractVector{<:Block{1}}) + return block_getindex(a, I) +end + +function Base.getindex(a::AbstractGradedUnitRange, I::BlockRange{1}) + return block_getindex(a, I) +end + function Base.getindex( a::AbstractGradedUnitRange, grouped_perm::AbstractBlockVector{<:Block} ) From 280c37a881b9e9f8bd6fdbeeadf8800f526b9e41 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 21 Mar 2024 16:27:43 -0400 Subject: [PATCH 10/17] Fix some tests --- .../blocksparsearrayinterface/blocksparsearrayinterface.jl | 3 --- NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl | 7 +++++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 4f4c42a168..f7ebc9750e 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -129,8 +129,6 @@ end # TODO: Move to `BlockArraysExtensions`. function blockindices(axis::AbstractUnitRange, block::Block, indices::AbstractUnitRange) - @show indices - error() indices_within_block = intersect(indices, axis[block]) if iszero(length(indices_within_block)) # Falls outside of block @@ -146,7 +144,6 @@ end function blockindices( axis::AbstractUnitRange, block::Block, indices::AbstractVector{<:Block{1}} ) - error() if block ∉ indices # Falls outside of block return 1:0 diff --git a/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl b/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl index 18c6859500..40d655985c 100644 --- a/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl +++ b/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl @@ -11,6 +11,13 @@ labelled(object::AbstractUnitRange, label) = LabelledUnitRange(object, label) unlabel(lobject::LabelledUnitRange) = lobject.value unlabel_type(::Type{<:LabelledUnitRange{Value}}) where {Value} = Value +# Used by `CartesianIndices` constructor. +# TODO: Maybe reconsider this definition? Also, this should preserve +# the label if possible, currently it drops the label. +function Base.AbstractUnitRange{T}(a::LabelledUnitRange) where {T} + return AbstractUnitRange{T}(unlabel(a)) +end + for f in [:first, :getindex, :last, :length, :step] @eval Base.$f(a::LabelledUnitRange, args...) = labelled($f(unlabel(a), args...), label(a)) end From 19160b32c1997c149fb2e7b4457ca61de009af22 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 21 Mar 2024 17:16:25 -0400 Subject: [PATCH 11/17] Fix some tests --- .../ext/BlockSparseArraysGradedAxesExt/test/runtests.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index 4736bfd8b3..e86911ebe2 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -17,7 +17,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "BlockSparseArraysGradedAxesExt (eltype=$elt)" for elt in elts @testset "map" begin d1 = gradedrange([U1(0) => 1, U1(1) => 1]) - d2 = gradedrange([U1(1) => 1, U1(0) => 1]) + d2 = gradedrange([U1(0) => 1, U1(1) => 1]) a = BlockSparseArray{elt}(d1, d2, d1, d2) blockdiagonal!(randn!, a) @test Array(a) isa Array{elt} @@ -27,12 +27,12 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) # TODO: Add tests for various slicing operations. @testset "fusedims" begin d1 = gradedrange([U1(0) => 1, U1(1) => 1]) - d2 = gradedrange([U1(1) => 1, U1(0) => 1]) + d2 = gradedrange([U1(0) => 1, U1(1) => 1]) a = BlockSparseArray{elt}(d1, d2, d1, d2) blockdiagonal!(randn!, a) m = fusedims(a, (1, 2), (3, 4)) - @test a[1, 1, 1, 1] == m[2, 2] - @test a[2, 2, 2, 2] == m[3, 3] + @test a[1, 1, 1, 1] == m[1, 1] + @test a[2, 2, 2, 2] == m[4, 4] # TODO: Current `fusedims` doesn't merge # common sectors, need to fix. @test_broken blocksize(m) == (3, 3) From bea467c96509d5fc75683bfb56c9e81e19e478d3 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 21 Mar 2024 17:44:53 -0400 Subject: [PATCH 12/17] Julia 1.6 compatibility --- .../src/BlockArraysExtensions/BlockArraysExtensions.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index 3bd4966092..1ddcaf4678 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -13,6 +13,7 @@ using BlockArrays: blocks, findblock, findblockindex +using Compat: allequal using Dictionaries: Dictionary, Indices using ..SparseArrayInterface: stored_indices From e7cb936d46e949d3e454420fed38c2413567917a Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 21 Mar 2024 21:22:59 -0400 Subject: [PATCH 13/17] Fix some slicing, add some tests --- .../test/runtests.jl | 39 ++++++++++++++++--- .../BlockArraysExtensions.jl | 14 ++++--- .../src/lib/GradedAxes/src/gradedunitrange.jl | 19 ++++++--- 3 files changed, 56 insertions(+), 16 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index e86911ebe2..adf7f60131 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -1,9 +1,11 @@ @eval module $(gensym()) using Test: @test, @testset, @test_broken using BlockArrays: Block, blocksize -using NDTensors.BlockSparseArrays: BlockSparseArray -using NDTensors.GradedAxes: gradedrange +using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored +using NDTensors.GradedAxes: GradedUnitRange, gradedrange +using NDTensors.LabelledNumbers: label using NDTensors.Sectors: U1 +using NDTensors.SparseArrayInterface: nstored using NDTensors.TensorAlgebra: fusedims, splitdims using Random: randn! function blockdiagonal!(f, a::AbstractArray) @@ -16,13 +18,38 @@ end const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "BlockSparseArraysGradedAxesExt (eltype=$elt)" for elt in elts @testset "map" begin - d1 = gradedrange([U1(0) => 1, U1(1) => 1]) - d2 = gradedrange([U1(0) => 1, U1(1) => 1]) + d1 = gradedrange([U1(0) => 2, U1(1) => 2]) + d2 = gradedrange([U1(0) => 2, U1(1) => 2]) a = BlockSparseArray{elt}(d1, d2, d1, d2) blockdiagonal!(randn!, a) + + for b in (a + a, 2 * a) + @test size(b) == (4, 4, 4, 4) + @test blocksize(b) == (2, 2, 2, 2) + @test nstored(b) == 32 + @test block_nstored(b) == 2 + for i in 1:ndims(a) + @test axes(b, i) isa GradedUnitRange + end + @test label(axes(b, 1)[Block(1)]) == U1(0) + @test label(axes(b, 1)[Block(2)]) == U1(1) + @test Array(a) isa Array{elt} + @test Array(a) == a + @test 2 * Array(a) == b + end + + b = a[2:3, 2:3, 2:3, 2:3] + @test size(b) == (2, 2, 2, 2) + @test blocksize(b) == (2, 2, 2, 2) + @test nstored(b) == 2 + @test block_nstored(b) == 2 + for i in 1:ndims(a) + @test axes(b, i) isa GradedUnitRange + end + @test label(axes(b, 1)[Block(1)]) == U1(0) + @test label(axes(b, 1)[Block(2)]) == U1(1) @test Array(a) isa Array{elt} @test Array(a) == a - @test 2 * Array(a) == 2a end # TODO: Add tests for various slicing operations. @testset "fusedims" begin @@ -31,6 +58,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) a = BlockSparseArray{elt}(d1, d2, d1, d2) blockdiagonal!(randn!, a) m = fusedims(a, (1, 2), (3, 4)) + @test axes(m, 1) isa GradedUnitRange + @test axes(m, 2) isa GradedUnitRange @test a[1, 1, 1, 1] == m[1, 1] @test a[2, 2, 2, 2] == m[4, 4] # TODO: Current `fusedims` doesn't merge diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index 1ddcaf4678..ca59099b44 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -15,6 +15,7 @@ using BlockArrays: findblockindex using Compat: allequal using Dictionaries: Dictionary, Indices +using ..GradedAxes: blockedunitrange_getindices using ..SparseArrayInterface: stored_indices # Outputs a `BlockUnitRange`. @@ -22,21 +23,24 @@ function sub_unitrange(a::AbstractUnitRange, indices) return error("Not implemented") end -# TODO: Write a specialized version for `indices::AbstractUnitRange`. -# TODO: Write a generic version for blocked unit ranges (like `GradedAxes`). +# TODO: Use `GradedAxes.blockedunitrange_getindices`. # Outputs a `BlockUnitRange`. function sub_unitrange(a::AbstractUnitRange, indices::AbstractUnitRange) - indices = sort(indices) - return blockedrange(collect(groupcount(i -> findblock(a, i), indices))) + return only(axes(blockedunitrange_getindices(a, indices))) end +# TODO: Use `GradedAxes.blockedunitrange_getindices`. # Outputs a `BlockUnitRange`. function sub_unitrange(a::AbstractUnitRange, indices::AbstractVector{<:Block}) return blockedrange([length(a[index]) for index in indices]) end +# TODO: Use `GradedAxes.blockedunitrange_getindices`. +# TODO: Merge blocks. function sub_unitrange(a::AbstractUnitRange, indices::BlockVector{<:Block}) - return blockedrange([length(a[index]) for index in indices]) + # `collect` is needed here, otherwise a `PseudoBlockVector` is + # constructed. + return blockedrange([length(a[index]) for index in collect(indices)]) end # TODO: Use `Tuple` conversion once diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index 6251cc19ca..c10d5052e3 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -25,17 +25,19 @@ function blockedunitrange(a::AbstractUnitRange, blocklengths) return BlockArrays._BlockedUnitRange(first(a), blocklasts) end -# Circumvents issue in `findblock` that assumes the `BlockedUnitRange` -# starts at 1. -# TODO: Raise an issue with `BlockArrays`. +# TODO: Move this to a `BlockArraysExtensions` library. +# TODO: Rename this. `BlockArrays.findblock(a, k)` finds the +# block of the value `k`, while this finds the block of the index `k`. +# This could make use of the `BlockIndices` object, i.e. `block(BlockIndices(a)[index])`. function blockedunitrange_findblock(a::BlockedUnitRange, index::Integer) @boundscheck index in 1:length(a) || throw(BoundsError(a, index)) return @inbounds findblock(a, index + first(a) - 1) end -# Circumvents issue in `findblockindex` that assumes the `BlockedUnitRange` -# starts at 1. -# TODO: Raise an issue with `BlockArrays`. +# TODO: Move this to a `BlockArraysExtensions` library. +# TODO: Rename this. `BlockArrays.findblockindex(a, k)` finds the +# block index of the value `k`, while this finds the block index of the index `k`. +# This could make use of the `BlockIndices` object, i.e. `BlockIndices(a)[index]`. function blockedunitrange_findblockindex(a::BlockedUnitRange, index::Integer) @boundscheck index in 1:length(a) || throw(BoundsError()) return @inbounds findblockindex(a, index + first(a) - 1) @@ -169,6 +171,7 @@ function blockedunitrange_getindex(a::GradedUnitRange, index) return labelled(unlabel_blocks(a)[index], get_label(a, index)) end +# TODO: Move this to a `BlockArraysExtensions` library. # Like `a[indices]` but preserves block structure. using BlockArrays: block, blockindex function blockedunitrange_getindices( @@ -194,20 +197,24 @@ function blockedunitrange_getindices( return blockedunitrange(indices .+ (first(a) - 1), blocklengths) end +# TODO: Move this to a `BlockArraysExtensions` library. function blockedunitrange_getindices(a::BlockedUnitRange, indices::BlockIndexRange) return a[block(indices)][only(indices.indices)] end +# TODO: Move this to a `BlockArraysExtensions` library. function blockedunitrange_getindices(a::BlockedUnitRange, indices::Vector{<:Integer}) return map(index -> a[index], indices) end +# TODO: Move this to a `BlockArraysExtensions` library. function blockedunitrange_getindices( a::BlockedUnitRange, indices::Vector{<:Union{Block{1},BlockIndexRange{1}}} ) return mortar(map(index -> a[index], indices)) end +# TODO: Move this to a `BlockArraysExtensions` library. function blockedunitrange_getindices(a::BlockedUnitRange, indices) return error("Not implemented.") end From bf5e92c80932fd3e288f002943f4b70991c4182f Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 21 Mar 2024 21:44:02 -0400 Subject: [PATCH 14/17] Rename sub_unitrange to sub_axis --- .../src/BlockArraysExtensions/BlockArraysExtensions.jl | 8 ++++---- .../wrappedabstractblocksparsearray.jl | 9 +-------- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index ca59099b44..04f37f0f18 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -19,25 +19,25 @@ using ..GradedAxes: blockedunitrange_getindices using ..SparseArrayInterface: stored_indices # Outputs a `BlockUnitRange`. -function sub_unitrange(a::AbstractUnitRange, indices) +function sub_axis(a::AbstractUnitRange, indices) return error("Not implemented") end # TODO: Use `GradedAxes.blockedunitrange_getindices`. # Outputs a `BlockUnitRange`. -function sub_unitrange(a::AbstractUnitRange, indices::AbstractUnitRange) +function sub_axis(a::AbstractUnitRange, indices::AbstractUnitRange) return only(axes(blockedunitrange_getindices(a, indices))) end # TODO: Use `GradedAxes.blockedunitrange_getindices`. # Outputs a `BlockUnitRange`. -function sub_unitrange(a::AbstractUnitRange, indices::AbstractVector{<:Block}) +function sub_axis(a::AbstractUnitRange, indices::AbstractVector{<:Block}) return blockedrange([length(a[index]) for index in indices]) end # TODO: Use `GradedAxes.blockedunitrange_getindices`. # TODO: Merge blocks. -function sub_unitrange(a::AbstractUnitRange, indices::BlockVector{<:Block}) +function sub_axis(a::AbstractUnitRange, indices::BlockVector{<:Block}) # `collect` is needed here, otherwise a `PseudoBlockVector` is # constructed. return blockedrange([length(a[index]) for index in collect(indices)]) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index 3bbb7166b3..73248ffc83 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -1,13 +1,6 @@ using BlockArrays: BlockedUnitRange, blockedrange using SplitApplyCombine: groupcount -## # TODO: Write a specialized version for `indices::AbstractUnitRange`. -## # TODO: Write a generic version for blocked unit ranges (like `GradedAxes`). -## function sub_unitrange(a::BlockedUnitRange, indices) -## indices = sort(indices) -## return blockedrange(collect(groupcount(i -> findblock(a, i), indices))) -## end - using Adapt: Adapt, WrappedArray const WrappedAbstractBlockSparseArray{T,N,A} = WrappedArray{ @@ -23,7 +16,7 @@ const BlockSparseArrayLike{T,N} = Union{ # TODO: Use `BlockSparseArrayLike`. # TODO: Need to handle block indexing. function Base.axes(a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray}) - return ntuple(i -> sub_unitrange(axes(parent(a), i), a.indices[i]), ndims(a)) + return ntuple(i -> sub_axis(axes(parent(a), i), a.indices[i]), ndims(a)) end # BlockArrays `AbstractBlockArray` interface From 44c27a47b51701705ab42f66dc80ebb78a32a0e4 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 22 Mar 2024 07:55:51 -0400 Subject: [PATCH 15/17] Fixes for Julia v1.6 --- .../BlockSparseArraysGradedAxesExt/test/runtests.jl | 13 +++++++++---- .../wrappedabstractblocksparsearray.jl | 4 ++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index adf7f60131..c3820b4ec8 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -1,4 +1,5 @@ @eval module $(gensym()) +using Compat: Returns using Test: @test, @testset, @test_broken using BlockArrays: Block, blocksize using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored @@ -28,11 +29,15 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test blocksize(b) == (2, 2, 2, 2) @test nstored(b) == 32 @test block_nstored(b) == 2 - for i in 1:ndims(a) - @test axes(b, i) isa GradedUnitRange + if VERSION >= v"1.7" + # TODO: Have to investigate why this fails + # on Julia v1.6, or drop support for v1.6. + for i in 1:ndims(a) + @test axes(b, i) isa GradedUnitRange + end + @test label(axes(b, 1)[Block(1)]) == U1(0) + @test label(axes(b, 1)[Block(2)]) == U1(1) end - @test label(axes(b, 1)[Block(1)]) == U1(0) - @test label(axes(b, 1)[Block(2)]) == U1(1) @test Array(a) isa Array{elt} @test Array(a) == a @test 2 * Array(a) == b diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index 73248ffc83..dc40010526 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -30,9 +30,9 @@ function BlockArrays.blocks( return blocksparse_blocks(a) end -# TODO: Use `TypeParameterAccessors`. +using ..TypeParameterAccessors: parenttype function blockstype(arraytype::Type{<:WrappedAbstractBlockSparseArray}) - return blockstype(Adapt.parent_type(arraytype)) + return blockstype(parenttype(arraytype)) end blocktype(a::BlockSparseArrayLike) = eltype(blocks(a)) From 3b0efef7b3daa7ed25183571c3c611b564506964 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 22 Mar 2024 12:54:16 -0400 Subject: [PATCH 16/17] Fix tests on Julia v1.6 --- .../test/runtests.jl | 14 ++++----- .../LabelledNumbers/src/labelled_interface.jl | 14 ++++++--- .../LabelledNumbers/src/labelledinteger.jl | 30 ++++++++++++------- .../lib/LabelledNumbers/src/labellednumber.jl | 10 ------- .../LabelledNumbers/src/labelledunitrange.jl | 6 ++++ .../src/lib/LabelledNumbers/test/runtests.jl | 19 +++++++++++- 6 files changed, 60 insertions(+), 33 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index c3820b4ec8..a7a339bea0 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -29,15 +29,13 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test blocksize(b) == (2, 2, 2, 2) @test nstored(b) == 32 @test block_nstored(b) == 2 - if VERSION >= v"1.7" - # TODO: Have to investigate why this fails - # on Julia v1.6, or drop support for v1.6. - for i in 1:ndims(a) - @test axes(b, i) isa GradedUnitRange - end - @test label(axes(b, 1)[Block(1)]) == U1(0) - @test label(axes(b, 1)[Block(2)]) == U1(1) + # TODO: Have to investigate why this fails + # on Julia v1.6, or drop support for v1.6. + for i in 1:ndims(a) + @test axes(b, i) isa GradedUnitRange end + @test label(axes(b, 1)[Block(1)]) == U1(0) + @test label(axes(b, 1)[Block(2)]) == U1(1) @test Array(a) isa Array{elt} @test Array(a) == a @test 2 * Array(a) == b diff --git a/NDTensors/src/lib/LabelledNumbers/src/labelled_interface.jl b/NDTensors/src/lib/LabelledNumbers/src/labelled_interface.jl index f695db9980..6cc65c46d6 100644 --- a/NDTensors/src/lib/LabelledNumbers/src/labelled_interface.jl +++ b/NDTensors/src/lib/LabelledNumbers/src/labelled_interface.jl @@ -33,10 +33,16 @@ labelled_oneunit(x) = set_value(x, one(x)) # encoded in the type. labelled_oneunit(type::Type) = error("Not implemented.") -labelled_mul(x, y) = labelled_mul(LabelledStyle(x), x, LabelledStyle(y), y) -labelled_mul(::IsLabelled, x, ::IsLabelled, y) = unlabel(x) * unlabel(y) -labelled_mul(::IsLabelled, x, ::NotLabelled, y) = set_value(x, unlabel(x) * y) -labelled_mul(::NotLabelled, x, ::IsLabelled, y) = set_value(y, x * unlabel(y)) +labelled_mul(x, y) = labelled_binary_op(*, x, y) +labelled_add(x, y) = labelled_binary_op(+, x, y) #labelled_add(LabelledStyle(x), x, LabelledStyle(y), y) +labelled_minus(x, y) = labelled_binary_op(-, x, y) #labelled_add(LabelledStyle(x), x, LabelledStyle(y), y) + +function labelled_binary_op(f, x, y) + return labelled_binary_op(f, LabelledStyle(x), x, LabelledStyle(y), y) +end +labelled_binary_op(f, ::IsLabelled, x, ::IsLabelled, y) = f(unlabel(x), unlabel(y)) +labelled_binary_op(f, ::IsLabelled, x, ::NotLabelled, y) = set_value(x, f(unlabel(x), y)) +labelled_binary_op(f, ::NotLabelled, x, ::IsLabelled, y) = set_value(y, f(x, unlabel(y))) # TODO: This is only needed for older Julia versions, like Julia 1.6. # Delete once we drop support for older Julia versions. diff --git a/NDTensors/src/lib/LabelledNumbers/src/labelledinteger.jl b/NDTensors/src/lib/LabelledNumbers/src/labelledinteger.jl index 65eab8b37c..f5e2d58f3d 100644 --- a/NDTensors/src/lib/LabelledNumbers/src/labelledinteger.jl +++ b/NDTensors/src/lib/LabelledNumbers/src/labelledinteger.jl @@ -23,16 +23,6 @@ Base.convert(type::Type{<:Number}, x::LabelledInteger) = type(unlabel(x)) function Base.convert(type::Type{<:LabelledInteger}, x::LabelledInteger) return type(unlabel(x), label(x)) end -# TODO: Define `labelled_promote_type`. -function Base.promote_type(type1::Type{T}, type2::Type{T}) where {T<:LabelledInteger} - return promote_type(unlabel_type(type1), unlabel_type(type2)) -end -function Base.promote_rule(type1::Type{<:LabelledInteger}, type2::Type{<:LabelledInteger}) - return promote_type(unlabel_type(type1), unlabel_type(type2)) -end -function Base.promote_rule(type1::Type{<:LabelledInteger}, type2::Type{<:Number}) - return promote_type(unlabel_type(type1), type2) -end # Used by `Base.hash(::Integer)`. # TODO: Define `labelled_trailing_zeros` to be used by other @@ -45,6 +35,8 @@ Base.trailing_zeros(x::LabelledInteger) = trailing_zeros(unlabel(x)) Base.:>>(x::LabelledInteger, y::Int) = >>(unlabel(x), y) Base.:(==)(x::LabelledInteger, y::LabelledInteger) = labelled_isequal(x, y) +Base.:(==)(x::LabelledInteger, y::Number) = labelled_isequal(x, y) +Base.:(==)(x::Number, y::LabelledInteger) = labelled_isequal(x, y) Base.:<(x::LabelledInteger, y::LabelledInteger) = labelled_isless(x, y) # TODO: Define `labelled_colon`. (::Base.Colon)(start::LabelledInteger, stop::LabelledInteger) = unlabel(start):unlabel(stop) @@ -56,6 +48,24 @@ Base.oneunit(type::Type{<:LabelledInteger}) = error("Not implemented.") Base.Int(x::LabelledInteger) = Int(unlabel(x)) +Base.:+(x::LabelledInteger, y::LabelledInteger) = labelled_add(x, y) +Base.:+(x::LabelledInteger, y::Number) = labelled_add(x, y) +Base.:+(x::Number, y::LabelledInteger) = labelled_add(x, y) +# Fix ambiguity error with `+(::Integer, ::Integer)`. +Base.:+(x::LabelledInteger, y::Integer) = labelled_add(x, y) +Base.:+(x::Integer, y::LabelledInteger) = labelled_add(x, y) + +Base.:-(x::LabelledInteger, y::LabelledInteger) = labelled_minus(x, y) +Base.:-(x::LabelledInteger, y::Number) = labelled_minus(x, y) +Base.:-(x::Number, y::LabelledInteger) = labelled_minus(x, y) +# Fix ambiguity error with `-(::Integer, ::Integer)`. +Base.:-(x::LabelledInteger, y::Integer) = labelled_minus(x, y) +Base.:-(x::Integer, y::LabelledInteger) = labelled_minus(x, y) + +function Base.sub_with_overflow(x::LabelledInteger, y::LabelledInteger) + return labelled_binary_op(Base.sub_with_overflow, x, y) +end + Base.:*(x::LabelledInteger, y::LabelledInteger) = labelled_mul(x, y) Base.:*(x::LabelledInteger, y::Number) = labelled_mul(x, y) Base.:*(x::Number, y::LabelledInteger) = labelled_mul(x, y) diff --git a/NDTensors/src/lib/LabelledNumbers/src/labellednumber.jl b/NDTensors/src/lib/LabelledNumbers/src/labellednumber.jl index 19887c9482..09a30a456b 100644 --- a/NDTensors/src/lib/LabelledNumbers/src/labellednumber.jl +++ b/NDTensors/src/lib/LabelledNumbers/src/labellednumber.jl @@ -12,16 +12,6 @@ unlabel_type(::Type{<:LabelledNumber{Value}}) where {Value} = Value # TODO: Define `labelled_convert`. Base.convert(type::Type{<:Number}, x::LabelledNumber) = type(unlabel(x)) -# TODO: Define `labelled_promote_type`. -function Base.promote_type(type1::Type{T}, type2::Type{T}) where {T<:LabelledNumber} - return promote_type(unlabel_type(type1), unlabel_type(type2)) -end -function Base.promote_rule(type1::Type{<:LabelledNumber}, type2::Type{<:LabelledNumber}) - return promote_type(unlabel_type(type1), unlabel_type(type2)) -end -function Base.promote_rule(type1::Type{<:LabelledNumber}, type2::Type{<:Number}) - return promote_type(unlabel_type(type1), type2) -end Base.:(==)(x::LabelledNumber, y::LabelledNumber) = labelled_isequal(x, y) Base.:<(x::LabelledNumber, y::LabelledNumber) = labelled_isless(x < y) diff --git a/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl b/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl index 40d655985c..0a4ffa2f5b 100644 --- a/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl +++ b/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl @@ -17,6 +17,12 @@ unlabel_type(::Type{<:LabelledUnitRange{Value}}) where {Value} = Value function Base.AbstractUnitRange{T}(a::LabelledUnitRange) where {T} return AbstractUnitRange{T}(unlabel(a)) end +# Used by `CartesianIndices` constructor. +# TODO: Seems to only be needed for Julia v1.6, maybe remove once we +# drop Julia v1.6 support. +function Base.OrdinalRange{T1,T2}(a::LabelledUnitRange) where {T1,T2<:Integer} + return OrdinalRange{T1,T2}(unlabel(a)) +end for f in [:first, :getindex, :last, :length, :step] @eval Base.$f(a::LabelledUnitRange, args...) = labelled($f(unlabel(a), args...), label(a)) diff --git a/NDTensors/src/lib/LabelledNumbers/test/runtests.jl b/NDTensors/src/lib/LabelledNumbers/test/runtests.jl index b77d7fd968..bfb6983e79 100644 --- a/NDTensors/src/lib/LabelledNumbers/test/runtests.jl +++ b/NDTensors/src/lib/LabelledNumbers/test/runtests.jl @@ -12,9 +12,26 @@ using Test: @test, @testset @test !islabelled(unlabel(x)) @test x * 2 == 4 - @test 2 * x == 4 @test label(x * 2) == "x" + @test 2 * x == 4 @test label(2 * x) == "x" + @test x * x == 4 + @test !islabelled(x * x) + + @test x + 3 == 5 + @test label(x + 3) == "x" + @test 3 + x == 5 + @test label(3 + x) == "x" + @test x + x == 4 + @test !islabelled(x + x) + + @test x - 3 == -1 + @test label(x - 3) == "x" + @test 3 - x == 1 + @test label(3 - x) == "x" + @test x - x == 0 + @test !islabelled(x - x) + @test x / 2 == 1 @test label(x / 2) == "x" @test x ÷ 2 == 1 From f137b8dff936faae446a0856f33d6d9aa4a49ca4 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 22 Mar 2024 14:05:38 -0400 Subject: [PATCH 17/17] Fix issue in Julia v1.10 --- NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl b/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl index 0a4ffa2f5b..62a0ddebdf 100644 --- a/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl +++ b/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl @@ -23,6 +23,10 @@ end function Base.OrdinalRange{T1,T2}(a::LabelledUnitRange) where {T1,T2<:Integer} return OrdinalRange{T1,T2}(unlabel(a)) end +# Fix ambiguity error in Julia v1.10. +function Base.OrdinalRange{T,T}(a::LabelledUnitRange) where {T<:Integer} + return OrdinalRange{T,T}(unlabel(a)) +end for f in [:first, :getindex, :last, :length, :step] @eval Base.$f(a::LabelledUnitRange, args...) = labelled($f(unlabel(a), args...), label(a))