From ad48a28d545d520df29a4d34bdc585ff0a7e3a69 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 3 Feb 2025 18:07:46 -0500 Subject: [PATCH 01/15] Construct BlockSparseArray when slicing with graded unit ranges --- Project.toml | 1 + .../BlockSparseArraysGradedUnitRangesExt.jl | 44 +++++++++++++++++++ src/abstractblocksparsearray/map.jl | 32 +++++++++++--- src/blocksparsearrayinterface/broadcast.jl | 42 +++++++++++++++++- 4 files changed, 112 insertions(+), 7 deletions(-) create mode 100644 ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl diff --git a/Project.toml b/Project.toml index 1c7e459c..0963a425 100644 --- a/Project.toml +++ b/Project.toml @@ -25,6 +25,7 @@ LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993" TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" [extensions] +BlockSparseArraysGradedUnitRangesExt = "GradedUnitRanges" BlockSparseArraysTensorAlgebraExt = ["LabelledNumbers", "TensorAlgebra"] [compat] diff --git a/ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl b/ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl new file mode 100644 index 00000000..4709c0ae --- /dev/null +++ b/ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl @@ -0,0 +1,44 @@ +module BlockSparseArraysGradedUnitRangesExt + +using BlockSparseArrays: BlockSparseArray +using GradedUnitRanges: AbstractGradedUnitRange + +# A block spare array similar to the input (dense) array. +# TODO: Make `BlockSparseArrays.blocksparse_similar` more general and use that, +# and also turn it into an DerivableInterfaces.jl-based interface function. +function similar_blocksparse( + a::AbstractArray, + elt::Type, + axes::Tuple{AbstractGradedUnitRange,Vararg{AbstractGradedUnitRange}}, +) + # TODO: Probably need to unwrap the type of `a` in certain cases + # to make a proper block type. + return BlockSparseArray{elt,length(axes),typeof(a)}(axes) +end + +function Base.similar( + a::AbstractArray, + elt::Type, + axes::Tuple{AbstractGradedUnitRange,Vararg{AbstractGradedUnitRange}}, +) + return similar_blocksparse(a, elt, axes) +end + +# Fix ambiguity error with `BlockArrays.jl`. +function Base.similar( + a::StridedArray, + elt::Type, + axes::Tuple{ + AbstractGradedUnitRange,AbstractGradedUnitRange,Vararg{AbstractGradedUnitRange} + }, +) + return similar_blocksparse(a, elt, axes) +end + +function Base.getindex(a::AbstractArray, I::AbstractGradedUnitRange...) + a′ = similar(a, only.(axes.(I))...) + a′ .= a + return a′ +end + +end diff --git a/src/abstractblocksparsearray/map.jl b/src/abstractblocksparsearray/map.jl index fd5cec71..ad789e06 100644 --- a/src/abstractblocksparsearray/map.jl +++ b/src/abstractblocksparsearray/map.jl @@ -69,11 +69,19 @@ end @interface interface::AbstractBlockSparseArrayInterface function Base.map!( f, a_dest::AbstractArray, a_srcs::AbstractArray... ) + if isempty(a_srcs) + # Broadcast expressions of the form `a .= 2`. + error("Not implemented.") + end if iszero(ndims(a_dest)) @interface interface map_zero_dim!(f, a_dest, a_srcs...) return a_dest end + # TODO: This assumes element types are numbers, generalize this logic. + elt = promote_type(eltype.(a_srcs)...) + f_preserves_zeros = f(zero(elt)) == zero(elt) + a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs) for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...) BI_dest = blockindexrange(a_dest, I) @@ -88,11 +96,13 @@ end end subblock_dest = @view block_dest[BI_dest.indices...] subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs)) - # TODO: Use `map!!` to handle immutable blocks. - map!(f, subblock_dest, subblock_srcs...) - # Replace the entire block, handles initializing new blocks - # or if blocks are immutable. - blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...] = block_dest + if f_preserves_zeros && any(!iszero, subblock_srcs) + # TODO: Use `map!!` to handle immutable blocks. + map!(f, subblock_dest, subblock_srcs...) + # Replace the entire block, handles initializing new blocks + # or if blocks are immutable. + blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...] = block_dest + end end return a_dest end @@ -120,7 +130,17 @@ end end function Base.map!(f, a_dest::AbstractArray, a_srcs::AnyAbstractBlockSparseArray...) - @interface interface(a_srcs...) map!(f, a_dest, a_srcs...) + @interface interface(a_dest, a_srcs...) map!(f, a_dest, a_srcs...) + return a_dest +end +function Base.map!(f, a_dest::AnyAbstractBlockSparseArray, a_srcs::AbstractArray...) + @interface interface(a_dest, a_srcs...) map!(f, a_dest, a_srcs...) + return a_dest +end +function Base.map!( + f, a_dest::AnyAbstractBlockSparseArray, a_srcs::AnyAbstractBlockSparseArray... +) + @interface interface(a_dest, a_srcs...) map!(f, a_dest, a_srcs...) return a_dest end diff --git a/src/blocksparsearrayinterface/broadcast.jl b/src/blocksparsearrayinterface/broadcast.jl index bee4bda1..3031591b 100644 --- a/src/blocksparsearrayinterface/broadcast.jl +++ b/src/blocksparsearrayinterface/broadcast.jl @@ -44,6 +44,46 @@ function Base.copyto!( # convert to map # flatten and only keep the AbstractArray arguments m = Mapped(bc) - @interface interface(bc) map!(m.f, dest, m.args...) + @interface interface(dest, bc) map!(m.f, dest, m.args...) + return dest +end + +# Broadcasting implementation +# TODO: Delete this in favor of `DerivableInterfaces` version. +function Base.copyto!(dest::AnyAbstractBlockSparseArray, bc::Broadcasted) + # convert to map + # flatten and only keep the AbstractArray arguments + m = Mapped(bc) + # TODO: Include `bc` when determining interface, currently + # `interface(::Type{<:Base.Broadcast.DefaultArrayStyle})` + # isn't defined. + @interface interface(dest) map!(m.f, dest, m.args...) + return dest +end + +# Broadcasting implementation +# TODO: Delete this in favor of `DerivableInterfaces` version. +function Base.copyto!( + dest::AnyAbstractBlockSparseArray, bc::Broadcasted{<:Base.Broadcast.AbstractArrayStyle{0}} +) + # convert to map + # flatten and only keep the AbstractArray arguments + m = Mapped(bc) + # TODO: Include `bc` when determining interface, currently + # `interface(::Type{<:Base.Broadcast.DefaultArrayStyle})` + # isn't defined. + @interface interface(dest) map!(m.f, dest, m.args...) + return dest +end + +# Broadcasting implementation +# TODO: Delete this in favor of `DerivableInterfaces` version. +function Base.copyto!( + dest::AnyAbstractBlockSparseArray{<:Any,N}, bc::Broadcasted{BlockSparseArrayStyle{N}} +) where {N} + # convert to map + # flatten and only keep the AbstractArray arguments + m = Mapped(bc) + @interface interface(dest, bc) map!(m.f, dest, m.args...) return dest end From 6d48e73ab1715676e9605526806fd869876f041a Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 3 Feb 2025 18:29:07 -0500 Subject: [PATCH 02/15] Make broadcasting more general --- src/abstractblocksparsearray/broadcast.jl | 20 +++++++++++ src/blocksparsearrayinterface/broadcast.jl | 41 ++-------------------- 2 files changed, 23 insertions(+), 38 deletions(-) diff --git a/src/abstractblocksparsearray/broadcast.jl b/src/abstractblocksparsearray/broadcast.jl index 96841be6..3527831c 100644 --- a/src/abstractblocksparsearray/broadcast.jl +++ b/src/abstractblocksparsearray/broadcast.jl @@ -46,3 +46,23 @@ function Broadcast.BroadcastStyle( ) return BlockSparseArrayStyle{ndims(arraytype)}() end + +# These catch cases that aren't caught by the standard +# `BlockSparseArrayStyle` definition, and also fix +# ambiguity issues. +function Base.copyto!(dest::AnyAbstractBlockSparseArray, bc::Broadcasted) + copyto_blocksparse!(dest, bc) + return dest +end +function Base.copyto!( + dest::AnyAbstractBlockSparseArray, bc::Broadcasted{<:Base.Broadcast.AbstractArrayStyle{0}} +) + copyto_blocksparse!(dest, bc) + return dest +end +function Base.copyto!( + dest::AnyAbstractBlockSparseArray{<:Any,N}, bc::Broadcasted{BlockSparseArrayStyle{N}} +) where {N} + copyto_blocksparse!(dest, bc) + return dest +end diff --git a/src/blocksparsearrayinterface/broadcast.jl b/src/blocksparsearrayinterface/broadcast.jl index 3031591b..ac0a8556 100644 --- a/src/blocksparsearrayinterface/broadcast.jl +++ b/src/blocksparsearrayinterface/broadcast.jl @@ -38,9 +38,7 @@ end # Broadcasting implementation # TODO: Delete this in favor of `DerivableInterfaces` version. -function Base.copyto!( - dest::AbstractArray{<:Any,N}, bc::Broadcasted{BlockSparseArrayStyle{N}} -) where {N} +function copyto_blocksparse!(dest::AbstractArray, bc::Broadcasted) # convert to map # flatten and only keep the AbstractArray arguments m = Mapped(bc) @@ -48,42 +46,9 @@ function Base.copyto!( return dest end -# Broadcasting implementation -# TODO: Delete this in favor of `DerivableInterfaces` version. -function Base.copyto!(dest::AnyAbstractBlockSparseArray, bc::Broadcasted) - # convert to map - # flatten and only keep the AbstractArray arguments - m = Mapped(bc) - # TODO: Include `bc` when determining interface, currently - # `interface(::Type{<:Base.Broadcast.DefaultArrayStyle})` - # isn't defined. - @interface interface(dest) map!(m.f, dest, m.args...) - return dest -end - -# Broadcasting implementation -# TODO: Delete this in favor of `DerivableInterfaces` version. function Base.copyto!( - dest::AnyAbstractBlockSparseArray, bc::Broadcasted{<:Base.Broadcast.AbstractArrayStyle{0}} -) - # convert to map - # flatten and only keep the AbstractArray arguments - m = Mapped(bc) - # TODO: Include `bc` when determining interface, currently - # `interface(::Type{<:Base.Broadcast.DefaultArrayStyle})` - # isn't defined. - @interface interface(dest) map!(m.f, dest, m.args...) - return dest -end - -# Broadcasting implementation -# TODO: Delete this in favor of `DerivableInterfaces` version. -function Base.copyto!( - dest::AnyAbstractBlockSparseArray{<:Any,N}, bc::Broadcasted{BlockSparseArrayStyle{N}} + dest::AbstractArray{<:Any,N}, bc::Broadcasted{BlockSparseArrayStyle{N}} ) where {N} - # convert to map - # flatten and only keep the AbstractArray arguments - m = Mapped(bc) - @interface interface(dest, bc) map!(m.f, dest, m.args...) + copyto_blocksparse!(dest, bc) return dest end From 09d1591700de173fef1e48a1e2ed1f66649cb8ca Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 3 Feb 2025 21:20:38 -0500 Subject: [PATCH 03/15] Try fixing some tests --- .../BlockSparseArraysGradedUnitRangesExt.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl b/ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl index 4709c0ae..95672bae 100644 --- a/ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl +++ b/ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl @@ -35,7 +35,9 @@ function Base.similar( return similar_blocksparse(a, elt, axes) end -function Base.getindex(a::AbstractArray, I::AbstractGradedUnitRange...) +function Base.getindex( + a::AbstractArray, I1::AbstractGradedUnitRange, I_rest::AbstractGradedUnitRange... +) a′ = similar(a, only.(axes.(I))...) a′ .= a return a′ From 3d7f094ba14dd79a69fdcbbf9539b1ad631666dd Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 3 Feb 2025 21:59:30 -0500 Subject: [PATCH 04/15] Fix map --- src/abstractblocksparsearray/map.jl | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/abstractblocksparsearray/map.jl b/src/abstractblocksparsearray/map.jl index ad789e06..6cebbe9f 100644 --- a/src/abstractblocksparsearray/map.jl +++ b/src/abstractblocksparsearray/map.jl @@ -96,13 +96,18 @@ end end subblock_dest = @view block_dest[BI_dest.indices...] subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs)) - if f_preserves_zeros && any(!iszero, subblock_srcs) - # TODO: Use `map!!` to handle immutable blocks. - map!(f, subblock_dest, subblock_srcs...) - # Replace the entire block, handles initializing new blocks - # or if blocks are immutable. - blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...] = block_dest + I_dest = CartesianIndex(Int.(Tuple(_block(BI_dest)))) + # If the function preserves zero values and all of the source blocks are zero, + # the output block will be zero. In that case, if the block isn't stored yet, + # don't do anything. + if f_preserves_zeros && all(iszero, subblock_srcs) && !isstored(blocks(a_dest), I_dest) + continue end + # TODO: Use `map!!` to handle immutable blocks. + map!(f, subblock_dest, subblock_srcs...) + # Replace the entire block, handles initializing new blocks + # or if blocks are immutable. + blocks(a_dest)[I_dest] = block_dest end return a_dest end From 0856230937afae67b93ea9ea2c302ca4a294414f Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 3 Feb 2025 23:30:26 -0500 Subject: [PATCH 05/15] Fix some broken tests --- src/abstractblocksparsearray/map.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/abstractblocksparsearray/map.jl b/src/abstractblocksparsearray/map.jl index 6cebbe9f..9ffff15e 100644 --- a/src/abstractblocksparsearray/map.jl +++ b/src/abstractblocksparsearray/map.jl @@ -77,11 +77,8 @@ end @interface interface map_zero_dim!(f, a_dest, a_srcs...) return a_dest end - # TODO: This assumes element types are numbers, generalize this logic. - elt = promote_type(eltype.(a_srcs)...) - f_preserves_zeros = f(zero(elt)) == zero(elt) - + f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest)) a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs) for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...) BI_dest = blockindexrange(a_dest, I) From 6da5d4fa94e3cc3a2b52f252f040bedd2abad80b Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 4 Feb 2025 00:18:39 -0500 Subject: [PATCH 06/15] Fix more tests --- .../BlockSparseArraysGradedUnitRangesExt.jl | 44 +++++++++++++++++-- src/abstractblocksparsearray/cat.jl | 7 ++- src/abstractblocksparsearray/map.jl | 3 +- 3 files changed, 47 insertions(+), 7 deletions(-) diff --git a/ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl b/ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl index 95672bae..31a1fec3 100644 --- a/ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl +++ b/ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl @@ -1,7 +1,8 @@ module BlockSparseArraysGradedUnitRangesExt -using BlockSparseArrays: BlockSparseArray +using BlockSparseArrays: AnyAbstractBlockSparseArray, BlockSparseArray, blocktype using GradedUnitRanges: AbstractGradedUnitRange +using TypeParameterAccessors: set_ndims, unwrap_array_type # A block spare array similar to the input (dense) array. # TODO: Make `BlockSparseArrays.blocksparse_similar` more general and use that, @@ -13,7 +14,11 @@ function similar_blocksparse( ) # TODO: Probably need to unwrap the type of `a` in certain cases # to make a proper block type. - return BlockSparseArray{elt,length(axes),typeof(a)}(axes) + return BlockSparseArray{ + elt,length(axes),set_ndims(unwrap_array_type(blocktype(a)), length(axes)) + }( + axes + ) end function Base.similar( @@ -35,12 +40,43 @@ function Base.similar( return similar_blocksparse(a, elt, axes) end -function Base.getindex( - a::AbstractArray, I1::AbstractGradedUnitRange, I_rest::AbstractGradedUnitRange... +# Fix ambiguity error with `BlockSparseArrays.jl`. +function Base.similar( + a::AnyAbstractBlockSparseArray, + elt::Type, + axes::Tuple{ + AbstractGradedUnitRange,AbstractGradedUnitRange,Vararg{AbstractGradedUnitRange} + }, ) + return similar_blocksparse(a, elt, axes) +end + +function getindex_blocksparse(a::AbstractArray, I::AbstractUnitRange...) a′ = similar(a, only.(axes.(I))...) a′ .= a return a′ end +function Base.getindex( + a::AbstractArray, I1::AbstractGradedUnitRange, I_rest::AbstractGradedUnitRange... +) + return getindex_blocksparse(a, I1, I_rest...) +end + +# Fix ambiguity errors. +function Base.getindex( + a::AnyAbstractBlockSparseArray, + I1::AbstractGradedUnitRange, + I_rest::AbstractGradedUnitRange..., +) + return getindex_blocksparse(a, I1, I_rest...) +end +function Base.getindex( + a::AnyAbstractBlockSparseArray{<:Any,2}, + I1::AbstractGradedUnitRange, + I2::AbstractGradedUnitRange, +) + return getindex_blocksparse(a, I1, I2) +end + end diff --git a/src/abstractblocksparsearray/cat.jl b/src/abstractblocksparsearray/cat.jl index eefb1a59..66626c9f 100644 --- a/src/abstractblocksparsearray/cat.jl +++ b/src/abstractblocksparsearray/cat.jl @@ -1,6 +1,9 @@ using DerivableInterfaces: @interface, interface -# TODO: Define with `@derive`. -function Base.cat(as::AnyAbstractBlockSparseArray...; dims) +function Base._cat(dims, as::AnyAbstractBlockSparseArray...) + # TODO: Call `DerivableInterfaces.cat_along(dims, as...)` instead, + # for better inferability. See: + # https://github.com/ITensor/DerivableInterfaces.jl/pull/13 + # https://github.com/ITensor/DerivableInterfaces.jl/pull/17 return @interface interface(as...) cat(as...; dims) end diff --git a/src/abstractblocksparsearray/map.jl b/src/abstractblocksparsearray/map.jl index 9ffff15e..811171e3 100644 --- a/src/abstractblocksparsearray/map.jl +++ b/src/abstractblocksparsearray/map.jl @@ -71,7 +71,8 @@ end ) if isempty(a_srcs) # Broadcast expressions of the form `a .= 2`. - error("Not implemented.") + @interface interface fill!(a_dest, f()) + return a_dest end if iszero(ndims(a_dest)) @interface interface map_zero_dim!(f, a_dest, a_srcs...) From 94240580d65ab1309b88afbf162593c5fafae261 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 4 Feb 2025 11:46:40 -0500 Subject: [PATCH 07/15] Reorganize map code, stricter block sparse `map!` --- src/abstractblocksparsearray/map.jl | 131 --------------------------- src/blocksparsearrayinterface/map.jl | 127 ++++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 131 deletions(-) diff --git a/src/abstractblocksparsearray/map.jl b/src/abstractblocksparsearray/map.jl index 811171e3..522cc67a 100644 --- a/src/abstractblocksparsearray/map.jl +++ b/src/abstractblocksparsearray/map.jl @@ -1,136 +1,5 @@ using ArrayLayouts: LayoutArray -using BlockArrays: blockisequal -using DerivableInterfaces: @interface, AbstractArrayInterface, interface -using GPUArraysCore: @allowscalar using LinearAlgebra: Adjoint, Transpose -using SparseArraysBase: SparseArraysBase, SparseArrayStyle - -# Returns `Vector{<:CartesianIndices}` -function union_stored_blocked_cartesianindices(as::Vararg{AbstractArray}) - combined_axes = combine_axes(axes.(as)...) - stored_blocked_cartesianindices_as = map(as) do a - return blocked_cartesianindices(axes(a), combined_axes, eachblockstoredindex(a)) - end - return ∪(stored_blocked_cartesianindices_as...) -end - -# This is used by `map` to get the output axes. -# This is type piracy, try to avoid this, maybe requires defining `map`. -## Base.promote_shape(a1::Tuple{Vararg{BlockedUnitRange}}, a2::Tuple{Vararg{BlockedUnitRange}}) = combine_axes(a1, a2) - -reblock(a) = a - -# If the blocking of the slice doesn't match the blocking of the -# parent array, reblock according to the blocking of the parent array. -function reblock( - a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{AbstractUnitRange}}} -) - # TODO: This relies on the behavior that slicing a block sparse - # array with a UnitRange inherits the blocking of the underlying - # block sparse array, we might change that default behavior - # so this might become something like `@blocked parent(a)[...]`. - return @view parent(a)[UnitRange{Int}.(parentindices(a))...] -end - -function reblock( - a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{NonBlockedArray}}} -) - return @view parent(a)[map(I -> I.array, parentindices(a))...] -end - -function reblock( - a::SubArray{ - <:Any, - <:Any, - <:AbstractBlockSparseArray, - <:Tuple{Vararg{BlockIndices{<:AbstractBlockVector{<:Block{1}}}}}, - }, -) - # Remove the blocking. - return @view parent(a)[map(I -> Vector(I.blocks), parentindices(a))...] -end - -# `map!` specialized to zero-dimensional inputs. -function map_zero_dim! end - -@interface ::AbstractArrayInterface function map_zero_dim!( - f, a_dest::AbstractArray, a_srcs::AbstractArray... -) - @allowscalar a_dest[] = f.(map(a_src -> a_src[], a_srcs)...) - return a_dest -end - -# TODO: Move to `blocksparsearrayinterface/map.jl`. -# TODO: Rewrite this so that it takes the blocking structure -# made by combining the blocking of the axes (i.e. the blocking that -# is used to determine `union_stored_blocked_cartesianindices(...)`). -# `reblock` is a partial solution to that, but a bit ad-hoc. -## TODO: Make this an `@interface AbstractBlockSparseArrayInterface` function. -@interface interface::AbstractBlockSparseArrayInterface function Base.map!( - f, a_dest::AbstractArray, a_srcs::AbstractArray... -) - if isempty(a_srcs) - # Broadcast expressions of the form `a .= 2`. - @interface interface fill!(a_dest, f()) - return a_dest - end - if iszero(ndims(a_dest)) - @interface interface map_zero_dim!(f, a_dest, a_srcs...) - return a_dest - end - # TODO: This assumes element types are numbers, generalize this logic. - f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest)) - a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs) - for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...) - BI_dest = blockindexrange(a_dest, I) - BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs) - # TODO: Investigate why this doesn't work: - # block_dest = @view a_dest[_block(BI_dest)] - block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(_block(BI_dest)))...] - # TODO: Investigate why this doesn't work: - # block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs)) - block_srcs = ntuple(length(a_srcs)) do i - return blocks_maybe_single(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...] - end - subblock_dest = @view block_dest[BI_dest.indices...] - subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs)) - I_dest = CartesianIndex(Int.(Tuple(_block(BI_dest)))) - # If the function preserves zero values and all of the source blocks are zero, - # the output block will be zero. In that case, if the block isn't stored yet, - # don't do anything. - if f_preserves_zeros && all(iszero, subblock_srcs) && !isstored(blocks(a_dest), I_dest) - continue - end - # TODO: Use `map!!` to handle immutable blocks. - map!(f, subblock_dest, subblock_srcs...) - # Replace the entire block, handles initializing new blocks - # or if blocks are immutable. - blocks(a_dest)[I_dest] = block_dest - end - return a_dest -end - -# TODO: Move to `blocksparsearrayinterface/map.jl`. -@interface ::AbstractBlockSparseArrayInterface function Base.mapreduce( - f, op, as::AbstractArray...; kwargs... -) - # TODO: Define an `init` value based on the element type. - return @interface interface(blocks.(as)...) mapreduce( - block -> mapreduce(f, op, block), op, blocks.(as)...; kwargs... - ) -end - -# TODO: Move to `blocksparsearrayinterface/map.jl`. -@interface ::AbstractBlockSparseArrayInterface function Base.iszero(a::AbstractArray) - # TODO: Just call `iszero(blocks(a))`? - return @interface interface(blocks(a)) iszero(blocks(a)) -end - -# TODO: Move to `blocksparsearrayinterface/map.jl`. -@interface ::AbstractBlockSparseArrayInterface function Base.isreal(a::AbstractArray) - # TODO: Just call `isreal(blocks(a))`? - return @interface interface(blocks(a)) isreal(blocks(a)) -end function Base.map!(f, a_dest::AbstractArray, a_srcs::AnyAbstractBlockSparseArray...) @interface interface(a_dest, a_srcs...) map!(f, a_dest, a_srcs...) diff --git a/src/blocksparsearrayinterface/map.jl b/src/blocksparsearrayinterface/map.jl index ebbfba40..92011836 100644 --- a/src/blocksparsearrayinterface/map.jl +++ b/src/blocksparsearrayinterface/map.jl @@ -1,3 +1,130 @@ +using DerivableInterfaces: @interface, AbstractArrayInterface, interface +using GPUArraysCore: @allowscalar + +# TODO: Rewrite this so that it takes the blocking structure +# made by combining the blocking of the axes (i.e. the blocking that +# is used to determine `union_stored_blocked_cartesianindices(...)`). +# `reblock` is a partial solution to that, but a bit ad-hoc. +## TODO: Make this an `@interface AbstractBlockSparseArrayInterface` function. +@interface interface::AbstractBlockSparseArrayInterface function Base.map!( + f, a_dest::AbstractArray, a_srcs::AbstractArray... +) + if isempty(a_srcs) + error("Can't call `map!` with zero source terms.") + end + if iszero(ndims(a_dest)) + @interface interface map_zero_dim!(f, a_dest, a_srcs...) + return a_dest + end + # TODO: This assumes element types are numbers, generalize this logic. + f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest)) + a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs) + for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...) + BI_dest = blockindexrange(a_dest, I) + BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs) + # TODO: Investigate why this doesn't work: + # block_dest = @view a_dest[_block(BI_dest)] + block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(_block(BI_dest)))...] + # TODO: Investigate why this doesn't work: + # block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs)) + block_srcs = ntuple(length(a_srcs)) do i + return blocks_maybe_single(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...] + end + subblock_dest = @view block_dest[BI_dest.indices...] + subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs)) + I_dest = CartesianIndex(Int.(Tuple(_block(BI_dest)))) + # If the function preserves zero values and all of the source blocks are zero, + # the output block will be zero. In that case, if the block isn't stored yet, + # don't do anything. + if f_preserves_zeros && all(iszero, subblock_srcs) && !isstored(blocks(a_dest), I_dest) + continue + end + # TODO: Use `map!!` to handle immutable blocks. + map!(f, subblock_dest, subblock_srcs...) + # Replace the entire block, handles initializing new blocks + # or if blocks are immutable. + blocks(a_dest)[I_dest] = block_dest + end + return a_dest +end + +@interface ::AbstractBlockSparseArrayInterface function Base.mapreduce( + f, op, as::AbstractArray...; kwargs... +) + # TODO: Define an `init` value based on the element type. + return @interface interface(blocks.(as)...) mapreduce( + block -> mapreduce(f, op, block), op, blocks.(as)...; kwargs... + ) +end + +@interface ::AbstractBlockSparseArrayInterface function Base.iszero(a::AbstractArray) + # TODO: Just call `iszero(blocks(a))`? + return @interface interface(blocks(a)) iszero(blocks(a)) +end + +@interface ::AbstractBlockSparseArrayInterface function Base.isreal(a::AbstractArray) + # TODO: Just call `isreal(blocks(a))`? + return @interface interface(blocks(a)) isreal(blocks(a)) +end + +# Helper functions for block sparse map. + +# Returns `Vector{<:CartesianIndices}` +function union_stored_blocked_cartesianindices(as::Vararg{AbstractArray}) + combined_axes = combine_axes(axes.(as)...) + stored_blocked_cartesianindices_as = map(as) do a + return blocked_cartesianindices(axes(a), combined_axes, eachblockstoredindex(a)) + end + return ∪(stored_blocked_cartesianindices_as...) +end + +# This is used by `map` to get the output axes. +# This is type piracy, try to avoid this, maybe requires defining `map`. +## Base.promote_shape(a1::Tuple{Vararg{BlockedUnitRange}}, a2::Tuple{Vararg{BlockedUnitRange}}) = combine_axes(a1, a2) + +reblock(a) = a + +# If the blocking of the slice doesn't match the blocking of the +# parent array, reblock according to the blocking of the parent array. +function reblock( + a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{AbstractUnitRange}}} +) + # TODO: This relies on the behavior that slicing a block sparse + # array with a UnitRange inherits the blocking of the underlying + # block sparse array, we might change that default behavior + # so this might become something like `@blocked parent(a)[...]`. + return @view parent(a)[UnitRange{Int}.(parentindices(a))...] +end + +function reblock( + a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{NonBlockedArray}}} +) + return @view parent(a)[map(I -> I.array, parentindices(a))...] +end + +function reblock( + a::SubArray{ + <:Any, + <:Any, + <:AbstractBlockSparseArray, + <:Tuple{Vararg{BlockIndices{<:AbstractBlockVector{<:Block{1}}}}}, + }, +) + # Remove the blocking. + return @view parent(a)[map(I -> Vector(I.blocks), parentindices(a))...] +end + +# `map!` specialized to zero-dimensional inputs. +function map_zero_dim! end + +@interface ::AbstractArrayInterface function map_zero_dim!( + f, a_dest::AbstractArray, a_srcs::AbstractArray... +) + @allowscalar a_dest[] = f.(map(a_src -> a_src[], a_srcs)...) + return a_dest +end + +# TODO: Decide what to do with these. function map_stored_blocks(f, a::AbstractArray) # TODO: Implement this as: # ```julia From 65d5b1a04d57dfa018ea5607212ca2b0d5eddd9d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 4 Feb 2025 12:57:50 -0500 Subject: [PATCH 08/15] More general definition of blocktype --- .../blocksparsearrayinterface.jl | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index e52c4ac7..95350f58 100644 --- a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -40,6 +40,31 @@ function eachstoredblock(a::AbstractArray) return storedvalues(blocks(a)) end +# TODO: Generalize this, this catches simple cases +# where the more general definition isn't specific enough. +blocktype(a::Array) = typeof(a) +# TODO: Maybe unwrap SubArrays? +function blocktype(a::AbstractArray) + # TODO: Unfortunately, this doesn't always give + # a concrete type, even when it could be concrete, i.e. + #= + ```julia + julia> eltype(blocks(BlockArray(randn(2, 2), [1, 1], [1, 1]))) + Matrix{Float64} (alias for Array{Float64, 2}) + + julia> eltype(blocks(BlockedArray(randn(2, 2), [1, 1], [1, 1]))) + AbstractMatrix{Float64} (alias for AbstractArray{Float64, 2}) + + julia> eltype(blocks(randn(2, 2))) + AbstractMatrix{Float64} (alias for AbstractArray{Float64, 2}) + ``` + =# + if isempty(blocks(a)) + return eltype(blocks(a)) + end + return eltype(first(blocks(a))) +end + abstract type AbstractBlockSparseArrayInterface <: AbstractSparseArrayInterface end # TODO: Also support specifying the `blocktype` along with the `eltype`. From b17d41c10a655a62f5b8a0dbfc19da922784b49f Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 4 Feb 2025 13:01:38 -0500 Subject: [PATCH 09/15] Try fixing tests --- src/abstractblocksparsearray/map.jl | 34 ++++++++++++++++++++++++++++ src/blocksparsearrayinterface/map.jl | 30 ------------------------ 2 files changed, 34 insertions(+), 30 deletions(-) diff --git a/src/abstractblocksparsearray/map.jl b/src/abstractblocksparsearray/map.jl index 522cc67a..724b0ffc 100644 --- a/src/abstractblocksparsearray/map.jl +++ b/src/abstractblocksparsearray/map.jl @@ -1,6 +1,40 @@ using ArrayLayouts: LayoutArray +using BlockArrays: AbstractBlockVector, Block using LinearAlgebra: Adjoint, Transpose +# TODO: Make this more general, independent of `AbstractBlockSparseArray`. +# If the blocking of the slice doesn't match the blocking of the +# parent array, reblock according to the blocking of the parent array. +function reblock( + a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{AbstractUnitRange}}} +) + # TODO: This relies on the behavior that slicing a block sparse + # array with a UnitRange inherits the blocking of the underlying + # block sparse array, we might change that default behavior + # so this might become something like `@blocked parent(a)[...]`. + return @view parent(a)[UnitRange{Int}.(parentindices(a))...] +end + +# TODO: Make this more general, independent of `AbstractBlockSparseArray`. +function reblock( + a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{NonBlockedArray}}} +) + return @view parent(a)[map(I -> I.array, parentindices(a))...] +end + +# TODO: Make this more general, independent of `AbstractBlockSparseArray`. +function reblock( + a::SubArray{ + <:Any, + <:Any, + <:AbstractBlockSparseArray, + <:Tuple{Vararg{BlockIndices{<:AbstractBlockVector{<:Block{1}}}}}, + }, +) + # Remove the blocking. + return @view parent(a)[map(I -> Vector(I.blocks), parentindices(a))...] +end + function Base.map!(f, a_dest::AbstractArray, a_srcs::AnyAbstractBlockSparseArray...) @interface interface(a_dest, a_srcs...) map!(f, a_dest, a_srcs...) return a_dest diff --git a/src/blocksparsearrayinterface/map.jl b/src/blocksparsearrayinterface/map.jl index 92011836..b97bff89 100644 --- a/src/blocksparsearrayinterface/map.jl +++ b/src/blocksparsearrayinterface/map.jl @@ -84,36 +84,6 @@ end reblock(a) = a -# If the blocking of the slice doesn't match the blocking of the -# parent array, reblock according to the blocking of the parent array. -function reblock( - a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{AbstractUnitRange}}} -) - # TODO: This relies on the behavior that slicing a block sparse - # array with a UnitRange inherits the blocking of the underlying - # block sparse array, we might change that default behavior - # so this might become something like `@blocked parent(a)[...]`. - return @view parent(a)[UnitRange{Int}.(parentindices(a))...] -end - -function reblock( - a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{NonBlockedArray}}} -) - return @view parent(a)[map(I -> I.array, parentindices(a))...] -end - -function reblock( - a::SubArray{ - <:Any, - <:Any, - <:AbstractBlockSparseArray, - <:Tuple{Vararg{BlockIndices{<:AbstractBlockVector{<:Block{1}}}}}, - }, -) - # Remove the blocking. - return @view parent(a)[map(I -> Vector(I.blocks), parentindices(a))...] -end - # `map!` specialized to zero-dimensional inputs. function map_zero_dim! end From 36cd30a4df7688e56c70fe2000013208715553a1 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 4 Feb 2025 13:03:12 -0500 Subject: [PATCH 10/15] Missing import --- src/blocksparsearrayinterface/blocksparsearrayinterface.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 95350f58..b907a6ad 100644 --- a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -11,6 +11,7 @@ using BlockArrays: BlockedVector, block, blockcheckbounds, + blockisequal, blocklengths, blocks, findblockindex From 090c202e156b7fa0408401a8c232750391ccec95 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 4 Feb 2025 13:12:28 -0500 Subject: [PATCH 11/15] Add test for converting dense to block sparse with graded unit ranges --- test/test_gradedunitrangesext.jl | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/test/test_gradedunitrangesext.jl b/test/test_gradedunitrangesext.jl index 4822d45f..bccb692e 100644 --- a/test/test_gradedunitrangesext.jl +++ b/test/test_gradedunitrangesext.jl @@ -2,7 +2,7 @@ using Test: @test, @testset using BlockArrays: AbstractBlockArray, Block, BlockedOneTo, blockedrange, blocklengths, blocksize -using BlockSparseArrays: BlockSparseArray, blockstoredlength +using BlockSparseArrays: BlockSparseArray, BlockSparseMatrix, blockstoredlength using GradedUnitRanges: GradedUnitRanges, GradedOneTo, @@ -318,5 +318,20 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test Array(a1' * a2) ≈ Array(a1') * Array(a2) @test Array(a1 * a2') ≈ Array(a1) * Array(a2') end + @testset "Construct from dense" begin + r = gradedrange([U1(0) => 2, U1(1) => 3]) + a1 = randn(elt, 2, 2) + a2 = randn(elt, 3, 3) + a = cat(a1, a2; dims=(1, 2)) + b = a[r, dual(r)] + @test eltype(b) === elt + @test b isa BlockSpareMatrix{elt} + @test blockstoredlength(b) == 2 + @test b[Block(1, 1)] == a1 + @test iszero(b[Block(2, 1)]) + @test iszero(b[Block(1, 2)]) + @test b[Block(2, 2)] == a2 + @test all(GradedUnitRanges.space_isequal.(axes(b), (r, dual(r)))) + end end end From a9dd3b0f6704f1d120695a803744a08b0790cfe7 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 4 Feb 2025 13:58:48 -0500 Subject: [PATCH 12/15] Special case scalar broadcasting --- src/blocksparsearrayinterface/broadcast.jl | 8 ++++++++ test/test_gradedunitrangesext.jl | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/blocksparsearrayinterface/broadcast.jl b/src/blocksparsearrayinterface/broadcast.jl index ac0a8556..8287b6d0 100644 --- a/src/blocksparsearrayinterface/broadcast.jl +++ b/src/blocksparsearrayinterface/broadcast.jl @@ -36,6 +36,14 @@ function Base.similar(bc::Broadcasted{<:BlockSparseArrayStyle}, elt::Type) return similar(first(m.args), elt, combine_axes(axes.(m.args)...)) end +# Catches cases like `dest .= value` or `dest .= value1 .+ value2`. +# If the RHS is zero, this makes sure that the storage is emptied, +# which is logic that is handled by `fill!`. +function copyto_blocksparse!(dest::AbstractArray, bc::Broadcasted{<:AbstractArrayStyle{0}}) + value = bc.f(bc.args...) + return fill!(dest, value) +end + # Broadcasting implementation # TODO: Delete this in favor of `DerivableInterfaces` version. function copyto_blocksparse!(dest::AbstractArray, bc::Broadcasted) diff --git a/test/test_gradedunitrangesext.jl b/test/test_gradedunitrangesext.jl index bccb692e..67b069e2 100644 --- a/test/test_gradedunitrangesext.jl +++ b/test/test_gradedunitrangesext.jl @@ -325,7 +325,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) a = cat(a1, a2; dims=(1, 2)) b = a[r, dual(r)] @test eltype(b) === elt - @test b isa BlockSpareMatrix{elt} + @test b isa BlockSparseMatrix{elt} @test blockstoredlength(b) == 2 @test b[Block(1, 1)] == a1 @test iszero(b[Block(2, 1)]) From 55ff45ae912faeb091802b21580030c48a7b2403 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 4 Feb 2025 14:06:19 -0500 Subject: [PATCH 13/15] Try fixing tests --- src/blocksparsearrayinterface/broadcast.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blocksparsearrayinterface/broadcast.jl b/src/blocksparsearrayinterface/broadcast.jl index 8287b6d0..f2beee6d 100644 --- a/src/blocksparsearrayinterface/broadcast.jl +++ b/src/blocksparsearrayinterface/broadcast.jl @@ -41,7 +41,7 @@ end # which is logic that is handled by `fill!`. function copyto_blocksparse!(dest::AbstractArray, bc::Broadcasted{<:AbstractArrayStyle{0}}) value = bc.f(bc.args...) - return fill!(dest, value) + return @interface BlockSparseArrayInterface() fill!(dest, value) end # Broadcasting implementation From 4702b5707a16dcbd0152c00bc22528a10b00a22e Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 4 Feb 2025 14:14:25 -0500 Subject: [PATCH 14/15] Fix corner case in zero-dim broadcast --- src/blocksparsearrayinterface/broadcast.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/blocksparsearrayinterface/broadcast.jl b/src/blocksparsearrayinterface/broadcast.jl index f2beee6d..f79aa433 100644 --- a/src/blocksparsearrayinterface/broadcast.jl +++ b/src/blocksparsearrayinterface/broadcast.jl @@ -40,7 +40,8 @@ end # If the RHS is zero, this makes sure that the storage is emptied, # which is logic that is handled by `fill!`. function copyto_blocksparse!(dest::AbstractArray, bc::Broadcasted{<:AbstractArrayStyle{0}}) - value = bc.f(bc.args...) + # `[]` is used to unwrap zero-dimensional arrays. + value = bc.f(bc.args...)[] return @interface BlockSparseArrayInterface() fill!(dest, value) end From 82abcf794262c1205155fc687f367645c10e0007 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 4 Feb 2025 14:14:34 -0500 Subject: [PATCH 15/15] Fix corner case in zero-dim broadcast --- src/blocksparsearrayinterface/broadcast.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/blocksparsearrayinterface/broadcast.jl b/src/blocksparsearrayinterface/broadcast.jl index f79aa433..57ebe783 100644 --- a/src/blocksparsearrayinterface/broadcast.jl +++ b/src/blocksparsearrayinterface/broadcast.jl @@ -1,4 +1,5 @@ using Base.Broadcast: BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, Broadcasted +using GPUArraysCore: @allowscalar using MapBroadcast: Mapped using DerivableInterfaces: DerivableInterfaces, @interface @@ -41,7 +42,7 @@ end # which is logic that is handled by `fill!`. function copyto_blocksparse!(dest::AbstractArray, bc::Broadcasted{<:AbstractArrayStyle{0}}) # `[]` is used to unwrap zero-dimensional arrays. - value = bc.f(bc.args...)[] + value = @allowscalar bc.f(bc.args...)[] return @interface BlockSparseArrayInterface() fill!(dest, value) end