Skip to content

Commit

Permalink
Wait for trace channel to finish writing
Browse files Browse the repository at this point in the history
  • Loading branch information
RomeoV committed Oct 25, 2024
1 parent f50326a commit 23dfffa
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/KSVD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ function ksvd(Y::AbstractMatrix{T}, n_atoms::Int;

norm_results, nnz_per_col_results = Float64[], Float64[];
# if store_trace || show_trace
trace_channel = Channel{Tuple{Int, Matrix{T}, SparseMatrixCSC{T, Int64}}}(maxiters; spawn=true) do ch
trace_taskref = Ref{Task}()
CH_T = Tuple{Int, Matrix{T}, SparseMatrixCSC{T, Int64}}
trace_channel = Channel{CH_T}(maxiters; spawn=true, taskref=trace_taskref) do ch
for (iter, D, X) in ch
norm_val = norm.(eachcol(Y - D*X)) |> mean
nnz_per_col_val = nnz(X) / size(X, 2)
Expand Down Expand Up @@ -162,7 +164,9 @@ function ksvd(Y::AbstractMatrix{T}, n_atoms::Int;
termination_condition = :nnz_per_col_target; break
end
end
close(trace_channel)
TimerOutputs.complement!(timer)
wait(trace_taskref[]) # make sure trace has finished
return (; D, X, norm_results, nnz_per_col_results, termination_condition, timer)
end

Expand Down

0 comments on commit 23dfffa

Please sign in to comment.