diff --git a/src/KSVD.jl b/src/KSVD.jl index 7e37318..812238a 100644 --- a/src/KSVD.jl +++ b/src/KSVD.jl @@ -88,6 +88,7 @@ function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10; ksvd_update_method = BatchedParallelKSVD{false, T}(; shuffle_indices=true, batch_size_per_thread=1), sparse_coding_method = ParallelMatchingPursuit(; max_nnz, rtol=5e-2), minibatch_size=nothing, + D_init::Union{Nothing, <:AbstractMatrix{T}} = nothing, # termination conditions maxiters::Int=100, #: The maximum number of iterations to perform. Defaults to 100. maxtime::Union{Nothing, <:Real}=nothing,# : The maximum time for solving the nonlinear system of equations. Defaults to nothing which means no time limit. Note that setting a time limit does have a small overhead. @@ -104,7 +105,7 @@ function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10; # D is a dictionary matrix that contains atoms for columns. @timeit_debug timer "Init dict" begin - D = init_dictionary(T, emb_dim, n_atoms) # size(D) == (n, K) + D = (isnothing(D_init) ? init_dictionary(T, emb_dim, n_atoms) : D_init) # size(D) == (n, K) @assert all(≈(1.0), norm.(eachcol(D))) end X = sparse_coding(sparse_coding_method, Y, D; timer)