Skip to content

Commit

Permalink
Test DICG with Probability Simplex and MOI LMO.
Browse files Browse the repository at this point in the history
  • Loading branch information
Hendrych committed Nov 5, 2024
1 parent 8125a32 commit d5f1d65
Showing 1 changed file with 51 additions and 5 deletions.
56 changes: 51 additions & 5 deletions examples/optimal_experiment_design.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ using Random
using Distributions
using LinearAlgebra
using Statistics
using SCIP
using MathOptInterface
const MOI = MathOptInterface
using SparseArrays
using Test

# The Optimal Experiment Design Problem consists of choosing a subset of experiments
Expand Down Expand Up @@ -43,6 +47,29 @@ function build_data(m)
return A
end

"""
Build MOI version of the lmo.
"""
function build_moi_lmo(m)
o = SCIP.Optimizer()
MOI.empty!(o)
MOI.set(o, MOI.Silent(), true)

x = MOI.add_variables(o, m)

for xi in x
# each var has to be non-negative
MOI.add_constraint(o, xi, MOI.GreaterThan(0.0))
end

# sum of all variables has to be less than 1.0
MOI.add_constraint(o, sum(x, init=0.0), MOI.LessThan(1.0))

lmo = FrankWolfe.MathOptLMO(o)

return lmo
end

"""
Check if given point is in the domain of f, i.e. X = transpose(A) * diagm(x) * A
positive definite.
Expand Down Expand Up @@ -87,12 +114,11 @@ function build_start_point(A)
V = Vector{Float64}[]

for i in S
v = zeros(m)
v[i] = 1.0
v = FrankWolfe.ScaledHotVector(1.0, i, m)
push!(V, v)
end

x = sum(V .* 1/n)
x = SparseArrays.SparseVector(sum(V .* 1/n))
active_set= FrankWolfe.ActiveSet(fill(1/n, n), V, x)

return x, active_set, S
Expand Down Expand Up @@ -232,10 +258,20 @@ m = 300
f, grad! = build_a_criterion(A, build_safe=false)
x0, active_set = build_start_point(A)
domain_oracle = build_domain_oracle(A)
x, _, primal, dual_gap, traj_data = decomposition_invariant_conditional_gradient(f, grad!, lmo, x0, verbose=true,line_search=FrankWolfe.Secant(domain_oracle=domain_oracle), trajectory=true)
x_d, _, primal, dual_gap, traj_data_d = FrankWolfe.decomposition_invariant_conditional_gradient(f, grad!, lmo, x0, verbose=true,line_search=FrankWolfe.Secant(domain_oracle=domain_oracle), trajectory=true)

lmo = build_moi_lmo(m)
f, grad! = build_a_criterion(A, build_safe=false)
x0, active_set = build_start_point(A)
domain_oracle = build_domain_oracle(A)
x_d_m, _, primal, dual_gap, traj_data_d_m = FrankWolfe.decomposition_invariant_conditional_gradient(f, grad!, lmo, x0, verbose=true,line_search=FrankWolfe.Secant(domain_oracle=domain_oracle), trajectory=true)

@test traj_data_s[end][1] < traj_data[end][1]
@test traj_data_d[end][1] <= traj_data_s[end][1]
@test traj_data_d_m[end][1] <= traj_data_s[end][1]
@test isapprox(f(x_s), f(x))
@test isapprox(f(x_s), f(x_d))
@test isapprox(f(x_s), f(x_d_m))
end

@testset "D-Optimal Design" begin
Expand All @@ -256,10 +292,20 @@ m = 300
f, grad! = build_d_criterion(A, build_safe=false)
x0, active_set = build_start_point(A)
domain_oracle = build_domain_oracle(A)
x, _, primal, dual_gap, traj_data = decomposition_invariant_conditional_gradient(f, grad!, lmo, x0, verbose=true,line_search=FrankWolfe.Secant(domain_oracle=domain_oracle), trajectory=true)
x_d, _, primal, dual_gap, traj_data_d = FrankWolfe.decomposition_invariant_conditional_gradient(f, grad!, lmo, x0, verbose=true,line_search=FrankWolfe.Secant(domain_oracle=domain_oracle), trajectory=true)

lmo = build_moi_lmo(m)
f, grad! = build_d_criterion(A, build_safe=false)
x0, active_set = build_start_point(A)
domain_oracle = build_domain_oracle(A)
x_d_m, _, primal, dual_gap, traj_data_d_m = FrankWolfe.decomposition_invariant_conditional_gradient(f, grad!, lmo, x0, verbose=true,line_search=FrankWolfe.Secant(domain_oracle=domain_oracle), trajectory=true)

@test traj_data_s[end][1] < traj_data[end][1]
@test traj_data_d[end][1] <= traj_data_s[end][1]
@test traj_data_d_m[end][1] <= traj_data_s[end][1]
@test isapprox(f(x_s), f(x))
@test isapprox(f(x_s), f(x_d))
@test isapprox(f(x_s), f(x_d_m))
end
end

0 comments on commit d5f1d65

Please sign in to comment.