Skip to content

Commit

Permalink
Implement callback function and add iter to logs
Browse files Browse the repository at this point in the history
  • Loading branch information
RomeoV committed Jul 16, 2024
1 parent 501cf81 commit 82f61c1
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/KSVD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,11 @@ function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10;
maxtime::Union{Nothing, <:Real}=nothing,# : The maximum time for solving the nonlinear system of equations. Defaults to nothing which means no time limit. Note that setting a time limit does have a small overhead.
abstol::Number=real(oneunit(T)) * (eps(real(one(T))))^(4 // 5), #: The absolute tolerance. Defaults to real(oneunit(T)) * (eps(real(one(T))))^(1 // 2).
reltol::Number=real(oneunit(T)) * (eps(real(one(T))))^(4 // 5), #: The relative tolerance. Defaults to real(oneunit(T)) * (eps(real(one(T))))^(1 // 2).
nnz_per_col_target::Int=0,
nnz_per_col_target::Float64=0.0,
# tracing options
show_trace::Bool=false,
# store_trace::Bool,
callback_fn::Union{Nothing, Function}=nothing,
) where T
timer = TimerOutput()
emb_dim, n_samples = size(Y)
Expand All @@ -113,12 +114,13 @@ function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10;

norm_results, nnz_per_col_results = Float64[], Float64[];
# if store_trace || show_trace
trace_channel = Channel{Tuple{Matrix{T}, SparseMatrixCSC{T, Int64}}}(; spawn=true) do ch
for (D, X) in ch
trace_channel = Channel{Tuple{Int, Matrix{T}, SparseMatrixCSC{T, Int64}}}(maxiters; spawn=true) do ch
for (iter, D, X) in ch
norm_val = norm(Y - D*X)
nnz_per_col_val = nnz(X) / size(X, 2)
show_trace && @info norm_val, nnz_per_col_val
show_trace && @info (iter, norm_val, nnz_per_col_val)
(push!(norm_results, norm_val); push!(nnz_per_col_results, nnz_per_col_val))
!isnothing(callback_fn) && callback_fn((; iter, Y, D, X, norm_val, nnz_per_col_val))
end
end

Expand All @@ -131,7 +133,7 @@ function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10;
D, X = ksvd_update(ksvd_update_method, Y, D, X; timer)

# put a task to compute the trace / termination conditions.
push!(trace_channel, (copy(D), copy(X)))
push!(trace_channel, (iter, copy(D), copy(X)))

# Check termination conditions.
# Notice that this is typically not using the most recent results yet. So we might only later realize that we
Expand Down

0 comments on commit 82f61c1

Please sign in to comment.