Skip to content

Commit

Permalink
Implement callback function and add iter to logs
Browse files Browse the repository at this point in the history
<p dir="auto">We use this e.g. to implement a logging callback for logging to mlflow. An example snippet</p>
<div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="julia&gt; using PythonCall
julia&gt; mlflow = pyimport(&quot;mlflow&quot;)
julia&gt; mlflow.set_tracking_uri(&quot;http://mycloud-nixos:8095&quot;)
julia&gt; pywith(mlflow.start_run()) do _
                  for iter in 1:10
                      mlflow.log_metric(&quot;foo&quot;, 1*iter, step=iter)
                      mlflow.log_metric(&quot;bar&quot;, 2)
                  end
               end"><pre class="notranslate"><span class="pl-s1">julia</span><span class="pl-c1">&gt;</span> <span class="pl-s1">using</span> <span class="pl-v">PythonCall</span>
<span class="pl-s1">julia</span><span class="pl-c1">&gt;</span> <span class="pl-s1">mlflow</span> <span class="pl-c1">=</span> <span class="pl-en">pyimport</span>(<span class="pl-s">"mlflow"</span>)
<span class="pl-s1">julia</span><span class="pl-c1">&gt;</span> <span class="pl-s1">mlflow</span>.<span class="pl-en">set_tracking_uri</span>(<span class="pl-s">"http://mycloud-nixos:8095"</span>)
<span class="pl-s1">julia</span><span class="pl-c1">&gt;</span> <span class="pl-en">pywith</span>(<span class="pl-s1">mlflow</span>.<span class="pl-en">start_run</span>()) <span class="pl-s1">do</span> <span class="pl-s1">_</span>
                  <span class="pl-k">for</span> <span class="pl-s1">iter</span> <span class="pl-c1">in</span> <span class="pl-c1">1</span>:<span class="pl-c1">10</span>
                      <span class="pl-s1">mlflow</span>.<span class="pl-en">log_metric</span>(<span class="pl-s">"foo"</span>, <span class="pl-c1">1</span><span class="pl-c1">*</span><span class="pl-s1">iter</span>, <span class="pl-s1">step</span><span class="pl-c1">=</span><span class="pl-s1">iter</span>)
                      <span class="pl-s1">mlflow</span>.<span class="pl-en">log_metric</span>(<span class="pl-s">"bar"</span>, <span class="pl-c1">2</span>)
                  <span class="pl-s1">end</span>
               <span class="pl-s1">end</span></pre></div>
  • Loading branch information
RomeoV authored Jul 16, 2024
2 parents 80de0e8 + 82f61c1 commit c181090
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/KSVD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,11 @@ function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10;
maxtime::Union{Nothing, <:Real}=nothing,# : The maximum time for solving the nonlinear system of equations. Defaults to nothing which means no time limit. Note that setting a time limit does have a small overhead.
abstol::Number=real(oneunit(T)) * (eps(real(one(T))))^(4 // 5), #: The absolute tolerance. Defaults to real(oneunit(T)) * (eps(real(one(T))))^(1 // 2).
reltol::Number=real(oneunit(T)) * (eps(real(one(T))))^(4 // 5), #: The relative tolerance. Defaults to real(oneunit(T)) * (eps(real(one(T))))^(1 // 2).
nnz_per_col_target::Int=0,
nnz_per_col_target::Float64=0.0,
# tracing options
show_trace::Bool=false,
# store_trace::Bool,
callback_fn::Union{Nothing, Function}=nothing,
) where T
timer = TimerOutput()
emb_dim, n_samples = size(Y)
Expand All @@ -113,12 +114,13 @@ function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10;

norm_results, nnz_per_col_results = Float64[], Float64[];
# if store_trace || show_trace
trace_channel = Channel{Tuple{Matrix{T}, SparseMatrixCSC{T, Int64}}}(; spawn=true) do ch
for (D, X) in ch
trace_channel = Channel{Tuple{Int, Matrix{T}, SparseMatrixCSC{T, Int64}}}(maxiters; spawn=true) do ch
for (iter, D, X) in ch
norm_val = norm(Y - D*X)
nnz_per_col_val = nnz(X) / size(X, 2)
show_trace && @info norm_val, nnz_per_col_val
show_trace && @info (iter, norm_val, nnz_per_col_val)
(push!(norm_results, norm_val); push!(nnz_per_col_results, nnz_per_col_val))
!isnothing(callback_fn) && callback_fn((; iter, Y, D, X, norm_val, nnz_per_col_val))
end
end

Expand All @@ -131,7 +133,7 @@ function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10;
D, X = ksvd_update(ksvd_update_method, Y, D, X; timer)

# put a task to compute the trace / termination conditions.
push!(trace_channel, (copy(D), copy(X)))
push!(trace_channel, (iter, copy(D), copy(X)))

# Check termination conditions.
# Notice that this is typically not using the most recent results yet. So we might only later realize that we
Expand Down

0 comments on commit c181090

Please sign in to comment.