Skip to content

Commit

Permalink
some progress with factorisations and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Nov 25, 2024
1 parent 3d66e5d commit 2d5b83a
Show file tree
Hide file tree
Showing 13 changed files with 582 additions and 155 deletions.
7 changes: 6 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@ version = "0.1.0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[compat]
Aqua = "0.6, 0.7, 0.8"
JET = "0.9"
LinearAlgebra = "1"
Test = "1"
TestExtras = "0.2,0.3"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"

[targets]
test = ["Aqua", "JET", "Test"]
test = ["Aqua", "JET", "Test", "TestExtras"]
5 changes: 5 additions & 0 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@ module MatrixAlgebraKit
using LinearAlgebra: LinearAlgebra
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt, triu!

export qr_compact!, qr_full!
export eigh_full!, eigh_vals!, eigh_trunc!
export svd_compact!, svd_full!, svd_vals!, svd_trunc!

include("auxiliary.jl")
include("backend.jl")
include("yalapack.jl")
include("qr.jl")
include("svd.jl")
include("eigh.jl")

Expand Down
2 changes: 1 addition & 1 deletion src/backend.jl
Original file line number Diff line number Diff line change
@@ -1 +1 @@
struct LAPACKBackend end
struct LAPACKBackend end
63 changes: 49 additions & 14 deletions src/eigh.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# `eigh!`` is a simple wrapper for `eigh_full!`
function eigh!(A::AbstractMatrix,
D::AbstractVector=similar(A, real(eltype(A)), size(A, 1)),
V::AbstractMatrix=similar(A, size(A));
kwargs...)
return eigh_full!(A, D, V; kwargs...)
# TODO: do not export but mark as public ?
function eigh!(A::AbstractMatrix, args...; kwargs...)
return eigh_full!(A, args...; kwargs...)
end

function eigh_full!(A::AbstractMatrix,
Expand All @@ -12,6 +9,11 @@ function eigh_full!(A::AbstractMatrix,
kwargs...)
return eigh_full!(A, D, V, default_backend(eigh_full!, A; kwargs...); kwargs...)
end
function eigh_vals!(A::AbstractMatrix,
D::AbstractVector=similar(A, real(eltype(A)), size(A, 1));
kwargs...)
return eigh_vals!(A, D, default_backend(eigh_vals!, A; kwargs...); kwargs...)
end
function eigh_trunc!(A::AbstractMatrix;
kwargs...)
return eigh_trunc!(A, default_backend(eigh_trunc!, A; kwargs...); kwargs...)
Expand All @@ -20,6 +22,9 @@ end
function default_backend(::typeof(eigh_full!), A::AbstractMatrix; kwargs...)
return default_eigh_backend(A; kwargs...)
end
function default_backend(::typeof(eigh_vals!), A::AbstractMatrix; kwargs...)
return default_eigh_backend(A; kwargs...)
end
function default_backend(::typeof(eigh_trunc!), A::AbstractMatrix; kwargs...)
return default_eigh_backend(A; kwargs...)
end
Expand All @@ -37,6 +42,13 @@ function check_eigh_full_input(A, D, V)
throw(DimensionMismatch("Eigenvector matrix `V` must have size equal to A"))
return nothing
end
function check_eigh_vals_input(A, D)
m, n = size(A)
m == n || throw(ArgumentError("Eigenvalue decompsition requires square matrix"))
size(D) == (n,) ||
throw(DimensionMismatch("Eigenvalue vector `D` must have length equal to size(A, 1)"))
return nothing
end

@static if VERSION >= v"1.12-DEV.0"
const RobustRepresentations = LinearAlgebra.RobustRepresentations
Expand All @@ -58,17 +70,37 @@ function eigh_full!(A::AbstractMatrix,
elseif alg == LinearAlgebra.QRIteration()
YALAPACK.heev!(A, D, V; kwargs...)
else
throw(ArgumentError("Unknown algorithm $alg"))
throw(ArgumentError("Unknown LAPACK eigenvalue algorithm $alg"))
end
return D, V
end

# for eigh_trunc!, it doesn't make sense to preallocate U, S, Vᴴ as we don't know their sizes
function eigh_vals!(A::AbstractMatrix,
D::AbstractVector,
backend::LAPACKBackend;
alg=RobustRepresentations(),
kwargs...)
check_eigh_vals_input(A, D)
V = similar(A, (size(A, 1), 0))
if alg == RobustRepresentations()
YALAPACK.heevr!(A, D, V; kwargs...)
elseif alg == LinearAlgebra.DivideAndConquer()
YALAPACK.heevd!(A, D, V; kwargs...)
elseif alg == LinearAlgebra.QRIteration()
YALAPACK.heev!(A, D, V; kwargs...)
else
throw(ArgumentError("Unknown LAPACK eigenvalue algorithm $alg"))
end
return D
end

# for eigh_trunc!, it doesn't make sense to preallocate D and V as we don't know their sizes
function eigh_trunc!(A::AbstractMatrix,
backend::LAPACKBackend;
alg=RobustRepresentations(),
tol=zero(real(eltype(A))),
rank=min(size(A)...),
atol=zero(real(eltype(A))),
rtol=zero(real(eltype(A))),
rank=size(A, 1),
kwargs...)
if alg == RobustRepresentations()
D, V = YALAPACK.heevr!(A; kwargs...)
Expand All @@ -77,12 +109,15 @@ function eigh_trunc!(A::AbstractMatrix,
elseif alg == LinearAlgebra.QRIteration()
D, V = YALAPACK.heev!(A; kwargs...)
else
throw(ArgumentError("Unknown algorithm $alg"))
throw(ArgumentError("Unknown LAPACK eigenvalue algorithm $alg"))
end
# eigenvalues are sorted in ascending order; do we assume that they are positive?
# eigenvalues are sorted in ascending order
# TODO: do we assume that they are positive, or should we check for this?
# or do we want to truncate based on absolute value and thus sort differently?
n = length(D)
s = max(n - rank, findfirst(>=(tol * D[end]), S))
tol = convert(eltype(D), max(atol, rtol * D[n]))
s = max(n - rank + 1, findfirst(>=(tol), D))
# TODO: do we want views here, such that we do not need extra allocations if we later
# copy them into other storage
return D[n:-1:s], V[:, n:-1:s]
end
end
1 change: 1 addition & 0 deletions src/matrixfunctions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

115 changes: 112 additions & 3 deletions src/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ function qr_full!(A::AbstractMatrix,
Q::AbstractMatrix=similar(A, (size(A, 1), size(A, 1))),
R::AbstractMatrix=similar(A, (size(A, 1), size(A, 2)));
kwargs...)
return qr_full!(A, Q, R, default_backend(qr_full!, A; kwargs...))
return qr_full!(A, Q, R, default_backend(qr_full!, A; kwargs...); kwargs...)
end
function qr_compact!(A::AbstractMatrix,
Q::AbstractMatrix=similar(A, (size(A, 1), size(A, 1))),
R::AbstractMatrix=similar(A, (size(A, 1), size(A, 2)));
kwargs...)
return qr_compact!(A, Q, R, default_backend(qr_compact!, A; kwargs...))
return qr_compact!(A, Q, R, default_backend(qr_compact!, A; kwargs...); kwargs...)
end

function default_backend(::typeof(qr_full!), A::AbstractMatrix; kwargs...)
Expand All @@ -20,4 +20,113 @@ end

function default_qr_backend(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
return LAPACKBackend()
end
end

function check_qr_full_input(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix)
m, n = size(A)
size(Q) == (m, m) ||
throw(DimensionMismatch("Full unitary matrix `Q` must be square with equal number of rows as A"))
isempty(R) || size(R) == (m, n) ||
throw(DimensionMismatch("Upper triangular matrix `R` must have size equal to A"))
return nothing
end
function check_qr_compact_input(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix)
m, n = size(A)
if n <= m
size(Q) == (m, n) ||
throw(DimensionMismatch("Isometric `Q` must have size equal to A"))
isempty(R) || size(R) == (n, n) ||
throw(DimensionMismatch("Upper triangular matrix `R` must be square with equal number of columns as A"))
else
check_qr_full_input(A, Q, R)
end
end

function qr_full!(A::AbstractMatrix,
Q::AbstractMatrix,
R::AbstractMatrix,
backend::LAPACKBackend;
positive=false,
pivoted=false,
blocksize=((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A)))
check_qr_full_input(A, Q, R)
_unsafe_qr!(A, Q, R; positive=positive, pivoted=pivoted, blocksize=blocksize)
return Q, R
end

function qr_compact!(A::AbstractMatrix,
Q::AbstractMatrix,
R::AbstractMatrix,
backend::LAPACKBackend;
positive=false,
pivoted=false,
blocksize=((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A)))
check_qr_compact_input(A, Q, R)
_unsafe_qr!(A, Q, R; positive=positive, pivoted=pivoted, blocksize=blocksize)
return Q, R
end

function _unsafe_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
positive=false,
pivoted=false,
blocksize=((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A)))
m, n = size(A)
minmn = min(m, n)
computeR = length(R) > 0
inplaceQ = Q === A

if pivoted && (blocksize > 1)
throw(ArgumentError("LAPACK does not provide a blocked implementation for a pivoted QR decomposition"))
end
if inplaceQ && (computeR || positive || blocksize > 1 || m < n)
throw(ArgumentError("inplace Q only supported if matrix is tall (`m >= n`), R is not required, and using the unblocked algorithm (`blocksize=1`) with `positive=false`"))
end

if blocksize > 1
nb = min(minmn, blocksize)
if computeR # first use R as space for T
A, T = YALAPACK.geqrt!(A, view(R, 1:nb, 1:minmn))
else
A, T = YALAPACK.geqrt!(A, similar(A, nb, minmn))
end
Q = YALAPACK.gemqrt!('L', 'N', A, T, one!(Q))
else
if pivoted
A, τ, jpvt = YALAPACK.geqp3!(A)
else
A, τ = YALAPACK.geqrf!(A)
end
if inplaceQ
Q = YALAPACK.orgqr!(A, τ)
else
Q = YALAPACK.ormqr!('L', 'N', A, τ, one!(Q))
end
end

if positive # already fix Q even if we do not need R
@inbounds for j in 1:minmn
s = safesign(A[j, j])
@simd for i in 1:m
Q[i, j] *= s
end
end
end

if computeR
= triu!(view(A, axes(R)...))
if positive
@inbounds for j in n:-1:1
@simd for i in 1:min(minmn, j)
R̃[i, j] = R̃[i, j] * conj(safesign(R̃[i, i]))
end
end
end
if !pivoted
copyto!(R, R̃)
else
# probably very inefficient in terms of memory access
copyto!(view(R, :, jpvt), R̃)
end
end
return Q, R
end
51 changes: 45 additions & 6 deletions src/svd.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# TODO: do not export but mark as public ?
function svd!(A::AbstractMatrix, args...; kwargs...)
return svd_compact!(A, args...; kwargs...)
end

function svd_full!(A::AbstractMatrix,
U::AbstractMatrix=similar(A, (size(A, 1), size(A, 1))),
S::AbstractVector=similar(A, real(eltype(A)), (min(size(A)...),)),
Expand All @@ -12,6 +17,12 @@ function svd_compact!(A::AbstractMatrix,
kwargs...)
return svd_compact!(A, U, S, Vᴴ, default_backend(svd_compact!, A; kwargs...); kwargs...)
end
function svd_vals!(A::AbstractMatrix,
S::AbstractVector=similar(A, real(eltype(A)), (min(size(A)...),));
kwargs...)
return svd_vals!(A, S, default_backend(svd_vals!, A; kwargs...); kwargs...)
end

function svd_trunc!(A::AbstractMatrix;
kwargs...)
return svd_trunc!(A, default_backend(svd_trunc!, A; kwargs...); kwargs...)
Expand All @@ -23,6 +34,9 @@ end
function default_backend(::typeof(svd_compact!), A::AbstractMatrix; kwargs...)
return default_svd_backend(A; kwargs...)
end
function default_backend(::typeof(svd_vals!), A::AbstractMatrix; kwargs...)
return default_svd_backend(A; kwargs...)
end
function default_backend(::typeof(svd_trunc!), A::AbstractMatrix; kwargs...)
return default_svd_backend(A; kwargs...)
end
Expand Down Expand Up @@ -53,6 +67,13 @@ function check_svd_compact_input(A, U, S, Vᴴ)
throw(DimensionMismatch("`svd_compact!` requires vector S of length min(size(A)..."))
return nothing
end
function check_svd_vals_input(A, S)
m, n = size(A)
minmn = min(m, n)
size(S) == (minmn,) ||
throw(DimensionMismatch("`svd_vals!` requires vector S of length min(size(A)..."))
return nothing
end

function svd_full!(A::AbstractMatrix,
U::AbstractMatrix,
Expand All @@ -66,7 +87,7 @@ function svd_full!(A::AbstractMatrix,
elseif alg == LinearAlgebra.QRIteration()
YALAPACK.gesvd!(A, S, U, Vᴴ)
else
throw(ArgumentError("Unknown algorithm $alg"))
throw(ArgumentError("Unknown LAPACK singular value algorithm $alg"))
end
return U, S, Vᴴ
end
Expand All @@ -82,26 +103,44 @@ function svd_compact!(A::AbstractMatrix,
elseif alg == LinearAlgebra.QRIteration()
YALAPACK.gesvd!(A, S, U, Vᴴ)
else
throw(ArgumentError("Unknown algorithm $alg"))
throw(ArgumentError("Unknown LAPACK singular value algorithm $alg"))
end
return U, S, Vᴴ
end

function svd_vals!(A::AbstractMatrix,
S::AbstractVector,
backend::LAPACKBackend;
alg=LinearAlgebra.DivideAndConquer())
check_svd_vals_input(A, S)
m, n = size(A)
if alg == LinearAlgebra.DivideAndConquer()
YALAPACK.gesdd!(A, S, similar(A, m, 0), similar(A, n, 0))
elseif alg == LinearAlgebra.QRIteration()
YALAPACK.gesvd!(A, S, similar(A, m, 0), similar(A, n, 0))
else
throw(ArgumentError("Unknown LAPACK singular value algorithm $alg"))
end
return S
end

# for svd_trunc!, it doesn't make sense to preallocate U, S, Vᴴ as we don't know their sizes
function svd_trunc!(A::AbstractMatrix,
backend::LAPACKBackend;
alg=LinearAlgebra.DivideAndConquer(),
tol=zero(real(eltype(A))),
atol=zero(real(eltype(A))),
rtol=zero(real(eltype(A))),
rank=min(size(A)...))
if alg == LinearAlgebra.DivideAndConquer()
S, U, Vᴴ = YALAPACK.gesdd!(A)
elseif alg == LinearAlgebra.QRIteration()
S, U, Vᴴ = YALAPACK.gesvd!(A)
else
throw(ArgumentError("Unknown algorithm $alg"))
throw(ArgumentError("Unknown LAPACK singular value algorithm $alg"))
end
r = min(rank, findlast(>=(tol * S[1]), S))
tol = convert(eltype(S), max(atol, rtol * S[1]))
r = min(rank, findlast(>=(tol), S))
# TODO: do we want views here, such that we do not need extra allocations if we later
# copy them into other storage
return U[:, 1:r], S[1:r], Vᴴ[1:r, :]
end
end
Loading

0 comments on commit 2d5b83a

Please sign in to comment.