Skip to content

Commit

Permalink
Allow passing a D_init
Browse files Browse the repository at this point in the history
  • Loading branch information
RomeoV committed Jul 19, 2024
1 parent 71f77da commit 756593a
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/KSVD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10;
ksvd_update_method = BatchedParallelKSVD{false, T}(; shuffle_indices=true, batch_size_per_thread=1),
sparse_coding_method = ParallelMatchingPursuit(; max_nnz, rtol=5e-2),
minibatch_size=nothing,
D_init::Union{Nothing, <:AbstractMatrix{T}} = nothing,
# termination conditions
maxiters::Int=100, #: The maximum number of iterations to perform. Defaults to 100.
maxtime::Union{Nothing, <:Real}=nothing,# : The maximum time for solving the nonlinear system of equations. Defaults to nothing which means no time limit. Note that setting a time limit does have a small overhead.
Expand All @@ -104,7 +105,7 @@ function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10;

# D is a dictionary matrix that contains atoms for columns.
@timeit_debug timer "Init dict" begin
D = init_dictionary(T, emb_dim, n_atoms) # size(D) == (n, K)
D = (isnothing(D_init) ? init_dictionary(T, emb_dim, n_atoms) : D_init) # size(D) == (n, K)
@assert all((1.0), norm.(eachcol(D)))
end
X = sparse_coding(sparse_coding_method, Y, D; timer)
Expand Down

0 comments on commit 756593a

Please sign in to comment.