Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow D_init and nothing tolerances #4

Merged
merged 4 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,7 @@ jobs:
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
- uses: actions/cache@v3
env:
cache-name: cache-artifacts
with:
path: ~/.julia/artifacts
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
restore-keys: |
${{ runner.os }}-test-${{ env.cache-name }}-
${{ runner.os }}-test-
${{ runner.os }}-
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
Expand Down
9 changes: 5 additions & 4 deletions src/KSVD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,12 @@ function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10;
ksvd_update_method = BatchedParallelKSVD{false, T}(; shuffle_indices=true, batch_size_per_thread=1),
sparse_coding_method = ParallelMatchingPursuit(; max_nnz, rtol=5e-2),
minibatch_size=nothing,
D_init::Union{Nothing, <:AbstractMatrix{T}} = nothing,
# termination conditions
maxiters::Int=100, #: The maximum number of iterations to perform. Defaults to 100.
maxtime::Union{Nothing, <:Real}=nothing,# : The maximum time for solving the nonlinear system of equations. Defaults to nothing which means no time limit. Note that setting a time limit does have a small overhead.
abstol::Number=real(oneunit(T)) * (eps(real(one(T))))^(4 // 5), #: The absolute tolerance. Defaults to real(oneunit(T)) * (eps(real(one(T))))^(1 // 2).
reltol::Number=real(oneunit(T)) * (eps(real(one(T))))^(4 // 5), #: The relative tolerance. Defaults to real(oneunit(T)) * (eps(real(one(T))))^(1 // 2).
abstol::Union{Nothing, <:Real}=real(oneunit(T)) * (eps(real(one(T))))^(4 // 5), #: The absolute tolerance. Defaults to real(oneunit(T)) * (eps(real(one(T))))^(1 // 2).
reltol::Union{Nothing, <:Real}=real(oneunit(T)) * (eps(real(one(T))))^(4 // 5), #: The relative tolerance. Defaults to real(oneunit(T)) * (eps(real(one(T))))^(1 // 2).
nnz_per_col_target::Number=0.0,
# tracing options
show_trace::Bool=false,
Expand All @@ -104,7 +105,7 @@ function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10;

# D is a dictionary matrix that contains atoms for columns.
@timeit_debug timer "Init dict" begin
D = init_dictionary(T, emb_dim, n_atoms) # size(D) == (n, K)
D = (isnothing(D_init) ? init_dictionary(T, emb_dim, n_atoms) : D_init) # size(D) == (n, K)
@assert all(≈(1.0), norm.(eachcol(D)))
end
X = sparse_coding(sparse_coding_method, Y, D; timer)
Expand Down Expand Up @@ -152,7 +153,7 @@ function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10;
termination_condition = :maxiter; break
elseif !isnothing(maxtime) && (time() - tic) > maxtime
termination_condition = :maxtime; break
elseif length(norm_results) > 1 && isapprox(norm_results[end], norm_results[end-1]; atol=abstol, rtol=reltol)
elseif (!isnothing(abstol) && !isnothing(reltol)) && length(norm_results) > 1 && isapprox(norm_results[end], norm_results[end-1]; atol=abstol, rtol=reltol)
termination_condition = :converged; break
elseif !isempty(nnz_per_col_results) && last(nnz_per_col_results) <= nnz_per_col_target
termination_condition = :nnz_per_col_target; break
Expand Down
6 changes: 3 additions & 3 deletions src/matching_pursuit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ products[t+1] = dictionary' * (residual[t] - dictionary[idx] * a)
products_abs = abs.(products) # prealloc

for i in 1:max_iter
!isfinite(norm(residual)) && @show norm(residual), residual
!isfinite(norm_data) && @show norm_data
# @assert(norm(residual)) && @show norm(residual), residual
# @assert(isfinite(norm_data))
if norm(residual)/norm_data < rtol
return sparsevec(xdict, n_atoms)
end
Expand All @@ -261,7 +261,7 @@ products[t+1] = dictionary' * (residual[t] - dictionary[idx] * a)

a = products[maxindex]
atom = @view dictionary[:, maxindex]
@assert norm(atom) ≈ 1. norm(atom)
# @assert norm(atom) ≈ 1. norm(atom)

residual .-= a .* atom
products .-= a .* @view DtD[:, maxindex]
Expand Down
Loading