diff --git a/stdlib/LinearAlgebra/src/dense.jl b/stdlib/LinearAlgebra/src/dense.jl index 27fa515e70a02..94926805bb387 100644 --- a/stdlib/LinearAlgebra/src/dense.jl +++ b/stdlib/LinearAlgebra/src/dense.jl @@ -617,7 +617,6 @@ function exp!(A::StridedMatrix{T}) where T<:BlasFloat end ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A nA = opnorm(A, 1) - Inn = Matrix{T}(I, n, n) ## For sufficiently small nA, use lower order Padé-Approximations if (nA <= 2.1) if nA > 0.95 @@ -634,17 +633,21 @@ function exp!(A::StridedMatrix{T}) where T<:BlasFloat C = T[120.,60.,12.,1.] end A2 = A * A - P = copy(Inn) - U = C[2] * P - V = C[1] * P - for k in 1:(div(size(C, 1), 2) - 1) + # Compute U and V: Even/odd terms in Padé numerator & denom + # Expansion of k=1 in for loop + P = A2 + U = C[2]*I + C[4]*P + V = C[1]*I + C[3]*P + for k in 2:(div(size(C, 1), 2) - 1) k2 = 2 * k P *= A2 - U += C[k2 + 2] * P - V += C[k2 + 1] * P + mul!(U, C[k2 + 2], P, true, true) # U += C[k2+2]*P + mul!(V, C[k2 + 1], P, true, true) # V += C[k2+1]*P end + U = A * U X = V + U + # Padé approximant: (V-U)\(V+U) LAPACK.gesv!(V-U, X) else s = log2(nA/5.4) # power of 2 later reversed by squaring @@ -660,10 +663,27 @@ function exp!(A::StridedMatrix{T}) where T<:BlasFloat A2 = A * A A4 = A2 * A2 A6 = A2 * A4 - U = A * (A6 * (CC[14].*A6 .+ CC[12].*A4 .+ CC[10].*A2) .+ - CC[8].*A6 .+ CC[6].*A4 .+ CC[4].*A2 .+ CC[2].*Inn) - V = A6 * (CC[13].*A6 .+ CC[11].*A4 .+ CC[9].*A2) .+ - CC[7].*A6 .+ CC[5].*A4 .+ CC[3].*A2 .+ CC[1].*Inn + Ut = CC[4]*A2 + Ut[diagind(Ut)] .+= CC[2] + # Allocation economical version of: + #U = A * (A6 * (CC[14].*A6 .+ CC[12].*A4 .+ CC[10].*A2) .+ + # CC[8].*A6 .+ CC[6].*A4 .+ Ut) + U = mul!(CC[8].*A6 .+ CC[6].*A4 .+ Ut, + A6, + CC[14].*A6 .+ CC[12].*A4 .+ CC[10].*A2, + true, true) + U = A*U + + # Allocation economical version of: Vt = CC[3]*A2 (recycle Ut) + Vt = mul!(Ut, CC[3], A2, true, false) + Vt[diagind(Vt)] .+= CC[1] + # Allocation economical version of: + #V = A6 * (CC[13].*A6 .+ CC[11].*A4 .+ CC[9].*A2) .+ + # CC[7].*A6 .+ CC[5].*A4 .+ Vt + V = mul!(CC[7].*A6 .+ CC[5].*A4 .+ Vt, + A6, + CC[13].*A6 .+ CC[11].*A4 .+ CC[9].*A2, + true, true) X = V + U LAPACK.gesv!(V-U, X)