1
- using VectorizedRNG
2
- using LinearAlgebra: Diagonal, I
3
- using LoopVectorization
4
- using RecursiveFactorization
5
- using SparseArrays
1
+ using LinearAlgebra, . Threads
6
2
7
3
struct SparseBandedMatrix{T} <: AbstractMatrix{T}
8
4
size :: Tuple{Int, Int}
@@ -45,9 +41,9 @@ function Base.getindex(M :: SparseBandedMatrix{T}, i :: Int, j :: Int, I :: Int.
45
41
zero (T)
46
42
end
47
43
48
- function Base. setindex! (M :: SparseBandedMatrix{T} , val, i :: Int , j :: Int , I :: Int... ) where T # TODO IF VAL ISNT OF TYPE T
44
+ function Base. setindex! (M :: SparseBandedMatrix{T} , val, i :: Int , j :: Int , I :: Int... ) where T
49
45
@boundscheck checkbounds (M, i, j, I... )
50
- rows, cols = size (M)
46
+ rows = size (M, 1 )
51
47
wanted_ind = rows - i + j
52
48
ind = searchsortedfirst (M. indices, wanted_ind)
53
49
if (ind > length (M. indices) || M. indices[ind] != wanted_ind)
@@ -85,9 +81,6 @@ function Base.setindex!(M :: SparseBandedMatrix{T}, val, i :: Int, j :: Int, I :
85
81
diagvals
86
82
end
87
83
88
- using LinearAlgebra
89
- using . Threads
90
-
91
84
# C = Cb + aAB
92
85
function LinearAlgebra. mul! (C :: Matrix{T} , A:: SparseBandedMatrix{T} , B :: Matrix{T} , a :: Number , b :: Number ) where T
93
86
@assert size (A, 2 ) == size (B, 1 )
@@ -111,7 +104,7 @@ function LinearAlgebra.mul!(C :: Matrix{T}, A:: SparseBandedMatrix{T}, B :: Matr
111
104
index_j = location - cols + i
112
105
end
113
106
# A[index_i, index_j] * B[index_j, j] = C[index_i, j]
114
- @simd for j in 1 : size (B, 2 )
107
+ for j in 1 : size (B, 2 )
115
108
C[index_i, j] = fma (val, B[index_j, j], C[index_i, j])
116
109
end
117
110
end
@@ -146,6 +139,50 @@ function LinearAlgebra.mul!(C :: Matrix{T}, A:: Matrix{T}, B :: SparseBandedMatr
146
139
C
147
140
end
148
141
142
+ function LinearAlgebra. mul! (C :: SparseBandedMatrix{T} , A:: SparseBandedMatrix{T} , B :: SparseBandedMatrix{T} , a :: Number , b :: Number ) where T
143
+ @assert size (A, 2 ) == size (B, 1 )
144
+ @assert size (A, 1 ) == size (C, 1 )
145
+ @assert size (B, 2 ) == size (C, 2 )
146
+
147
+ C.*= b
148
+
149
+ rows_a, cols_a = size (A)
150
+ rows_b, cols_b = size (B)
151
+ @inbounds for (ind_a, location_a) in enumerate (A. indices)
152
+ @threads for i in eachindex (A. diags[ind_a])
153
+ val_a = A. diags[ind_a][i] * a
154
+ if location_a < rows_a
155
+ index_ia = rows_a - location_a + i
156
+ index_ja = i
157
+ else
158
+ index_ia = i
159
+ index_ja = location_a - cols_a + i
160
+ end
161
+ min_loc = rows_b - index_ja + 1
162
+ max_loc = 2 * rows_b - index_ja
163
+ for (ind_b, location_b) in enumerate (B. indices)
164
+ # index_ib = index_ja
165
+ # if ind < rows(A), then index = (rows - loc + i, i)
166
+ # rows - loc + j = index_ja, j = index_ja - rows + loc
167
+ # else index = (i, loc - cols + i)
168
+ # if location < rows(B), then
169
+ if location_b <= rows_b && location_b >= min_loc
170
+ j = index_ja - rows_b + location_b
171
+ index_jb = j
172
+ val_b = B. diags[ind_b][j]
173
+ C[index_ia, index_jb] = muladd (val_a, val_b, C[index_ia, index_jb])
174
+ elseif location_b > rows_b && location_b <= max_loc
175
+ j = index_ja
176
+ index_jb = location_b - cols_b + j
177
+ val_b = B. diags[ind_b][j]
178
+ C[index_ia, index_jb] = muladd (val_a, val_b, C[index_ia, index_jb])
179
+ end
180
+ end
181
+ end
182
+ end
183
+ C
184
+ end
185
+
149
186
function LinearAlgebra. mul! (C :: Matrix{T} , A:: SparseBandedMatrix{T} , B :: SparseBandedMatrix{T} , a :: Number , b :: Number ) where T
150
187
@assert size (A, 2 ) == size (B, 1 )
151
188
@assert size (A, 1 ) == size (C, 1 )
@@ -190,6 +227,12 @@ function LinearAlgebra.mul!(C :: Matrix{T}, A:: SparseBandedMatrix{T}, B :: Spar
190
227
C
191
228
end
192
229
230
+ using VectorizedRNG
231
+ using LinearAlgebra: Diagonal, I
232
+ using LoopVectorization
233
+ using RecursiveFactorization
234
+ using SparseArrays
235
+
193
236
@inline exphalf (x) = exp (x) * oftype (x, 0.5 )
194
237
function 🦋! (wv, :: Val{SEED} = Val (888 )) where {SEED}
195
238
T = eltype (wv)
@@ -207,13 +250,13 @@ function 🦋generate_random!(A, ::Val{SEED} = Val(888)) where {SEED}
207
250
end
208
251
209
252
function 🦋workspace (A, :: Val{SEED} = Val (888 )) where {SEED}
210
- A = pad! (A)
211
253
B = similar (A);
212
254
ws = 🦋generate_random! (B)
213
255
🦋mul! (copyto! (B, A), ws)
214
- U, V, B = materializeUV (B, ws)
256
+ U, V = materializeUV (B, ws)
215
257
F = RecursiveFactorization. lu! (B, Val (false ))
216
- A, U, V, F
258
+
259
+ U, V, F
217
260
end
218
261
219
262
const butterfly_workspace = 🦋workspace;
@@ -284,30 +327,7 @@ function diagnegbottom(x)
284
327
Diagonal (y), Diagonal (z)
285
328
end
286
329
287
- # 🦋(A, B) = [A B
288
- # A -B]
289
-
290
- # Bu2 = [🦋(U₁u, U₁l) 0*I
291
- # 0*I 🦋(U₂u, U₂l)]
292
- # U1u U1l 0 0
293
- # U1u -U1l 0 0
294
- #=
295
- function 🦋!(C, A, B)
296
- A1, A2 = size(A)
297
- B1, B2 = size(B)
298
- @assert A1 == B1
299
- for j in 1 : A2, i in 1 : A1
300
- C[i, j] = A[i, j]
301
- C[i + A1, j] = A[i, j]
302
- end
303
- for j in A2 + 1 : A2 + B2, i in 1 : A1
304
- C[i, j] = B[i, j - A2]
305
- C[i + A1, j] = -B[i, j - A2]
306
- end
307
- C
308
- end
309
- =#
310
- function 🦋! (C, A:: Diagonal , B:: Diagonal )
330
+ function 🦋2 !(C, A:: Diagonal , B:: Diagonal )
311
331
@assert size (A) == size (B)
312
332
A1 = size (A, 1 )
313
333
@@ -321,19 +341,32 @@ function 🦋!(C, A::Diagonal, B::Diagonal)
321
341
C
322
342
end
323
343
324
- function 🦋! (C:: SparseBandedMatrix , A:: Diagonal , B:: Diagonal )
325
- @assert size (A) == size (B)
344
+ function 🦋! (A:: Matrix , C:: SparseBandedMatrix , X:: Diagonal , Y:: Diagonal )
345
+ @assert size (X) == size (Y)
346
+ if (size (X, 1 ) + size (Y, 1 ) != size (A, 1 ))
347
+ x = size (A, 1 ) - size (X, 1 ) - size (Y, 1 )
348
+ setdiagonal! (C, [X. diag; rand (x); - Y. diag], true )
349
+ setdiagonal! (C, X. diag, true )
350
+ setdiagonal! (C, Y. diag, false )
351
+ else
352
+ setdiagonal! (C, [X. diag; - Y. diag], true )
353
+ setdiagonal! (C, X. diag, true )
354
+ setdiagonal! (C, Y. diag, false )
355
+ end
326
356
357
+ C
358
+ end
359
+
360
+ function 🦋2 !(C:: SparseBandedMatrix , A:: Diagonal , B:: Diagonal )
327
361
setdiagonal! (C, [A. diag; - B. diag], true )
328
362
setdiagonal! (C, A. diag, true )
329
363
setdiagonal! (C, B. diag, false )
330
364
C
331
365
end
332
366
333
-
334
367
function materializeUV (A, (uv,))
335
368
M, N = size (A)
336
- Mh = M >>> 1
369
+ Mh = M >>> 1
337
370
Nh = N >>> 1
338
371
339
372
U₁u, U₁l = diagnegbottom (@view (uv[1 : Mh]))
@@ -346,30 +379,46 @@ function materializeUV(A, (uv,))
346
379
# WRITE OUT MERGINGS EXPLICITLY
347
380
# Bu2 = [🦋(U₁u, U₁l) 0*I
348
381
# 0*I 🦋(U₂u, U₂l)]
349
- # show size(Bu2)[1] #808
350
- # @show size(🦋(V₁u, V₁l))[1] #404
351
382
352
383
# Bu2 = spzeros(M, N)
384
+
385
+ mrng = VectorizedRNG. MutableXoshift (888 )
386
+ T = typeof (uv[1 ])
387
+
353
388
Bu2 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N)
354
389
355
- 🦋! (view (Bu2, 1 : M ÷ 2 , 1 : N ÷ 2 ), U₁u, U₁l)
356
- 🦋! (view (Bu2, M ÷ 2 + 1 : M, N ÷ 2 + 1 : N), U₂u, U₂l)
390
+ 🦋2 !(view (Bu2, 1 : (M ÷ 4 ) * 2 , 1 : (N ÷ 4 ) * 2 ), U₁u, U₁l)
391
+ 🦋2 !(view (Bu2, M - M ÷ 4 * 2 + 1 : M, N - N ÷ 4 * 2 + 1 : N), U₂u, U₂l)
392
+ rand! (mrng, diag (view (Bu2, 1 : (M ÷ 4 ) * 2 , 1 : (N ÷ 4 ) * 2 )), static (0 ), T (- 0.05 ), T (0.1 ))
393
+
357
394
358
395
# Bu1 = spzeros(M, N)
359
396
Bu1 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N)
360
- 🦋! (Bu1, Uu, Ul)
397
+ 🦋! (A, Bu1, Uu, Ul)
361
398
362
399
# Bv2 = spzeros(M, N)
363
400
Bv2 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N)
364
401
365
- 🦋! (view (Bv2, 1 : M ÷ 2 , 1 : N ÷ 2 ), V₁u, V₁l)
366
- 🦋! (view (Bv2, M ÷ 2 + 1 : M, N ÷ 2 + 1 : N), V₂u, V₂l)
402
+ 🦋2 !(view (Bv2, 1 : (M ÷ 4 ) * 2 , 1 : (N ÷ 4 ) * 2 ), V₁u, V₁l)
403
+ 🦋2 !(view (Bv2, M - M ÷ 4 * 2 + 1 : M, N - N ÷ 4 * 2 + 1 : N), V₂u, V₂l)
404
+ rand! (mrng, diag (view (Bv2, 1 : (M ÷ 4 ) * 2 , 1 : (N ÷ 4 ) * 2 )), static (0 ), T (- 0.05 ), T (0.1 ))
367
405
368
406
# Bv1 = spzeros(M, N)
369
407
Bv1 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N)
370
- 🦋! (Bv1, Vu, Vl)
408
+ 🦋! (A, Bv1, Vu, Vl)
371
409
372
- (Bu2 * Bu1)' , Bv2 * Bv1, A
410
+ # U = similar(A)
411
+ # U = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
412
+
413
+ # mul!(U, Bu2, Bu1, 1, 0)
414
+
415
+ # V = similar(A)
416
+ # V = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
417
+ # mul!(V, Bv2, Bv1, 1, 0)
418
+ # U = sparse(U)
419
+ # V = sparse(V)
420
+
421
+ (Bu2 * Bu1)' , Bv2 * Bv1
373
422
end
374
423
375
424
function pad! (A)
@@ -389,4 +438,85 @@ function pad!(A)
389
438
@inbounds A_new[i,j] = i == j
390
439
end
391
440
A_new
392
- end
441
+ end
442
+
443
+
444
+
445
+
446
+
447
+
448
+
449
+
450
+
451
+ #=
452
+ using SparseArrays, BenchmarkTools, Random
453
+
454
+ function get_data1()
455
+ dim = 5000
456
+ x = rand(10:75)
457
+ diag_vals = Vector{Vector{Float64}}(undef, x)
458
+ diag_locs = randperm(dim * 2 - 1)[1:x]
459
+ for j in 1:x
460
+ diag_vals[j] = rand(min(diag_locs[j], 2 * dim - diag_locs[j]))
461
+ end
462
+
463
+ x_butterfly = SparseBandedMatrix{Float64}(diag_locs, diag_vals, dim, dim)
464
+ x_dense = copy(x_butterfly)
465
+
466
+ y = rand(dim, dim)
467
+ z = zeros(dim, dim)
468
+
469
+ @show norm(x_dense*y - x_butterfly * y)
470
+
471
+ println("Timing dense multiplication.")
472
+ println("(left-side mul)")
473
+ @btime x_dense*y;
474
+ println("(right-side mul)")
475
+ @btime y*x_dense;
476
+ println("\nTiming butterfly multiplication.")
477
+ println("(left-side mul)")
478
+ @btime x_butterfly*y;
479
+ println("(right-side mul)")
480
+ @btime y*x_butterfly;
481
+
482
+ nothing
483
+ end
484
+
485
+ function get_data2()
486
+ dim = 1000
487
+ x = rand(10:40)
488
+ diag_vals = Vector{Vector{Float64}}(undef, x)
489
+ diag_locs = randperm(dim * 2 - 1)[1:x]
490
+ for j in 1:x
491
+ diag_vals[j] = rand(min(diag_locs[j], 2 * dim - diag_locs[j]))
492
+ end
493
+
494
+ x_butterfly = SparseBandedMatrix{Float64}(diag_locs, diag_vals, dim, dim)
495
+ x_dense = copy(x_butterfly)
496
+ x_sparse = sparse(x_dense)
497
+
498
+ y = rand(10:40)
499
+ diag_vals = Vector{Vector{Float64}}(undef, y)
500
+ diag_locs = randperm(dim * 2 - 1)[1:y]
501
+ for j in 1:y
502
+ diag_vals[j] = rand(min(diag_locs[j], 2 * dim - diag_locs[j]))
503
+ end
504
+
505
+ y_butterfly = SparseBandedMatrix{Float64}(diag_locs, diag_vals, dim, dim)
506
+ y_dense = copy(y_butterfly)
507
+ y_sparse = sparse(y_dense)
508
+
509
+ a = true
510
+ b = false
511
+ @assert isapprox(x_butterfly * y_butterfly, x_dense * y_dense)
512
+ println("Timing butterfly multiplication.")
513
+ @btime mul!(zeros(dim, dim), x_butterfly, y_butterfly, a, b);
514
+ println("\nTiming sparse multiplication.")
515
+ @btime mul!(zeros(dim, dim), x_sparse, y_sparse, a, b);
516
+ println("\nTiming dense multiplication.")
517
+ @btime mul!(zeros(dim, dim), x_dense, y_dense, a, b);
518
+
519
+ nothing
520
+ end
521
+ =#
522
+
0 commit comments