Skip to content

Commit

Permalink
Merge pull request #21 from benjione/OptimalTransport
Browse files Browse the repository at this point in the history
Implementation for solving entropic optimal transport
  • Loading branch information
benjione authored Oct 13, 2024
2 parents 33b1d1c + 9eca472 commit 37709e5
Show file tree
Hide file tree
Showing 7 changed files with 276 additions and 185 deletions.
142 changes: 142 additions & 0 deletions src/OptimalTransport.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
module OptimalTransport

using ..SequentialMeasureTransport
import ..SequentialMeasureTransport as SMT
using ..SequentialMeasureTransport: PSDDataVector
using Distributions
using FastGaussQuadrature: gausslegendre


function entropic_OT!(model::SMT.PSDModelOrthonormal{d2, T},
cost::Function,
p::Function,
q::Function,
ϵ::T,
XY::PSDDataVector{T};
X=nothing, Y=nothing,
preconditioner::Union{<:SMT.ConditionalMapping{d2, 0, T}, Nothing}=nothing,
reference::Union{<:SMT.ReferenceMap{d2, 0, T}, Nothing}=nothing,
use_putinar=true,
use_preconditioner_cost=false,
λ_marg=nothing,
kwargs...) where {d2, T<:Number}
@assert d2 % 2 == 0
d = d2 ÷ 2
reverse_KL_cost = begin
if use_preconditioner_cost
let p=p, q=q
x->p(x[1:d]) * q(x[d+1:end])
end
else
_rev_KL_density = let p=p, q=q
x -> p(x[1:d]) * q(x[d+1:end])
end
if reference !== nothing
_rev_KL_density = SMT.pushforward(reference, _rev_KL_density)
end
if preconditioner === nothing
_rev_KL_density
else
SMT.pullback(preconditioner, _rev_KL_density)
end
end
end

cost_pb = begin
_cost = let cost=cost
x -> cost(x)
end
if reference !== nothing
_cost = SMT.pushforward(reference, _cost)
else
_cost
end
end

if preconditioner !== nothing
cost_pb = let cost_pb=cost_pb
x -> cost_pb(SMT.pushforward(preconditioner, x))
end
end

ξ = map(x->reverse_KL_cost(x), XY)
ξ2 = map(x->cost_pb(x), XY)
if λ_marg === nothing
## estimate the order of the reverse KL cost to find an acceptable λ_marg
## to do that, we calculate KL(U||reverse_KL_cost) where U is the distribution of XY
_order_rev_KL = (sum(ξ2) - ϵ * sum(log.(ξ))) / length(ξ)
λ_marg = 10.0*_order_rev_KL
@info "Estimated order of the reverse KL cost: $_order_rev_KL \n
Setting λ_marg to $λ_marg"
end

model_for_marg = if preconditioner === nothing
model
else
SMT._add_mapping(model, preconditioner)
end
if X === nothing
X = [x[1:d] for x in XY]
end
if Y === nothing
Y = [x[d+1:end] for x in XY]
end


_p, _q = if reference !== nothing
_p = SMT.pushforward(reference[1:d], p)
_q = SMT.pushforward(reference[d+1:end], q)
_p, _q
else
p, q
end

## pushforward the samples
if reference !== nothing
_XY_marg = SMT.pushforward.(Ref(reference), [[x;y] for (x, y) in zip(X, Y)])
X = [x[1:d] for x in _XY_marg]
Y = [x[d+1:end] for x in _XY_marg]
end

## evaluate the marginals on the original samples
p_X = map(_p, X)
q_Y = map(_q, Y)
e_X = collect(1:d)
e_Y = collect(d+1:d2)
if use_putinar && (typeof(model) <: SMT.PSDModelPolynomial)
D, C = SMT.get_semialgebraic_domain_constraints(model)
return SMT._OT_JuMP!(model, cost_pb, ϵ, XY, ξ; mat_list=D, coef_list=C,
model_for_marginals=model_for_marg,
marg_regularization = [(e_X, X, p_X), (e_Y, Y, q_Y)],
λ_marg_reg=λ_marg,
kwargs...)
else
return SMT._OT_JuMP!(model, cost_pb, ϵ, XY, ξ;
model_for_marginals=model_for_marg,
marg_regularization = [(e_X, X, p_X), (e_Y, Y, q_Y)],
λ_marg_reg=λ_marg,
kwargs...)
end
end

function Wasserstein_Barycenter(model::SMT.PSDModelOrthonormal{d2, T},
measures::AbstractVector{<:Function},
weights::AbstractVector{T},
ϵ::T,
XY::PSDDataVector{T};
X=nothing, Y=nothing,
preconditioner::Union{<:SMT.ConditionalMapping{d2, 0, T}, Nothing}=nothing,
reference::Union{<:SMT.ReferenceMap{d2, 0, T}, Nothing}=nothing,
use_putinar=true,
use_preconditioner_cost=false,
λ_marg=nothing,
kwargs...
) where {d2, T<:Number}
d = d2 * (length(measures) + 1)

throw(error("Not implemented yet"))

end


end
2 changes: 2 additions & 0 deletions src/PSDModels/feature_map/polynomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ end
@inline _remove_mapping(a::PSDModelPolynomial{d, T, <:ConditionalMapping{d, <:Any, T}}) where {d, T<:Number} =
PSDModelPolynomial(a.B, a.Φ)

@inline _add_mapping(a::PSDModelPolynomial{d, T, Nothing}, mapping::ConditionalMapping{d, <:Any, T}) where {d, T<:Number} =
PSDModelPolynomial(a.B, a.Φ, mapping)

## Pretty printing
function Base.show(io::IO, a::PSDModelPolynomial{d, T, S}) where {d, T, S}
Expand Down
11 changes: 11 additions & 0 deletions src/Samplers/reference_maps/algebraic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@ struct AlgebraicReference{d, dC, T} <: ReferenceMap{d, dC, T}
end
end

function Base.getindex(m::AlgebraicReference{d, dC, T}, I) where {d, dC, T}
vars = collect(I)
if length(vars) == 0
return nothing
end
d_new = length(vars)
dC_new = length(intersect(d-dC+1:d, vars))
# @assert setdiff(1:d-dC, vars) == 1:dC_new
return AlgebraicReference{d_new, dC_new, T}()
end

function _pushforward(m::AlgebraicReference{<:Any, <:Any, T}, x::PSDdata{T}) where {T<:Number}
return ((x./sqrt.(1 .+ x.^2)).+1.0)/2.0
end
Expand Down
12 changes: 12 additions & 0 deletions src/Samplers/reference_maps/reference_maps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ module ReferenceMaps
Reference maps are diagonal maps used to work on domains of choice, while keeping the
density estimation in [0, 1]. The reference map is defined as a map from the uniform
distribution on [0, 1]^d to a domain of choice.
R_♯ ρ = U
"""

import ..SequentialMeasureTransport as SMT
Expand Down Expand Up @@ -34,6 +36,16 @@ Attention!
Using any of the functions with dim(x) < d will take a marginal distribution and pdf.
"""


"""
Get a reference maps on the subspace of dimension, e.g. [3:4]
"""
function Base.getindex(m::ReferenceMap{d, dC, T}, I) where {d, dC, T}
throw(error("Not implemented"))
end

Base.lastindex(m::ReferenceMap{d}) where {d} = d

### Interface for ReferenceMaps
@inline function Distributions.pdf(Rmap::ReferenceMap{<:Any, <:Any, T},
x::PSDdata{T}
Expand Down
4 changes: 4 additions & 0 deletions src/SequentialMeasureTransport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,8 @@ include("Samplers/sampler.jl")
include("statistics.jl")
using .Statistics

# methods to create Optimal Transport plans and more
include("OptimalTransport.jl")
using .OptimalTransport

end # module PositiveSemidefiniteModels
Loading

0 comments on commit 37709e5

Please sign in to comment.