Skip to content

Commit

Permalink
some progress
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Nov 19, 2024
1 parent 8a3b550 commit 3d66e5d
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 20 deletions.
4 changes: 3 additions & 1 deletion src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ using LinearAlgebra: LinearAlgebra
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt, triu!

include("auxiliary.jl")
include("backend.jl")
include("yalapack.jl")
include("orthnull.jl")
include("svd.jl")
include("eigh.jl")

end
88 changes: 88 additions & 0 deletions src/eigh.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# `eigh!`` is a simple wrapper for `eigh_full!`
function eigh!(A::AbstractMatrix,
D::AbstractVector=similar(A, real(eltype(A)), size(A, 1)),
V::AbstractMatrix=similar(A, size(A));
kwargs...)
return eigh_full!(A, D, V; kwargs...)
end

function eigh_full!(A::AbstractMatrix,
D::AbstractVector=similar(A, real(eltype(A)), size(A, 1)),
V::AbstractMatrix=similar(A, size(A));
kwargs...)
return eigh_full!(A, D, V, default_backend(eigh_full!, A; kwargs...); kwargs...)
end
function eigh_trunc!(A::AbstractMatrix;
kwargs...)
return eigh_trunc!(A, default_backend(eigh_trunc!, A; kwargs...); kwargs...)
end

function default_backend(::typeof(eigh_full!), A::AbstractMatrix; kwargs...)
return default_eigh_backend(A; kwargs...)
end
function default_backend(::typeof(eigh_trunc!), A::AbstractMatrix; kwargs...)
return default_eigh_backend(A; kwargs...)
end

function default_eigh_backend(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
return LAPACKBackend()
end

function check_eigh_full_input(A, D, V)
m, n = size(A)
m == n || throw(ArgumentError("Eigenvalue decompsition requires square matrix"))
size(D) == (n,) ||
throw(DimensionMismatch("Eigenvalue vector `D` must have length equal to size(A, 1)"))
size(V) == (n, n) ||
throw(DimensionMismatch("Eigenvector matrix `V` must have size equal to A"))
return nothing
end

@static if VERSION >= v"1.12-DEV.0"
const RobustRepresentations = LinearAlgebra.RobustRepresentations
else
struct RobustRepresentations end
end

function eigh_full!(A::AbstractMatrix,
D::AbstractVector,
V::AbstractMatrix,
backend::LAPACKBackend;
alg=RobustRepresentations(),
kwargs...)
check_eigh_full_input(A, D, V)
if alg == RobustRepresentations()
YALAPACK.heevr!(A, D, V; kwargs...)
elseif alg == LinearAlgebra.DivideAndConquer()
YALAPACK.heevd!(A, D, V; kwargs...)
elseif alg == LinearAlgebra.QRIteration()
YALAPACK.heev!(A, D, V; kwargs...)
else
throw(ArgumentError("Unknown algorithm $alg"))
end
return D, V
end

# for eigh_trunc!, it doesn't make sense to preallocate U, S, Vᴴ as we don't know their sizes
function eigh_trunc!(A::AbstractMatrix,
backend::LAPACKBackend;
alg=RobustRepresentations(),
tol=zero(real(eltype(A))),
rank=min(size(A)...),
kwargs...)
if alg == RobustRepresentations()
D, V = YALAPACK.heevr!(A; kwargs...)
elseif alg == LinearAlgebra.DivideAndConquer()
D, V = YALAPACK.heevd!(A; kwargs...)
elseif alg == LinearAlgebra.QRIteration()
D, V = YALAPACK.heev!(A; kwargs...)
else
throw(ArgumentError("Unknown algorithm $alg"))
end
# eigenvalues are sorted in ascending order; do we assume that they are positive?
n = length(D)
s = max(n - rank, findfirst(>=(tol * D[end]), S))
# TODO: do we want views here, such that we do not need extra allocations if we later
# copy them into other storage
return D[n:-1:s], V[:, n:-1:s]
end
60 changes: 45 additions & 15 deletions src/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,21 @@ function svd_full!(A::AbstractMatrix,
S::AbstractVector=similar(A, real(eltype(A)), (min(size(A)...),)),
Vᴴ::AbstractMatrix=similar(A, (size(A, 2), size(A, 2)));
kwargs...)
return svd_full!(A, U, S, Vᴴ, default_backend(svd_full!, A; kwargs...))
return svd_full!(A, U, S, Vᴴ, default_backend(svd_full!, A; kwargs...); kwargs...)
end
function svd_compact!(A::AbstractMatrix,
U::AbstractMatrix=similar(A, (size(A, 1), size(A, 1))),
S::AbstractVector=similar(A, real(eltype(A)), (min(size(A)...),)),
Vᴴ::AbstractMatrix=similar(A, (size(A, 2), size(A, 2)));
kwargs...)
return svd_compact!(A, U, S, Vᴴ, default_backend(svd_compact!, A; kwargs...))
return svd_compact!(A, U, S, Vᴴ, default_backend(svd_compact!, A; kwargs...); kwargs...)
end
function svd_trunc!(A::AbstractMatrix,
U::AbstractMatrix=similar(A, (size(A, 1), size(A, 1))),
S::AbstractVector=similar(A, real(eltype(A)), (min(size(A)...),)),
Vᴴ::AbstractMatrix=similar(A, (size(A, 2), size(A, 2)));
function svd_trunc!(A::AbstractMatrix;
kwargs...)
return svd_trunc!(A, U, S, Vᴴ, default_backend(svd_trunc!, A; kwargs...))
return svd_trunc!(A, default_backend(svd_trunc!, A; kwargs...); kwargs...)
end

function default_backend(::typeof(svd_full!), A::AbstractMatri; kwargs...)
function default_backend(::typeof(svd_full!), A::AbstractMatrix; kwargs...)
return default_svd_backend(A; kwargs...)
end
function default_backend(::typeof(svd_compact!), A::AbstractMatrix; kwargs...)
Expand Down Expand Up @@ -49,29 +46,62 @@ function check_svd_compact_input(A, U, S, Vᴴ)
m, n = size(A)
minmn = min(m, n)
size(U) == (m, minmn) ||
throw(DimensionMismatch("`svd_full!` requires square U matrix with equal number of rows as A"))
throw(DimensionMismatch("`svd_compact!` requires square U matrix with equal number of rows as A"))
size(Vᴴ) == (minmn, n) ||
throw(DimensionMismatch("`svd_full!` requires square Vᴴ matrix with equal number of columns as A"))
throw(DimensionMismatch("`svd_compact!` requires square Vᴴ matrix with equal number of columns as A"))
size(S) == (minmn,) ||
throw(DimensionMismatch("`svd_full!` requires vector S of length min(size(A)..."))
throw(DimensionMismatch("`svd_compact!` requires vector S of length min(size(A)..."))
return nothing
end

function svd_full!(A::AbstractMatrix,
U::AbstractMatrix,
S::AbstractVector,
Vᴴ::AbstractMatrix,
backend::LAPACKBackend)
backend::LAPACKBackend;
alg=LinearAlgebra.DivideAndConquer())
check_svd_full_input(A, U, S, Vᴴ)
YALAPACK.gesdd!(A, S, U, Vᴴ)
if alg == LinearAlgebra.DivideAndConquer()
YALAPACK.gesdd!(A, S, U, Vᴴ)
elseif alg == LinearAlgebra.QRIteration()
YALAPACK.gesvd!(A, S, U, Vᴴ)
else
throw(ArgumentError("Unknown algorithm $alg"))
end
return U, S, Vᴴ
end
function svd_compact!(A::AbstractMatrix,
U::AbstractMatrix,
S::AbstractVector,
Vᴴ::AbstractMatrix,
backend::LAPACKBackend)
backend::LAPACKBackend;
alg=LinearAlgebra.DivideAndConquer())
check_svd_compact_input(A, U, S, Vᴴ)
YALAPACK.gesdd!(A, S, U, Vᴴ)
if alg == LinearAlgebra.DivideAndConquer()
YALAPACK.gesdd!(A, S, U, Vᴴ)
elseif alg == LinearAlgebra.QRIteration()
YALAPACK.gesvd!(A, S, U, Vᴴ)
else
throw(ArgumentError("Unknown algorithm $alg"))
end
return U, S, Vᴴ
end

# for svd_trunc!, it doesn't make sense to preallocate U, S, Vᴴ as we don't know their sizes
function svd_trunc!(A::AbstractMatrix,
backend::LAPACKBackend;
alg=LinearAlgebra.DivideAndConquer(),
tol=zero(real(eltype(A))),
rank=min(size(A)...))
if alg == LinearAlgebra.DivideAndConquer()
S, U, Vᴴ = YALAPACK.gesdd!(A)
elseif alg == LinearAlgebra.QRIteration()
S, U, Vᴴ = YALAPACK.gesvd!(A)
else
throw(ArgumentError("Unknown algorithm $alg"))
end
r = min(rank, findlast(>=(tol * S[1]), S))
# TODO: do we want views here, such that we do not need extra allocations if we later
# copy them into other storage
return U[:, 1:r], S[1:r], Vᴴ[1:r, :]
end
9 changes: 5 additions & 4 deletions src/yalapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,8 @@ for (heev, heevr, heevd, hegvd, elty, relty) in
@eval begin
function heev!(A::AbstractMatrix{$elty},
W::AbstractVector{$relty}=similar(A, $relty, size(A, 1)),
V::AbstractMatrix{$elty}=A)
V::AbstractMatrix{$elty}=A;
uplo::AbstractChar='U') # shouldn't matter but 'U' seems slightly faster than 'L'
require_one_based_indexing(A, V, W)
chkstride1(A, V, W)
n = checksquare(A)
Expand All @@ -458,7 +459,6 @@ for (heev, heevr, heevd, hegvd, elty, relty) in
else
ishermitian(A) || throw(ArgumentError("A must be Hermitian"))
end
uplo = 'U' # shouldn't matter but 'U' seems slightly faster than 'L'
chkuplofinite(A, uplo)
n == length(W) || throw(DimensionMismatch("length mismatch between A and W"))
if length(V) == 0
Expand Down Expand Up @@ -513,6 +513,7 @@ for (heev, heevr, heevd, hegvd, elty, relty) in
function heevr!(A::AbstractMatrix{$elty},
W::AbstractVector{$relty}=similar(A, $relty, size(A, 1)),
V::AbstractMatrix{$elty}=similar(A);
uplo::AbstractChar='U', # shouldn't matter but 'U' seems slightly faster than 'L'
kwargs...)
require_one_based_indexing(A, V, W)
chkstride1(A, V, W)
Expand All @@ -522,7 +523,6 @@ for (heev, heevr, heevd, hegvd, elty, relty) in
else
ishermitian(A) || throw(ArgumentError("A must be Hermitian"))
end
uplo = 'U' # shouldn't matter but 'U' seems slightly faster than 'L'
chkuplofinite(A, uplo)
if haskey(kwargs, :irange)
il = first(irange)
Expand Down Expand Up @@ -623,7 +623,8 @@ for (heev, heevr, heevd, hegvd, elty, relty) in

function heevd!(A::AbstractMatrix{$elty},
W::AbstractVector{$relty}=similar(A, $relty, size(A, 1)),
V::AbstractMatrix{$elty}=A)
V::AbstractMatrix{$elty}=A;
uplo::AbstractChar='U') # shouldn't matter but 'U' seems slightly faster than 'L'
require_one_based_indexing(A, V, W)
chkstride1(A, V, W)
n = checksquare(A)
Expand Down

0 comments on commit 3d66e5d

Please sign in to comment.