Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.30"
version = "0.3.31"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down Expand Up @@ -56,7 +56,7 @@ AMDGPU = "0.9"
Accessors = "0.1.33"
Adapt = "3.7, 4"
ArrayLayouts = "1.4"
BlockArrays = "1"
BlockArrays = "1.1"
CUDA = "5"
Compat = "4.9"
Dictionaries = "0.4"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ end

function blockrange(
axis::AbstractUnitRange,
r::BlockVector{BlockIndex{1},<:AbstractVector{<:BlockIndexRange{1}}},
r::BlockVector{<:BlockIndex{1},<:AbstractVector{<:BlockIndexRange{1}}},
)
return map(b -> Block(b), blocks(r))
end
Expand Down Expand Up @@ -271,7 +271,7 @@ end
function blockindices(
a::AbstractUnitRange,
b::Block,
r::BlockVector{BlockIndex{1},<:AbstractVector{<:BlockIndexRange{1}}},
r::BlockVector{<:BlockIndex{1},<:AbstractVector{<:BlockIndexRange{1}}},
)
# TODO: Change to iterate over `BlockRange(r)`
# once https://github.com/JuliaArrays/BlockArrays.jl/issues/404
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,3 @@ end
function Base.view(a::BlockSparseArrayLike{<:Any,1}, index::Block{1})
return blocksparse_view(a, index)
end

function Base.view(a::BlockSparseArrayLike, indices::BlockIndexRange)
return view(view(a, block(indices)), indices.indices...)
end
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,18 @@ blocktype(a::BlockSparseArrayLike) = eltype(blocks(a))
blocktype(arraytype::Type{<:BlockSparseArrayLike}) = eltype(blockstype(arraytype))

using ArrayLayouts: ArrayLayouts
## function Base.getindex(a::BlockSparseArrayLike{<:Any,N}, I::Vararg{Int,N}) where {N}
## return ArrayLayouts.layout_getindex(a, I...)
## end
function Base.getindex(a::BlockSparseArrayLike{<:Any,N}, I::CartesianIndices{N}) where {N}
return ArrayLayouts.layout_getindex(a, I)
end
function Base.getindex(
a::BlockSparseArrayLike{<:Any,N}, I::Vararg{AbstractUnitRange,N}
a::BlockSparseArrayLike{<:Any,N}, I::Vararg{AbstractUnitRange{<:Integer},N}
) where {N}
return ArrayLayouts.layout_getindex(a, I...)
end
# TODO: Define `AnyBlockSparseMatrix`.
function Base.getindex(a::BlockSparseArrayLike{<:Any,2}, I::Vararg{AbstractUnitRange,2})
function Base.getindex(
a::BlockSparseArrayLike{<:Any,2}, I::Vararg{AbstractUnitRange{<:Integer},2}
)
return ArrayLayouts.layout_getindex(a, I...)
end

Expand Down Expand Up @@ -199,7 +198,7 @@ end

# Needed by `BlockArrays` matrix multiplication interface
function Base.similar(
arraytype::Type{<:BlockSparseArrayLike}, axes::Tuple{Vararg{AbstractUnitRange}}
arraytype::Type{<:BlockSparseArrayLike}, axes::Tuple{Vararg{AbstractUnitRange{<:Integer}}}
)
return similar(arraytype, eltype(arraytype), axes)
end
Expand All @@ -210,53 +209,45 @@ end
# Delete once we drop support for older versions of Julia.
function Base.similar(
arraytype::Type{<:BlockSparseArrayLike},
axes::Tuple{AbstractUnitRange,Vararg{AbstractUnitRange}},
)
return similar(arraytype, eltype(arraytype), axes)
end

# Needed by `BlockArrays` matrix multiplication interface
# Fixes ambiguity error with `BlockArrays.jl`.
function Base.similar(
arraytype::Type{<:BlockSparseArrayLike},
axes::Tuple{AbstractBlockedUnitRange,Vararg{AbstractUnitRange{Int}}},
axes::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
)
return similar(arraytype, eltype(arraytype), axes)
end

# Needed by `BlockArrays` matrix multiplication interface
# Fixes ambiguity error with `BlockArrays.jl`.
# Fixes ambiguity error with `BlockArrays`.
function Base.similar(
arraytype::Type{<:BlockSparseArrayLike},
axes::Tuple{
AbstractBlockedUnitRange,AbstractBlockedUnitRange,Vararg{AbstractUnitRange{Int}}
},
axes::Tuple{AbstractBlockedUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
)
return similar(arraytype, eltype(arraytype), axes)
end

# Needed by `BlockArrays` matrix multiplication interface
# Fixes ambiguity error with `BlockArrays.jl`.
# Fixes ambiguity error with `BlockArrays`.
function Base.similar(
arraytype::Type{<:BlockSparseArrayLike},
axes::Tuple{
AbstractUnitRange{Int},AbstractBlockedUnitRange,Vararg{AbstractUnitRange{Int}}
AbstractUnitRange{<:Integer},
AbstractBlockedUnitRange{<:Integer},
Vararg{AbstractUnitRange{<:Integer}},
},
)
return similar(arraytype, eltype(arraytype), axes)
end

# Needed for disambiguation
function Base.similar(
arraytype::Type{<:BlockSparseArrayLike}, axes::Tuple{Vararg{AbstractBlockedUnitRange}}
arraytype::Type{<:BlockSparseArrayLike},
axes::Tuple{Vararg{AbstractBlockedUnitRange{<:Integer}}},
)
return similar(arraytype, eltype(arraytype), axes)
end

# Needed by `BlockArrays` matrix multiplication interface
# TODO: Define a `blocksparse_similar` function.
function Base.similar(
arraytype::Type{<:BlockSparseArrayLike}, elt::Type, axes::Tuple{Vararg{AbstractUnitRange}}
arraytype::Type{<:BlockSparseArrayLike},
elt::Type,
axes::Tuple{Vararg{AbstractUnitRange{<:Integer}}},
)
# TODO: Make generic for GPU, maybe using `blocktype`.
# TODO: For non-block axes this should output `Array`.
Expand All @@ -265,7 +256,7 @@ end

# TODO: Define a `blocksparse_similar` function.
function Base.similar(
a::BlockSparseArrayLike, elt::Type, axes::Tuple{Vararg{AbstractUnitRange}}
a::BlockSparseArrayLike, elt::Type, axes::Tuple{Vararg{AbstractUnitRange{<:Integer}}}
)
# TODO: Make generic for GPU, maybe using `blocktype`.
# TODO: For non-block axes this should output `Array`.
Expand All @@ -277,7 +268,9 @@ end
function Base.similar(
a::BlockSparseArrayLike,
elt::Type,
axes::Tuple{AbstractBlockedUnitRange,Vararg{AbstractBlockedUnitRange}},
axes::Tuple{
AbstractBlockedUnitRange{<:Integer},Vararg{AbstractBlockedUnitRange{<:Integer}}
},
)
# TODO: Make generic for GPU, maybe using `blocktype`.
# TODO: For non-block axes this should output `Array`.
Expand All @@ -289,13 +282,37 @@ end
function Base.similar(
a::BlockSparseArrayLike,
elt::Type,
axes::Tuple{AbstractUnitRange,Vararg{AbstractUnitRange}},
axes::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
)
# TODO: Make generic for GPU, maybe using `blocktype`.
# TODO: For non-block axes this should output `Array`.
return BlockSparseArray{elt}(undef, axes)
end

# Fixes ambiguity error with `BlockArrays`.
function Base.similar(
a::BlockSparseArrayLike,
elt::Type,
axes::Tuple{AbstractBlockedUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
)
# TODO: Make generic for GPU, maybe using `blocktype`.
# TODO: For non-block axes this should output `Array`.
return BlockSparseArray{elt}(undef, axes)
end

# Fixes ambiguity errors with BlockArrays.
function Base.similar(
a::BlockSparseArrayLike,
elt::Type,
axes::Tuple{
AbstractUnitRange{<:Integer},
AbstractBlockedUnitRange{<:Integer},
Vararg{AbstractUnitRange{<:Integer}},
},
)
return BlockSparseArray{elt}(undef, axes)
end

# TODO: Define a `blocksparse_similar` function.
# Fixes ambiguity error with `StaticArrays`.
function Base.similar(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,24 @@ function blocksparse_fill!(a::AbstractArray, value)
return a
end
for b in BlockRange(a)
a[b] .= value
# We can't use:
# ```julia
# a[b] .= value
# ```
# since that would lead to a stack overflow,
# because broadcasting calls `fill!`.

# TODO: Ideally we would use:
# ```julia
# @view!(a[b]) .= value
# ```
# but that doesn't work on `SubArray` right now.

# This line is needed to instantiate blocks
# that aren't instantiated yet. Maybe
# we can make this work without this line?
blocks(a)[Int.(Tuple(b))...] = blocks(a)[Int.(Tuple(b))...]
blocks(a)[Int.(Tuple(b))...] .= value
end
return a
end
Expand Down Expand Up @@ -268,6 +285,10 @@ function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::CartesianIndex{N}) w
end
function Base.setindex!(a::SparseSubArrayBlocks{<:Any,N}, value, I::Vararg{Int,N}) where {N}
parent_blocks = view(blocks(parent(a.array)), axes(a)...)
# TODO: The following line is required to instantiate
# uninstantiated blocks, maybe use `@view!` instead,
# or some other code pattern.
parent_blocks[I...] = parent_blocks[I...]
return parent_blocks[I...][blockindices(parent(a.array), Block(I), a.array.indices)...] =
value
end
Expand Down
Loading