diff --git a/Project.toml b/Project.toml index 3091505d7..e8d7ddde1 100644 --- a/Project.toml +++ b/Project.toml @@ -120,7 +120,7 @@ PrecompileTools = "1.2" Preferences = "1.4" Random = "1.10" RecursiveArrayTools = "3.37" -RecursiveFactorization = "0.2.23" +RecursiveFactorization = "0.2.25" Reexport = "1.2.2" SafeTestsets = "0.1" SciMLBase = "2.70" diff --git a/benchmarks/lu.jl b/benchmarks/lu.jl index 896ee952e..d75e838f3 100644 --- a/benchmarks/lu.jl +++ b/benchmarks/lu.jl @@ -1,7 +1,8 @@ using BenchmarkTools, Random, VectorizationBase using LinearAlgebra, LinearSolve, MKL_jll +using RecursiveFactorization + nc = min(Int(VectorizationBase.num_cores()), Threads.nthreads()) -BLAS.set_num_threads(nc) BenchmarkTools.DEFAULT_PARAMETERS.seconds = 0.5 function luflop(m, n = m; innerflop = 2) @@ -24,10 +25,10 @@ algs = [ RFLUFactorization(), MKLLUFactorization(), FastLUFactorization(), - SimpleLUFactorization() + SimpleLUFactorization(), + ButterflyFactorization(Val(true)) ] res = [Float64[] for i in 1:length(algs)] - ns = 4:8:500 for i in 1:length(ns) n = ns[i] @@ -65,3 +66,4 @@ p savefig("lubench.png") savefig("lubench.pdf") + diff --git a/ext/LinearSolveRecursiveFactorizationExt.jl b/ext/LinearSolveRecursiveFactorizationExt.jl index 947dd8020..c3f25e3d8 100644 --- a/ext/LinearSolveRecursiveFactorizationExt.jl +++ b/ext/LinearSolveRecursiveFactorizationExt.jl @@ -1,7 +1,7 @@ module LinearSolveRecursiveFactorizationExt using LinearSolve: LinearSolve, userecursivefactorization, LinearCache, @get_cacheval, - RFLUFactorization, RF32MixedLUFactorization, default_alias_A, + RFLUFactorization, ButterflyFactorization, RF32MixedLUFactorization, default_alias_A, default_alias_b using LinearSolve.LinearAlgebra, LinearSolve.ArrayInterface, RecursiveFactorization using SciMLBase: SciMLBase, ReturnCode @@ -19,7 +19,6 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::RFLUFactorization end fact = RecursiveFactorization.lu!(A, ipiv, Val(P), Val(T), check = false) cache.cacheval = (fact, ipiv) - if !LinearAlgebra.issuccess(fact) return SciMLBase.build_linear_solution( alg, cache.u, nothing, cache; retcode = ReturnCode.Failure) @@ -105,4 +104,26 @@ function SciMLBase.solve!( alg, cache.u, nothing, cache; retcode = ReturnCode.Success) end +function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::ButterflyFactorization; + kwargs...) + A = cache.A + A = convert(AbstractMatrix, A) + b = cache.b + M, N = size(A) + if cache.isfresh + @assert M==N "A must be square" + ws = RecursiveFactorization.🦋workspace(A, b) + cache.cacheval = (ws) + cache.isfresh = false + end + out = RecursiveFactorization.🦋lu!(ws, M, alg.thread) + SciMLBase.build_linear_solution(alg, out, nothing, cache) end + +function LinearSolve.init_cacheval(alg::ButterflyFactorization, A, b, u, Pl, Pr, maxiters::Int, + abstol, reltol, verbose::Bool, assumptions::LinearSolve.OperatorAssumptions) + ws = RecursiveFactorization.🦋workspace(A, b) +end + +end + diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index c87dc2dc6..ebd6df78b 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -422,7 +422,7 @@ for kralg in (Krylov.lsmr!, Krylov.craigmr!) end for alg in (:LUFactorization, :FastLUFactorization, :SVDFactorization, :GenericFactorization, :GenericLUFactorization, :SimpleLUFactorization, - :RFLUFactorization, :UMFPACKFactorization, :KLUFactorization, :SparspakFactorization, + :RFLUFactorization, :ButterflyFactorization, :UMFPACKFactorization, :KLUFactorization, :SparspakFactorization, :DiagonalFactorization, :CholeskyFactorization, :BunchKaufmanFactorization, :CHOLMODFactorization, :LDLtFactorization, :AppleAccelerateLUFactorization, :MKLLUFactorization, :MetalLUFactorization, :CUSOLVERRFFactorization) @@ -464,7 +464,7 @@ cudss_loaded(A) = false is_cusparse(A) = false export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization, - GenericLUFactorization, SimpleLUFactorization, RFLUFactorization, + GenericLUFactorization, SimpleLUFactorization, RFLUFactorization, ButterflyFactorization, NormalCholeskyFactorization, NormalBunchKaufmanFactorization, UMFPACKFactorization, KLUFactorization, FastLUFactorization, FastQRFactorization, SparspakFactorization, DiagonalFactorization, CholeskyFactorization, diff --git a/src/extension_algs.jl b/src/extension_algs.jl index 51cdb901f..70d373ad2 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -254,6 +254,29 @@ function RFLUFactorization(; pivot = Val(true), thread = Val(true), throwerror = RFLUFactorization(pivot, thread; throwerror) end +""" +`ButterflyFactorization()` + +A fast pure Julia LU-factorization implementation +using RecursiveFactorization.jl. This method utilizes a butterly +factorization approach rather than pivoting. +""" +struct ButterflyFactorization{T} <: AbstractDenseFactorization + thread::Val{T} + function ButterflyFactorization(::Val{T}; throwerror = true) where {T} + if !userecursivefactorization(nothing) + throwerror && + error("ButterflyFactorization requires that RecursiveFactorization.jl is loaded, i.e. `using RecursiveFactorization`") + end + new{T}() + end +end + +function ButterflyFactorization(; thread = Val(true), throwerror = true) + ButterflyFactorization(thread; throwerror) +end + + # There's no options like pivot here. # But I'm not sure it makes sense as a GenericFactorization # since it just uses `LAPACK.getrf!`. diff --git a/test/butterfly.jl b/test/butterfly.jl new file mode 100644 index 000000000..9e10ae43d --- /dev/null +++ b/test/butterfly.jl @@ -0,0 +1,35 @@ +using LinearAlgebra, LinearSolve +using Test +using RecursiveFactorization + +@testset "Random Matricies" begin + for i in 490 : 510 + A = rand(i, i) + b = rand(i) + prob = LinearProblem(A, b) + x = solve(prob, ButterflyFactorization()) + @test norm(A * x .- b) <= 1e-6 + end +end + +function wilkinson(N) + A = zeros(N, N) + A[1:(N+1):N*N] .= 1 + A[:, end] .= 1 + for n in 1:(N - 1) + for r in (n + 1):N + @inbounds A[r, n] = -1 + end + end + A +end + +@testset "Wilkinson" begin + for i in 790 : 810 + A = wilkinson(i) + b = rand(i) + prob = LinearProblem(A, b) + x = solve(prob, ButterflyFactorization()) + @test norm(A * x .- b) <= 1e-10 + end +end