Skip to content

Commit d01c432

Browse files
edits
1 parent 82cd318 commit d01c432

File tree

2 files changed

+211
-53
lines changed

2 files changed

+211
-53
lines changed

src/butterflylu.jl

Lines changed: 182 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
using VectorizedRNG
2-
using LinearAlgebra: Diagonal, I
3-
using LoopVectorization
4-
using RecursiveFactorization
5-
using SparseArrays
1+
using LinearAlgebra, .Threads
62

73
struct SparseBandedMatrix{T} <: AbstractMatrix{T}
84
size :: Tuple{Int, Int}
@@ -45,9 +41,9 @@ function Base.getindex(M :: SparseBandedMatrix{T}, i :: Int, j :: Int, I :: Int.
4541
zero(T)
4642
end
4743

48-
function Base.setindex!(M :: SparseBandedMatrix{T}, val, i :: Int, j :: Int, I :: Int...) where T #TODO IF VAL ISNT OF TYPE T
44+
function Base.setindex!(M :: SparseBandedMatrix{T}, val, i :: Int, j :: Int, I :: Int...) where T
4945
@boundscheck checkbounds(M, i, j, I...)
50-
rows, cols = size(M)
46+
rows = size(M, 1)
5147
wanted_ind = rows - i + j
5248
ind = searchsortedfirst(M.indices, wanted_ind)
5349
if (ind > length(M.indices) || M.indices[ind] != wanted_ind)
@@ -85,9 +81,6 @@ function Base.setindex!(M :: SparseBandedMatrix{T}, val, i :: Int, j :: Int, I :
8581
diagvals
8682
end
8783

88-
using LinearAlgebra
89-
using .Threads
90-
9184
# C = Cb + aAB
9285
function LinearAlgebra.mul!(C :: Matrix{T}, A:: SparseBandedMatrix{T}, B :: Matrix{T}, a :: Number, b :: Number) where T
9386
@assert size(A, 2) == size(B, 1)
@@ -111,7 +104,7 @@ function LinearAlgebra.mul!(C :: Matrix{T}, A:: SparseBandedMatrix{T}, B :: Matr
111104
index_j = location - cols + i
112105
end
113106
#A[index_i, index_j] * B[index_j, j] = C[index_i, j]
114-
@simd for j in 1 : size(B, 2)
107+
for j in 1 : size(B, 2)
115108
C[index_i, j] = fma(val, B[index_j, j], C[index_i, j])
116109
end
117110
end
@@ -146,6 +139,50 @@ function LinearAlgebra.mul!(C :: Matrix{T}, A:: Matrix{T}, B :: SparseBandedMatr
146139
C
147140
end
148141

142+
function LinearAlgebra.mul!(C :: SparseBandedMatrix{T}, A:: SparseBandedMatrix{T}, B :: SparseBandedMatrix{T}, a :: Number, b :: Number) where T
143+
@assert size(A, 2) == size(B, 1)
144+
@assert size(A, 1) == size(C, 1)
145+
@assert size(B, 2) == size(C, 2)
146+
147+
C.*=b
148+
149+
rows_a, cols_a = size(A)
150+
rows_b, cols_b = size(B)
151+
@inbounds for (ind_a, location_a) in enumerate(A.indices)
152+
@threads for i in eachindex(A.diags[ind_a])
153+
val_a = A.diags[ind_a][i] * a
154+
if location_a < rows_a
155+
index_ia = rows_a - location_a + i
156+
index_ja = i
157+
else
158+
index_ia = i
159+
index_ja = location_a - cols_a + i
160+
end
161+
min_loc = rows_b - index_ja + 1
162+
max_loc = 2 * rows_b - index_ja
163+
for (ind_b, location_b) in enumerate(B.indices)
164+
#index_ib = index_ja
165+
# if ind < rows(A), then index = (rows - loc + i, i)
166+
#rows - loc + j = index_ja, j = index_ja - rows + loc
167+
# else index = (i, loc - cols + i)
168+
# if location < rows(B), then
169+
if location_b <= rows_b && location_b >= min_loc
170+
j = index_ja - rows_b + location_b
171+
index_jb = j
172+
val_b = B.diags[ind_b][j]
173+
C[index_ia, index_jb] = muladd(val_a, val_b, C[index_ia, index_jb])
174+
elseif location_b > rows_b && location_b <= max_loc
175+
j = index_ja
176+
index_jb = location_b - cols_b + j
177+
val_b = B.diags[ind_b][j]
178+
C[index_ia, index_jb] = muladd(val_a, val_b, C[index_ia, index_jb])
179+
end
180+
end
181+
end
182+
end
183+
C
184+
end
185+
149186
function LinearAlgebra.mul!(C :: Matrix{T}, A:: SparseBandedMatrix{T}, B :: SparseBandedMatrix{T}, a :: Number, b :: Number) where T
150187
@assert size(A, 2) == size(B, 1)
151188
@assert size(A, 1) == size(C, 1)
@@ -190,6 +227,12 @@ function LinearAlgebra.mul!(C :: Matrix{T}, A:: SparseBandedMatrix{T}, B :: Spar
190227
C
191228
end
192229

230+
using VectorizedRNG
231+
using LinearAlgebra: Diagonal, I
232+
using LoopVectorization
233+
using RecursiveFactorization
234+
using SparseArrays
235+
193236
@inline exphalf(x) = exp(x) * oftype(x, 0.5)
194237
function 🦋!(wv, ::Val{SEED} = Val(888)) where {SEED}
195238
T = eltype(wv)
@@ -207,13 +250,13 @@ function 🦋generate_random!(A, ::Val{SEED} = Val(888)) where {SEED}
207250
end
208251

209252
function 🦋workspace(A, ::Val{SEED} = Val(888)) where {SEED}
210-
A = pad!(A)
211253
B = similar(A);
212254
ws = 🦋generate_random!(B)
213255
🦋mul!(copyto!(B, A), ws)
214-
U, V, B = materializeUV(B, ws)
256+
U, V = materializeUV(B, ws)
215257
F = RecursiveFactorization.lu!(B, Val(false))
216-
A, U, V, F
258+
259+
U, V, F
217260
end
218261

219262
const butterfly_workspace = 🦋workspace;
@@ -284,30 +327,7 @@ function diagnegbottom(x)
284327
Diagonal(y), Diagonal(z)
285328
end
286329

287-
#🦋(A, B) = [A B
288-
# A -B]
289-
290-
#Bu2 = [🦋(U₁u, U₁l) 0*I
291-
# 0*I 🦋(U₂u, U₂l)]
292-
# U1u U1l 0 0
293-
# U1u -U1l 0 0
294-
#=
295-
function 🦋!(C, A, B)
296-
A1, A2 = size(A)
297-
B1, B2 = size(B)
298-
@assert A1 == B1
299-
for j in 1 : A2, i in 1 : A1
300-
C[i, j] = A[i, j]
301-
C[i + A1, j] = A[i, j]
302-
end
303-
for j in A2 + 1 : A2 + B2, i in 1 : A1
304-
C[i, j] = B[i, j - A2]
305-
C[i + A1, j] = -B[i, j - A2]
306-
end
307-
C
308-
end
309-
=#
310-
function 🦋!(C, A::Diagonal, B::Diagonal)
330+
function 🦋2!(C, A::Diagonal, B::Diagonal)
311331
@assert size(A) == size(B)
312332
A1 = size(A, 1)
313333

@@ -321,19 +341,32 @@ function 🦋!(C, A::Diagonal, B::Diagonal)
321341
C
322342
end
323343

324-
function 🦋!(C::SparseBandedMatrix, A::Diagonal, B::Diagonal)
325-
@assert size(A) == size(B)
344+
function 🦋!(A::Matrix, C::SparseBandedMatrix, X::Diagonal, Y::Diagonal)
345+
@assert size(X) == size(Y)
346+
if (size(X, 1) + size(Y, 1) != size(A, 1))
347+
x = size(A, 1) - size(X, 1) - size(Y, 1)
348+
setdiagonal!(C, [X.diag; rand(x); -Y.diag], true)
349+
setdiagonal!(C, X.diag, true)
350+
setdiagonal!(C, Y.diag, false)
351+
else
352+
setdiagonal!(C, [X.diag; -Y.diag], true)
353+
setdiagonal!(C, X.diag, true)
354+
setdiagonal!(C, Y.diag, false)
355+
end
326356

357+
C
358+
end
359+
360+
function 🦋2!(C::SparseBandedMatrix, A::Diagonal, B::Diagonal)
327361
setdiagonal!(C, [A.diag; -B.diag], true)
328362
setdiagonal!(C, A.diag, true)
329363
setdiagonal!(C, B.diag, false)
330364
C
331365
end
332366

333-
334367
function materializeUV(A, (uv,))
335368
M, N = size(A)
336-
Mh = M >>> 1
369+
Mh = M >>> 1
337370
Nh = N >>> 1
338371

339372
U₁u, U₁l = diagnegbottom(@view(uv[1:Mh]))
@@ -346,30 +379,46 @@ function materializeUV(A, (uv,))
346379
#WRITE OUT MERGINGS EXPLICITLY
347380
#Bu2 = [🦋(U₁u, U₁l) 0*I
348381
# 0*I 🦋(U₂u, U₂l)]
349-
#show size(Bu2)[1] #808
350-
#@show size(🦋(V₁u, V₁l))[1] #404
351382

352383
#Bu2 = spzeros(M, N)
384+
385+
mrng = VectorizedRNG.MutableXoshift(888)
386+
T = typeof(uv[1])
387+
353388
Bu2 = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
354389

355-
🦋!(view(Bu2, 1 : M ÷ 2, 1 : N ÷ 2), U₁u, U₁l)
356-
🦋!(view(Bu2, M ÷ 2 + 1 : M, N ÷ 2 + 1 : N), U₂u, U₂l)
390+
🦋2!(view(Bu2, 1 : (M ÷ 4) * 2, 1 : (N ÷ 4) * 2), U₁u, U₁l)
391+
🦋2!(view(Bu2, M - M ÷ 4 * 2 + 1: M, N - N ÷ 4 * 2 + 1: N), U₂u, U₂l)
392+
rand!(mrng, diag(view(Bu2, 1 : (M ÷ 4) * 2, 1 : (N ÷ 4) * 2)), static(0), T(-0.05), T(0.1))
393+
357394

358395
#Bu1 = spzeros(M, N)
359396
Bu1 = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
360-
🦋!(Bu1, Uu, Ul)
397+
🦋!(A, Bu1, Uu, Ul)
361398

362399
#Bv2 = spzeros(M, N)
363400
Bv2 = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
364401

365-
🦋!(view(Bv2, 1 : M ÷ 2, 1 : N ÷ 2), V₁u, V₁l)
366-
🦋!(view(Bv2, M ÷ 2 + 1 : M, N ÷ 2 + 1 : N), V₂u, V₂l)
402+
🦋2!(view(Bv2, 1 : (M ÷ 4) * 2, 1 : (N ÷ 4) * 2), V₁u, V₁l)
403+
🦋2!(view(Bv2, M - M ÷ 4 * 2 + 1: M, N - N ÷ 4 * 2 + 1: N), V₂u, V₂l)
404+
rand!(mrng, diag(view(Bv2, 1 : (M ÷ 4) * 2, 1 : (N ÷ 4) * 2)), static(0), T(-0.05), T(0.1))
367405

368406
#Bv1 = spzeros(M, N)
369407
Bv1 = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
370-
🦋!(Bv1, Vu, Vl)
408+
🦋!(A, Bv1, Vu, Vl)
371409

372-
(Bu2 * Bu1)', Bv2 * Bv1, A
410+
#U = similar(A)
411+
#U = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
412+
413+
#mul!(U, Bu2, Bu1, 1, 0)
414+
415+
#V = similar(A)
416+
#V = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
417+
#mul!(V, Bv2, Bv1, 1, 0)
418+
#U = sparse(U)
419+
#V = sparse(V)
420+
421+
(Bu2 * Bu1)', Bv2 * Bv1
373422
end
374423

375424
function pad!(A)
@@ -389,4 +438,85 @@ function pad!(A)
389438
@inbounds A_new[i,j] = i == j
390439
end
391440
A_new
392-
end
441+
end
442+
443+
444+
445+
446+
447+
448+
449+
450+
451+
#=
452+
using SparseArrays, BenchmarkTools, Random
453+
454+
function get_data1()
455+
dim = 5000
456+
x = rand(10:75)
457+
diag_vals = Vector{Vector{Float64}}(undef, x)
458+
diag_locs = randperm(dim * 2 - 1)[1:x]
459+
for j in 1:x
460+
diag_vals[j] = rand(min(diag_locs[j], 2 * dim - diag_locs[j]))
461+
end
462+
463+
x_butterfly = SparseBandedMatrix{Float64}(diag_locs, diag_vals, dim, dim)
464+
x_dense = copy(x_butterfly)
465+
466+
y = rand(dim, dim)
467+
z = zeros(dim, dim)
468+
469+
@show norm(x_dense*y - x_butterfly * y)
470+
471+
println("Timing dense multiplication.")
472+
println("(left-side mul)")
473+
@btime x_dense*y;
474+
println("(right-side mul)")
475+
@btime y*x_dense;
476+
println("\nTiming butterfly multiplication.")
477+
println("(left-side mul)")
478+
@btime x_butterfly*y;
479+
println("(right-side mul)")
480+
@btime y*x_butterfly;
481+
482+
nothing
483+
end
484+
485+
function get_data2()
486+
dim = 1000
487+
x = rand(10:40)
488+
diag_vals = Vector{Vector{Float64}}(undef, x)
489+
diag_locs = randperm(dim * 2 - 1)[1:x]
490+
for j in 1:x
491+
diag_vals[j] = rand(min(diag_locs[j], 2 * dim - diag_locs[j]))
492+
end
493+
494+
x_butterfly = SparseBandedMatrix{Float64}(diag_locs, diag_vals, dim, dim)
495+
x_dense = copy(x_butterfly)
496+
x_sparse = sparse(x_dense)
497+
498+
y = rand(10:40)
499+
diag_vals = Vector{Vector{Float64}}(undef, y)
500+
diag_locs = randperm(dim * 2 - 1)[1:y]
501+
for j in 1:y
502+
diag_vals[j] = rand(min(diag_locs[j], 2 * dim - diag_locs[j]))
503+
end
504+
505+
y_butterfly = SparseBandedMatrix{Float64}(diag_locs, diag_vals, dim, dim)
506+
y_dense = copy(y_butterfly)
507+
y_sparse = sparse(y_dense)
508+
509+
a = true
510+
b = false
511+
@assert isapprox(x_butterfly * y_butterfly, x_dense * y_dense)
512+
println("Timing butterfly multiplication.")
513+
@btime mul!(zeros(dim, dim), x_butterfly, y_butterfly, a, b);
514+
println("\nTiming sparse multiplication.")
515+
@btime mul!(zeros(dim, dim), x_sparse, y_sparse, a, b);
516+
println("\nTiming dense multiplication.")
517+
@btime mul!(zeros(dim, dim), x_dense, y_dense, a, b);
518+
519+
nothing
520+
end
521+
=#
522+

test/runtests.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,32 @@ end
8888
x = V * (F \ (U * b_ext))
8989
@test norm(A * x[1:M] .- b) <= 1e-10
9090
end
91-
end
91+
end
92+
93+
i = 800
94+
A = wilkinson(i)
95+
b = rand(i)
96+
U, V, F = RecursiveFactorization.🦋workspace(A)
97+
x = V * (F \ (U * b))
98+
@test norm(A * x .- b) <= 1e-10
99+
100+
#=
101+
102+
i = 8
103+
A = rand(i, i)
104+
b = rand(i)
105+
A_ext, U, V, F = RecursiveFactorization.🦋workspace(A)
106+
tmp = U * b
107+
F \ tmp
108+
(F.L * F.U) \ tmp
109+
#tmp * F.U ^ -1 * F.L^-1
110+
F.U \ (tmp / F.L)
111+
isapprox(F \ tmp, F.U \ (F.L \ tmp))
112+
113+
x = V * (F \ (U * b))
114+
115+
x = V * (TriangularSolve.ldiv!(UpperTriangular(F.U), TriangularSolve.ldiv!(LowerTriangular(F.L), U * b)))
116+
117+
norm(V * (TriangularSolve.ldiv!(UpperTriangular(F.U), TriangularSolve.ldiv!(LowerTriangular(F.L), U * b)))- V*(F\(U*b)))
118+
# goal: solve F.L * F.U * x = b
119+
=#

0 commit comments

Comments
 (0)