Skip to content

Commit

Permalink
Add support for weights (#91)
Browse files Browse the repository at this point in the history
* Add support for weights

* Limit to FrequencyWeights

* Improve wording

* Respond to review comments

* Version 0.2.10
  • Loading branch information
timholy authored Nov 28, 2023
1 parent d9d9e9f commit d6b428d
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 75 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "CovarianceEstimation"
uuid = "587fd27a-f159-11e8-2dae-1979310e6154"
authors = ["Mateusz Baran <[email protected]>", "Thibaut Lienart"]
version = "0.2.9"
version = "0.2.10"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
133 changes: 78 additions & 55 deletions src/linearshrinkage.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,24 +103,31 @@ LinearShrinkage(;
corrected::Bool=false) = LinearShrinkage(target, shrinkage, corrected=corrected)

"""
cov(lse::LinearShrinkage, X; dims=1)
cov(lse::LinearShrinkage, X, [weights::FrequencyWeights]; dims=1)
Linear shrinkage covariance estimator for matrix `X` along dimension `dims`.
Computed using the method described by `lse`.
Optionally provide `weights` associated with each observation in `X` (see `StatsBase.FrequencyWeights`).
!!! note
Theoretical guidance for the use of weights in shrinkage estimation seems sparse.
`FrequencyWeights` have a straightforward implementation, but support for other `AbstractWeight` subtypes
awaits analytical justification.
"""
function cov(lse::LinearShrinkage, X::AbstractMatrix{<:Real};
function cov(lse::LinearShrinkage, X::AbstractMatrix{<:Real}, weights::FrequencyWeights...;
dims::Int=1, mean=nothing)

dims (1, 2) || throw(ArgumentError("Argument dims can only be 1 or 2 (given: $dims)"))

Xc = (dims == 1) ? copy(X) : copy(transpose(X))
n, p = size(Xc)
# sample covariance of size (p x p)
S = cov(SimpleCovariance(corrected=lse.corrected), X; dims=dims, mean=mean)
S = cov(SimpleCovariance(corrected=lse.corrected), X, weights...; dims=dims, mean=mean)

# NOTE: don't need to check if mean is proper as this is already done above
if mean === nothing
Xc .-= Statistics.mean(Xc, dims=1)
Xc .-= Statistics.mean(Xc, weights...; dims=1)
elseif mean isa AbstractArray
if dims == 1
Xc .-= mean
Expand All @@ -129,7 +136,7 @@ function cov(lse::LinearShrinkage, X::AbstractMatrix{<:Real};
end
end

return linear_shrinkage(lse.target, Xc, S, lse.shrinkage, n, p, lse.corrected)
return linear_shrinkage(lse.target, Xc, S, lse.shrinkage, n, p, lse.corrected, weights...)
end

##############################################################################
Expand Down Expand Up @@ -171,6 +178,7 @@ This operation appears often in the computations of optimal shrinkage λ.
* Time complexity: ``O(2np^2)``
"""
uccov(X::AbstractMatrix) = (X' * X) / size(X, 1)
uccov(X::AbstractMatrix, weights::FrequencyWeights) = (X' * (weights .* X)) / sum(weights)


"""
Expand Down Expand Up @@ -225,10 +233,17 @@ function sum_fij(Xc, S, n, κ)
M .*= sd'
return sumij(M) / (n * κ)
end
function sum_fij(Xc, S, n, κ, weights)
sd = sqrt.(diag(S))
M = ((Xc.^3)' * (weights .* Xc)) ./ sd
M .-= κ .* S .* sd
M .*= sd'
return sumij(M) / (sum(weights) * κ)
end
##############################################################################

"""
linear_shrinkage(target, Xc, S, λ, n, p, corrected)
linear_shrinkage(target, Xc, S, λ, n, p, corrected, [weights])
Performs linear shrinkage with target of type `target` for data matrix `Xc`
of size `n` by `p` with covariance matrix `S` and shrinkage parameter `λ`.
Expand All @@ -240,7 +255,7 @@ linear_shrinkage

function linear_shrinkage(::DiagonalUnitVariance, Xc::AbstractMatrix,
S::AbstractMatrix, λ::Real, n::Int, p::Int,
corrected::Bool)
corrected::Bool, weights::FrequencyWeights...)

return linshrink(I, S, λ)
end
Expand All @@ -252,24 +267,25 @@ Compute the shrinkage estimator where the target is a `DiagonalUnitVariance`.
"""
function linear_shrinkage(::DiagonalUnitVariance, Xc::AbstractMatrix,
S::AbstractMatrix, λ::Symbol, n::Int, p::Int,
corrected::Bool)
corrected::Bool, weights::FrequencyWeights...)

F = I
T = float(eltype(S))
κ = n - Int(corrected)
γ = T/n)
wn = totalweight(n, weights...)
κ = wn - Int(corrected)
γ = T/wn)
Xc² = Xc.^2
# computing the shrinkage
if λ [:auto, :lw]
ΣS² = sumij2(S, with_diag=true)
λ = sumij(uccov(Xc²), with_diag=true) / γ^2 - ΣS²
λ = sumij(uccov(Xc², weights...), with_diag=true) / γ^2 - ΣS²
λ /= κ * (ΣS² - 2tr(S) + p)
elseif λ == :ss
# use the standardised data matrix
d = one(T) ./ vec(sum(Xc², dims=1))
d = one(T) ./ vec(sum(Xc², weights...; dims=1))
= rescale(S, sqrt.(d)) # this has diagonal 1/κ
ΣS̄² = sumij2(S̄, with_diag=true)
λ = sumij(rescale!(uccov(Xc²), d), with_diag=true) / γ^2 - ΣS̄²
λ = sumij(rescale!(uccov(Xc², weights...), d), with_diag=true) / γ^2 - ΣS̄²
λ /= T* ΣS̄² - p / κ)
else
throw(ArgumentError("Unsupported shrinkage method for target DiagonalUnitVariance: ."))
Expand All @@ -284,7 +300,7 @@ target_B(S::AbstractMatrix, p::Int) = tr(S)/p * I

function linear_shrinkage(::DiagonalCommonVariance, Xc::AbstractMatrix,
S::AbstractMatrix, λ::Real, n::Int, p::Int,
corrected::Bool)
corrected::Bool, weights::FrequencyWeights...)

return linshrink(target_B(S, p), S, λ)
end
Expand All @@ -296,39 +312,40 @@ Compute the shrinkage estimator where the target is a `DiagonalCommonVariance`.
"""
function linear_shrinkage(::DiagonalCommonVariance, Xc::AbstractMatrix,
S::AbstractMatrix, λ::Symbol, n::Int, p::Int,
corrected::Bool)
corrected::Bool, weights::FrequencyWeights...)

F = target_B(S, p)
T = float(eltype(F))
κ = n - Int(corrected)
γ = T/n)
wn = totalweight(n, weights...)
κ = wn - Int(corrected)
γ = T/wn)
Xc² = Xc.^2
# computing the shrinkage
if λ [:auto, :lw]
v = F.λ # tr(S)/p
ΣS² = sumij2(S, with_diag=true)
λ = sumij(uccov(Xc²), with_diag=true) / γ^2 - ΣS²
λ = sumij(uccov(Xc², weights...), with_diag=true) / γ^2 - ΣS²
λ /= κ * (ΣS² - p*v^2)
elseif λ == :ss
# use the standardised data matrix
d = one(T) ./ vec(sum(Xc², dims=1))
d = one(T) ./ vec(sum(Xc², weights...; dims=1))
= rescale(S, sqrt.(d)) # this has diagonal 1/κ
= κ # tr(S̄)/p
ΣS̄² = sumij2(S̄, with_diag=true)
λ = sumij(rescale!(uccov(Xc²), d), with_diag=true) / γ^2 - ΣS̄²
λ = sumij(rescale!(uccov(Xc², weights...), d), with_diag=true) / γ^2 - ΣS̄²
λ /= T* ΣS̄² - p/κ)
elseif λ == :rblw
# https://arxiv.org/pdf/0907.4698.pdf equations 17, 19
trS² = sum(abs2, S)
tr²S = tr(S)^2
# note: using corrected or uncorrected S does not change λ
λ = T(((n-2)/n * trS² + tr²S) / ((n+2) * (trS² - tr²S/p)))
λ = T(((wn-2)/wn * trS² + tr²S) / ((wn+2) * (trS² - tr²S/p)))
elseif λ == :oas
# https://arxiv.org/pdf/0907.4698.pdf equation 23
trS² = sum(abs2, S)
tr²S = tr(S)^2
# note: using corrected or uncorrected S does not change λ
λ = ((one(T)-T(2.0)/p) * trS² + tr²S) / ((n+one(T)-T(2.0)/p) * (trS² - tr²S/p))
λ = ((one(T)-T(2.0)/p) * trS² + tr²S) / ((wn+one(T)-T(2.0)/p) * (trS² - tr²S/p))
else
throw(ArgumentError("Unsupported shrinkage method for target DiagonalCommonVariance: ."))
end
Expand All @@ -342,7 +359,7 @@ target_D(S::AbstractMatrix) = Diagonal(S)

function linear_shrinkage(::DiagonalUnequalVariance, Xc::AbstractMatrix,
S::AbstractMatrix, λ::Real, n::Int, p::Int,
corrected::Bool)
corrected::Bool, weights::FrequencyWeights...)

return linshrink(target_D(S), S, λ)
end
Expand All @@ -354,23 +371,26 @@ Compute the shrinkage estimator where the target is a `DiagonalUnequalVariance`.
"""
function linear_shrinkage(::DiagonalUnequalVariance, Xc::AbstractMatrix,
S::AbstractMatrix, λ::Symbol, n::Int, p::Int,
corrected::Bool)
corrected::Bool, weights::FrequencyWeights...)

F = target_D(S)
T = float(eltype(F))
κ = n - Int(corrected)
γ = T/ n)
wn = totalweight(n, weights...)
κ = wn - Int(corrected)
γ = T/ wn)
Xc² = Xc.^2
# computing the shrinkage
if λ [:auto, :lw]
ΣS² = sumij2(S)
λ = sumij(uccov(Xc²)) / γ^2 - ΣS²
λ = sumij(uccov(Xc², weights...)) / γ^2 - ΣS²
λ /= κ * ΣS²
elseif λ == :ss
keep = diag(S) .> zero(T)
Xc² = Xc²[:, keep]
# use the standardised data matrix
d = one(T) ./ vec(sum(Xc², dims=1))
ΣS̄² = sumij2(rescale(S, sqrt.(d)))
λ = sumij(rescale!(uccov(Xc²), d)) / γ^2 - ΣS̄²
d = one(T) ./ vec(sum(Xc², weights...; dims=1))
ΣS̄² = sumij2(rescale(S[keep, keep], sqrt.(d)))
λ = sumij(rescale!(uccov(Xc², weights...), d)) / γ^2 - ΣS̄²
λ /= κ * ΣS̄²
else
throw(ArgumentError("Unsupported shrinkage method for target DiagonalUnequalVariance: ."))
Expand All @@ -392,7 +412,7 @@ end

function linear_shrinkage(::CommonCovariance, Xc::AbstractMatrix,
S::AbstractMatrix, λ::Real, n::Int, p::Int,
corrected::Bool)
corrected::Bool, weights::FrequencyWeights...)

F, _, _ = target_C(S, p)
return linshrink!(F, S, λ)
Expand All @@ -405,23 +425,24 @@ Compute the shrinkage estimator where the target is a `CommonCovariance`.
"""
function linear_shrinkage(::CommonCovariance, Xc::AbstractMatrix,
S::AbstractMatrix, λ::Symbol, n::Int, p::Int,
corrected::Bool)
corrected::Bool, weights::FrequencyWeights...)

F, v, c = target_C(S, p)
T = float(eltype(F))
κ = n - Int(corrected)
γ = T/n)
wn = totalweight(n, weights...)
κ = wn - Int(corrected)
γ = T/wn)
Xc² = Xc.^2
# computing the shrinkage
if λ [:auto, :lw]
ΣS² = sumij2(S, with_diag=true)
λ = sumij(uccov(Xc²), with_diag=true) / γ^2 - ΣS²
λ = sumij(uccov(Xc², weights...), with_diag=true) / γ^2 - ΣS²
λ /= κ * (ΣS² - p*(p-1)*c^2 - p*v^2)
elseif λ == :ss
d = one(T) ./ vec(sum(Xc², dims=1))
d = one(T) ./ vec(sum(Xc², weights...; dims=1))
= rescale(S, sqrt.(d))
ΣS̄² = sumij2(S̄, with_diag=true)
λ = sumij(rescale!(uccov(Xc²), d), with_diag=true) / γ^2 - ΣS̄²
λ = sumij(rescale!(uccov(Xc², weights...), d), with_diag=true) / γ^2 - ΣS̄²
λ /= κ * ΣS̄² - p/κ - κ * sumij(S̄; with_diag=false)^2 / (p * (p - 1))
else
throw(ArgumentError("Unsupported shrinkage method for target CommonCovariance: ."))
Expand All @@ -439,7 +460,7 @@ end

function linear_shrinkage(::PerfectPositiveCorrelation, Xc::AbstractMatrix,
S::AbstractMatrix, λ::Real, n::Int, p::Int,
corrected::Bool)
corrected::Bool, weights::FrequencyWeights...)

return linshrink!(target_E(S), S, λ)
end
Expand All @@ -452,26 +473,27 @@ Compute the shrinkage estimator where the target is a
"""
function linear_shrinkage(::PerfectPositiveCorrelation, Xc::AbstractMatrix,
S::AbstractMatrix, λ::Symbol, n::Int, p::Int,
corrected::Bool)
corrected::Bool, weights::FrequencyWeights...)

F = target_E(S)
T = float(eltype(F))
κ = n - Int(corrected)
γ = T/n)
wn = totalweight(n, weights...)
κ = wn - Int(corrected)
γ = T/wn)
Xc² = Xc.^2
# computing the shrinkage
if λ [:auto, :lw]
ΣS² = sumij2(S)
λ = (sumij(uccov(Xc²)) / γ^2 - ΣS²) / κ
λ -= sum_fij(Xc, S, n, κ)
λ = (sumij(uccov(Xc², weights...)) / γ^2 - ΣS²) / κ
λ -= sum_fij(Xc, S, n, κ, weights...)
λ /= sumij2(S - F)
elseif λ == :ss
d = one(T) ./ vec(sum(Xc², dims=1))
d = one(T) ./ vec(sum(Xc², weights...; dims=1))
s = sqrt.(d)
= rescale(S, s)
ΣS̄² = sumij2(S̄)
λ = (sumij(rescale!(uccov(Xc²), d)) / γ^2 - ΣS̄²) / κ
λ -= sum_fij(Xc .* s', S̄, n, κ)
λ = (sumij(rescale!(uccov(Xc², weights...), d)) / γ^2 - ΣS̄²) / κ
λ -= sum_fij(Xc .* s', S̄, n, κ, weights...)
= target_E(S̄)
λ /= sumij2(S̄ - F̄)
else
Expand All @@ -494,7 +516,7 @@ end

function linear_shrinkage(::ConstantCorrelation, Xc::AbstractMatrix,
S::AbstractMatrix, λ::Real, n::Int, p::Int,
corrected::Bool)
corrected::Bool, weights::FrequencyWeights...)

F, _ = target_F(S, p)
return linshrink!(F, S, λ)
Expand All @@ -507,27 +529,28 @@ Compute the shrinkage estimator where the target is a `ConstantCorrelation`.
"""
function linear_shrinkage(::ConstantCorrelation, Xc::AbstractMatrix,
S::AbstractMatrix, λ::Symbol, n::Int, p::Int,
corrected::Bool)
corrected::Bool, weights::FrequencyWeights...)

F, r̄ = target_F(S, p)
T = float(eltype(F))
κ = n - Int(corrected)
γ = T/n)
wn = totalweight(n, weights...)
κ = wn - Int(corrected)
γ = T/wn)
Xc² = Xc.^2
# computing the shrinkage
if λ [:auto, :lw]
ΣS² = sumij2(S)
λ = (sumij(uccov(Xc²)) / γ^2 - ΣS²) / κ
λ -=* sum_fij(Xc, S, n, κ)
λ = (sumij(uccov(Xc², weights...)) / γ^2 - ΣS²) / κ
λ -=* sum_fij(Xc, S, n, κ, weights...)
λ /= sumij2(S - F)
elseif λ == :ss
d = one(T) ./ vec(sum(Xc², dims=1))
d = one(T) ./ vec(sum(Xc², weights...; dims=1))
s = sqrt.(d)
= rescale(S, s)
F̄, r̄ = target_F(S̄, p)
ΣS̄² = sumij2(S̄)
λ = (sumij(rescale!(uccov(Xc²), d)) / γ^2 - ΣS̄²) / κ
λ -=* sum_fij(Xc .* s', S̄, n, κ)
λ = (sumij(rescale!(uccov(Xc², weights...), d)) / γ^2 - ΣS̄²) / κ
λ -=* sum_fij(Xc .* s', S̄, n, κ, weights...)
λ /= sumij2(S̄ - F̄)
else
throw(ArgumentError("Unsupported shrinkage method for target ConstantCorrelation: ."))
Expand Down
13 changes: 7 additions & 6 deletions src/nonlinearshrinkage.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ function analytical_nonlinear_shrinkage(S::AbstractMatrix{<:Real},
end

"""
cov(ans::AnalyticalNonlinearShrinkage, X; dims=1, mean=nothing)
cov(ans::AnalyticalNonlinearShrinkage, X, [weights]; dims=1, mean=nothing)
Nonlinear covariance estimator derived from the sample covariance estimator `S`
and its eigenvalue decomposition (which can be given through `decomp`).
Expand All @@ -134,19 +134,20 @@ vector (possibly a zero vector) and avoid the use of `mean=0`.
- (p<n): O(np^2 + n^2) with moderate constant
- (p>n): O(p^3) with low constant (dominated by eigendecomposition of S)
"""
function cov(ans::AnalyticalNonlinearShrinkage, X::AbstractMatrix{<:Real};
function cov(ans::AnalyticalNonlinearShrinkage, X::AbstractMatrix{<:Real}, weights::FrequencyWeights...;
dims::Int=1, mean=nothing)

@assert dims [1, 2] "Argument dims can only be 1 or 2 (given: $dims)"

szx = size(X)
(n, p) = ifelse(dims==1, szx, reverse(szx))
wn = floor(Int, totalweight(n, weights...))

# explained in the paper there must be at least 12 samples
(n < 12) && throw(ArgumentError("The number of samples `n` must be at " *
"least 12 (given: $n)."))
(wn < 12) && throw(ArgumentError("The (weighted) number of samples `n` must be at " *
"least 12 (given: $wn)."))

S = cov(SimpleCovariance(corrected=ans.corrected), X; dims=dims, mean=mean)
return analytical_nonlinear_shrinkage(S, n, p, mean === nothing;
S = cov(SimpleCovariance(corrected=ans.corrected), X, weights...; dims=dims, mean=mean)
return analytical_nonlinear_shrinkage(S, wn, p, mean === nothing;
decomp=ans.decomp)
end
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ function linshrink!(F::AbstractMatrix, S::AbstractMatrix, λ::Real)
F .= (one(λ) .- λ).*S .+ λ.*F
return Symmetric(F)
end

totalweight(n) = n
totalweight(_, weights) = sum(weights)
Loading

2 comments on commit d6b428d

@mateuszbaran
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/96064

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.10 -m "<description of version>" d6b428de4b1de17670de3f6f1106336bde637afc
git push origin v0.2.10

Please sign in to comment.