Skip to content

Commit

Permalink
Add Backtracking as fall back for the Secant linesearch (#504)
Browse files Browse the repository at this point in the history
* Show lowest dual gap in last iteration.

* Add backtracking as a fall back linesearch in Secant.

* Let the user decide the fallback line search.

* Adjust syntax.

* Type stability.

---------

Co-authored-by: Hendrych <[email protected]>
  • Loading branch information
dhendryc and Hendrych authored Sep 25, 2024
1 parent a67fdba commit 7813f20
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
4 changes: 2 additions & 2 deletions examples/optimal_experiment_design.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ m = 300
f, grad! = build_a_criterion(A, build_safe=false)
x0, active_set = build_start_point(A)
domain_oracle = build_domain_oracle(A)
x_s, _, primal, dual_gap, traj_data_s, _ = FrankWolfe.blended_pairwise_conditional_gradient(f, grad!, lmo, active_set, verbose=true, line_search=FrankWolfe.Secant(40, 1e-8, domain_oracle), trajectory=true)
x_s, _, primal, dual_gap, traj_data_s, _ = FrankWolfe.blended_pairwise_conditional_gradient(f, grad!, lmo, active_set, verbose=true, line_search=FrankWolfe.Secant(domain_oracle=domain_oracle), trajectory=true)

@test traj_data_s[end][1] < traj_data[end][1]
@test isapprox(f(x_s), f(x))
Expand All @@ -244,7 +244,7 @@ m = 300
f, grad! = build_d_criterion(A, build_safe=false)
x0, active_set = build_start_point(A)
domain_oracle = build_domain_oracle(A)
x_s, _, primal, dual_gap, traj_data_s, _ = FrankWolfe.blended_pairwise_conditional_gradient(f, grad!, lmo, active_set, verbose=true, line_search=FrankWolfe.Secant(40, 1e-8, domain_oracle), trajectory=true)
x_s, _, primal, dual_gap, traj_data_s, _ = FrankWolfe.blended_pairwise_conditional_gradient(f, grad!, lmo, active_set, verbose=true, line_search=FrankWolfe.Secant(domain_oracle=domain_oracle), trajectory=true)

@test traj_data_s[end][1] < traj_data[end][1]
@test isapprox(f(x_s), f(x))
Expand Down
37 changes: 30 additions & 7 deletions src/linesearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,28 +369,31 @@ Convergence is not guaranteed in general.
# References
- [Secant Method](https://en.wikipedia.org/wiki/Secant_method)
"""
struct Secant{F} <: LineSearchMethod
struct Secant{F,LSM<:LineSearchMethod} <: LineSearchMethod
inner_ls::LSM
limit_num_steps::Int
tol::Float64
domain_oracle::F
end

function Secant(limit_num_steps, tol)
return Secant(limit_num_steps, tol, x -> true)
return Secant(Backtracking(), limit_num_steps, tol, x -> true)
end

function Secant(; limit_num_steps=40, tol=1e-8)
return Secant(limit_num_steps, tol)
function Secant(;inner_ls=Backtracking(), limit_num_steps=40, tol=1e-8, domain_oracle=(x -> true))
return Secant(inner_ls, limit_num_steps, tol, domain_oracle)
end

mutable struct SecantWorkspace{XT,GT}
mutable struct SecantWorkspace{XT,GT, IWS}
inner_ws::IWS
x::XT
gradient::GT
last_gamma::Float64
end

function build_linesearch_workspace(::Secant, x, gradient)
return SecantWorkspace(similar(x), similar(gradient), 1.0) # Initialize last_gamma to 1.0
function build_linesearch_workspace(ls::Secant, x, gradient)
inner_ws = build_linesearch_workspace(ls.inner_ls, x, gradient)
return SecantWorkspace(inner_ws, similar(x), similar(gradient), 1.0) # Initialize last_gamma to 1.0
end

function perform_line_search(
Expand Down Expand Up @@ -448,6 +451,26 @@ function perform_line_search(
dot_gdir = dot_gdir_new
i += 1
end
if abs(dot_gdir) > line_search.tol
gamma = perform_line_search(
line_search.inner_ls,
0,
f,
grad!,
gradient,
x,
d,
gamma_max,
workspace.inner_ws,
memory_mode,
)

storage = muladd_memory_mode(memory_mode, storage, x, gamma, d)
new_val = f(storage)

@assert new_val <= best_val
best_gamma = gamma
end
workspace.last_gamma = best_gamma # Update last_gamma before returning
return best_gamma
end
Expand Down

0 comments on commit 7813f20

Please sign in to comment.