-
Notifications
You must be signed in to change notification settings - Fork 128
[BlockSparseArrays] Initial support for more general blocks, such as GPU blocks #1560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Here is a demonstration: using BlockArrays: Block, blockedrange
using Metal: mtl
using NDTensors.BlockSparseArrays: BlockSparseArray
function randn_block_diag(n)
return BlockSparseArray(
[Block(1, 1), Block(2, 2)],
[randn(Float32, n, n), randn(Float32, n, n)],
blockedrange.(([n, n], [n, n])),
)
end
n = 4000
a = randn_block_diag(n)
b = randn_block_diag(n)
@time a * b
a_gpu = mtl(a)
b_gpu = mtl(b)
@time a_gpu * b_gpuwhich outputs: 0.478664 seconds (69 allocations: 122.073 MiB, 4.58% gc time)
0.000443 seconds (442 allocations: 13.984 KiB) |
…l into BlockSparseArrays_gpu
|
Some basic operations with using Distributed: @everywhere, addprocs
addprocs(2)
using Adapt: Adapt, adapt
using BlockArrays: Block, blockedrange
@everywhere using Dagger: Dagger, AutoBlocks, DArray, distribute
using NDTensors.BlockSparseArrays: BlockSparseArray, BlockZero, block_size
struct DArrayAdaptor end
function Adapt.adapt_storage(::DArrayAdaptor, a::AbstractArray)
return distribute(a)
end
function Dagger.distribute(a::BlockSparseArray)
return adapt(DArrayAdaptor(), a)
end
function randn_block_diag(n)
return BlockSparseArray(
[Block(1, 1), Block(2, 2)],
[randn(Float32, n, n), randn(Float32, n, n)],
blockedrange.(([n, n], [n, n])),
)
end
function (f::BlockZero)(arraytype::Type{<:DArray}, I::CartesianIndex)
blck_size = block_size(f.axes, Block(Tuple(I)))
return zeros(AutoBlocks(), eltype(arraytype), blck_size...)
end
n = 4
a = randn_block_diag(n)
b = randn_block_diag(n)
c = a * b
a_d = distribute(a)
b_d = distribute(b)
c_d = a_d * b_d
c ≈ c_d |
|
I've converted some of the BlockSparseArray tests to also (optionally) run on GPU backends, and on CPU run with the JLArray backend. It caught a few scalar indexing bugs which we can investigate in future PRs (@kmp5VT). But basic slicing, scalar multiplication, addition, and matrix multiplication operations work on GPU. I'll merge this once tests pass and it can be used as a starting point for future work. |
This fixes broken functionality for
BlockSparseArraysthat have blocks that aren't justArray, such as blocks that are GPU arrays. Before this PR, the library supported constructing block sparse arrays with more general blocks, but functionality like adding or multiplying them was broken or implicitly moved data to CPU.To-do:
Adapt.jloverloads forAbstractBlockSparseArrayin terms of mapping adapt over nonzero/stored blocks.similartypetoTypeParameterAccessors(completed in [TypeParameterAccessors]similartype#1561).Future work:
BlockSparseArraywithDiagonalblocks on GPU.BlockSparseMatrixwithDiagonalblocks on GPU.Diagonal(for example, accessing non-allocated blocks is currently broken).@kmp5VT this should also help with making block sparse arrays that have distributed blocks, though I haven't tested that. But also this PR should give you some guidance on where you might look to fix issues that come up with that, like where in the code the output of matrix multiplication is defined so that can be customized if needed.