Skip to content

Commit

Permalink
Merge pull request #22 from benjione/InvertibleNetworksExtension
Browse files Browse the repository at this point in the history
Work on triangular transport maps
  • Loading branch information
benjione authored Oct 17, 2024
2 parents 37709e5 + ce9b8d3 commit a5428d6
Show file tree
Hide file tree
Showing 17 changed files with 388 additions and 193 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
1 change: 1 addition & 0 deletions src/PSDModels/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ function (a::PSDModel{T})(x::PSDdata{T}, B::AbstractMatrix) where {T<:Number}
return dot(v, B, v)
end


function set_coefficients!(a::PSDModel{T}, B::Hermitian{T}) where {T<:Number}
a.B .= B
end
Expand Down
17 changes: 10 additions & 7 deletions src/Samplers/reference_maps/gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,29 @@ end

function SMT.pushforward(
m::GaussianReference{d, <:Any, T},
x::PSDdata{T}
) where {d, T<:Number}
x::PSDdata{T2}
) where {d, T<:Number, T2<:Number}
@assert length(x) == d
return 0.5 * (1 .+ erf.(x ./ (m.σ * sqrt(2))))
res = copy(x)
res ./= (m.σ * sqrt(2))
map!(z->0.5 * (1 + erf(z)), res, res)
return res
end


function SMT.pullback(
m::GaussianReference{d, <:Any, T},
u::PSDdata{T}
) where {d, T<:Number}
u::PSDdata{T2}
) where {d, T<:Number, T2<:Number}
@assert length(u) == d
return sqrt(2) * m.σ * erfcinv.(2.0 .- 2*u)
end


function SMT.Jacobian(
m::GaussianReference{d, <:Any, T},
x::PSDdata{T}
) where {d, T<:Number}
x::PSDdata{T2}
) where {d, T<:Number, T2<:Number}
@assert length(x) == d
return mapreduce(xi->Distributions.pdf(Distributions.Normal(0, m.σ), xi), *, x)
end
Expand Down
12 changes: 8 additions & 4 deletions src/Samplers/samplers/Sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,19 @@ ATTENTION:
where T_i is the i-th map in the sampler.
"""
struct CondSampler{d, dC, T, R1, R2} <: AbstractCondSampler{d, dC, T, R1, R2}
samplers::Vector{<:ConditionalMapping{d, dC, T}} # defined on [0, 1]^d
R1_map::R1 # reference map from reference distribution to uniform on [0, 1]^d
R2_map::R2 # distribution from domain of pi to [0, 1]^d
samplers::Vector{<:ConditionalMapping{d, dC, T}} # defined on internal domain
internal_domain::ProductDomain{<:AbstractVector{T}}
internal_measure::Distributions.Product{Distributions.Continuous}
R1_map::R1 # reference map from reference (domain, distribution) to internal (domain, distribution)
R2_map::R2 # distribution from target (domain, distribution) to internal (domain, distribution)
function CondSampler(
samplers::Vector{<:ConditionalMapping{d, dC, T}},
R1_map::Union{<:ReferenceMap{d, dC, T}, Nothing},
R2_map::Union{<:ReferenceMap{d, dC, T}, Nothing}
) where {d, T<:Number, dC}
new{d, dC, T, typeof(R1_map), typeof(R2_map)}(samplers, R1_map, R2_map)
internal_domain = ProductDomain([UnitInterval() for _=1:d])
internal_measure = Distributions.product_distribution([Distributions.Uniform(0.0, 1.0) for _=1:d])
new{d, dC, T, typeof(R1_map), typeof(R2_map)}(samplers, internal_domain, internal_measure, R1_map, R2_map)
end
function CondSampler(
samplers::Vector{<:ConditionalMapping},
Expand Down
202 changes: 172 additions & 30 deletions src/Samplers/triangular_maps/ATM.jl
Original file line number Diff line number Diff line change
@@ -1,54 +1,196 @@


struct ATM{d, dC, T<:Number} <: AbstractTriangularMap{d, dC, T}
f::Vector{<:FMTensorPolynomial{<:Any, T}}
# abstract type ATM{d, dC, T} <: AbstractTriangularMap{d, dC, T} end

struct PolynomialATM{d, dC, T<:Number} <: AbstractTriangularMap{d, dC, T}
f::Vector{<:TensorFunction{<:Any, T}}
coeff::Vector{<:Vector{T}}
g::Function
variable_ordering::Vector{Int}
function ATM(f::Vector{FMTensorPolynomial{<:Any, T}}, g::Function, variable_ordering::Vector{Int}, dC::Int) where {T<:Number}
function PolynomialATM(f::Vector{<:TensorFunction{<:Any, T}}, g::Function, variable_ordering::Vector{Int}, dC::Int) where {T<:Number}
d = length(f)
coeff = Vector{Vector{T}}(undef, d)
for k=1:d
coeff[k] = randn(T, length(f[k](rand(k))))
end
new{d, dC, T}(f, coeff, g, variable_ordering)
end
function ATM(f::Vector{FMTensorPolynomial}, g::Function, variable_ordering::Vector{Int})
ATM(f, g, variable_ordering, 0)
function PolynomialATM(f::Vector{<:TensorFunction}, g::Function, variable_ordering::Vector{Int})
PolynomialATM(f, g, variable_ordering, 0)
end
end

int_x, int_w = gausslegendre(50)
int_x .= int_x * 0.5 .+ 0.5
int_w .= int_w * 0.5
@inline MonotoneMap(sampler::PolynomialATM{d, <:Any, T}, x::PSDdata{T}, k::Int) where {d, T<:Number} = MonotoneMap(sampler, x, k, sampler.coeff[k])
function MonotoneMap(sampler::PolynomialATM{d, <:Any, T}, x::PSDdata{T}, k::Int,
coeff::AbstractVector{T}) where {d, T<:Number}
f_part(z) = begin
sampler.f[k]([x[1:k-1]; z])
end
f_partial(z::T) = FD.derivative(f_part, z)
int_f(z::T) = sampler.g(dot(coeff, f_partial(z)))

_int_x = int_x * x[k]
_int_w = int_w * x[k]
res = dot(coeff, sampler.f[k]([x[1:k-1]; 0]))
for i=1:length(int_x)
@inbounds res += _int_w[i] * int_f(_int_x[i])
end
return res
end

@inline MonotoneMap(sampler::ATM{d, <:Any, T}, x::PSDdata{T}, k::Int) where {d, T<:Number} = MonotoneMap(sampler, x, k, sampler.coeff[k])
function MonotoneMap(sampler::ATM{d, <:Any, T}, x::PSDdata{T}, k::Int, coeff) where {d, T<:Number}
f_part(z) = sampler.f[k]([x[1:k-1]; z])
f_partial(z) = FD.derivative(f_part, z)
int_f(z) = sampler.g(dot(coeff, f_partial(z)))
int_x, int_w = gausslegendre(100)
int_x .= int_x * 0.5 .+ 0.5
int_x .= int_x * x[k]
int_w .= int_w * 0.5
int_w .*= x[k]
function ∇MonotoneMap(sampler::PolynomialATM{d, <:Any, T}, x::PSDdata{T}, k::Int,
coeff::AbstractVector{T}) where {d, T<:Number}
grad = zeros(size(coeff))
∇MonotoneMap!(grad, sampler, x, k, coeff)
return grad
end

int_part = sum(int_w .* int_f.(int_x))
return dot(coeff, sampler.f[k]([x[1:k-1]; 0])) + int_part
function ∇MonotoneMap!(grad::AbstractVector{T}, sampler::PolynomialATM{d, <:Any, T},
x::PSDdata{T}, k::Int, coeff::AbstractVector{T};
weight::T=one(T)) where {d, T<:Number}
f_part(z) = begin
sampler.f[k]([x[1:k-1]; z])
end
f_partial(z::T) = FD.derivative(f_part, z)
int_f2!(tmp::AbstractVector{T}, z::T) = begin
tmp .= f_partial(z)
tmp .*= FD.derivative(sampler.g, dot(coeff, tmp))
return nothing
end
_int_x = int_x * x[k]
_int_w = int_w * x[k]
tmp = similar(coeff)
grad .+= weight * sampler.f[k]([x[1:k-1]; 0])
for i=1:length(int_x)
int_f2!(tmp, @inbounds _int_x[i])
@inbounds grad .+= weight * _int_w[i] * tmp
end
return nothing
end

@inline ∂k_MonotoneMap(sampler::ATM{d, <:Any, T}, x::PSDdata{T}, k::Int) where {d, T<:Number} = ∂k_MonotoneMap(sampler, x, k, sampler.coeff[k])
function ∂k_MonotoneMap(sampler::ATM{d, <:Any, T}, x::PSDdata{T}, k::Int, coeff) where {d, T<:Number}
@inline ∂k_MonotoneMap(sampler::PolynomialATM{d, <:Any, T}, x::PSDdata{T}, k::Int) where {d, T<:Number} = ∂k_MonotoneMap(sampler, x, k, sampler.coeff[k])
function ∂k_MonotoneMap(sampler::PolynomialATM{d, <:Any, T}, x::PSDdata{T}, k::Int,
coeff::AbstractVector{T}) where {d, T<:Number}
f_part(z) = sampler.f[k]([x[1:k-1]; z])
f_partial(z) = FD.derivative(f_part, z)
int_f(z) = sampler.g(dot(coeff, f_partial(z)))
f_partial(z::T) = FD.derivative(f_part, z)
int_f(z::T) = sampler.g(dot(coeff, f_partial(z)))
return int_f(x[k])
end

function ∇∂k_MonotoneMap(sampler::PolynomialATM{d, <:Any, T}, x::PSDdata{T}, k::Int,
coeff::AbstractVector{T}) where {d, T<:Number}
grad = zeros(size(coeff))
∇∂k_MonotoneMap!(grad, sampler, x, k, coeff)
return grad
end

function ∇∂k_MonotoneMap!(grad::AbstractVector{T}, sampler::PolynomialATM{d, <:Any, T}, x::PSDdata{T}, k::Int,
coeff::AbstractVector{T}; weight::T=one(T)) where {d, T<:Number}
f_part(z) = sampler.f[k]([x[1:k-1]; z])
f_partial(z::T) = FD.derivative(f_part, z)
tmp = f_partial(x[k])
g_diff = FD.derivative(sampler.g, dot(coeff, tmp))
grad .+= weight * g_diff * tmp
return nothing
end

"""
Map of type
f1(x_{1:k-1}) + g(f2(x_{1:k-1})) * x_k
"""
struct PolynomialCouplingATM{d, dC, T} <: AbstractTriangularMap{d, dC, T}
f1::AbstractVector{<:TensorFunction{<:Any, T}}
f2::AbstractVector{<:TensorFunction{<:Any, T}}
coeff::AbstractVector{<:AbstractMatrix{T}}
g::Function
# poly_measure::Function
variable_ordering::Vector{Int}
end

function PolynomialCouplingATM(f1, f2, g, variable_ordering)
d = length(f1)
coeff = Vector{Matrix{Float64}}(undef, d)
coeff = [k==1 ? randn(2, 1) : randn(Float64, 2, length(f1[k](rand(k-1)))) for k=1:d]
# for k=1:d
# coeff[k] = randn(Float64, 2, length(f1[k](rand(k))))
# end
PolynomialCouplingATM{d, 0, Float64}(f1, f2, coeff, g, variable_ordering)
end

@inline MonotoneMap(sampler::PolynomialCouplingATM{d, <:Any, T}, x::PSDdata{T}, k::Int) where {d, T<:Number} = MonotoneMap(sampler, x, k, sampler.coeff[k])
function MonotoneMap(sampler::PolynomialCouplingATM{d, <:Any, T}, x::PSDdata{T},
k::Int, coeff::AbstractMatrix{T2}) where {d, T<:Number, T2<:Number}
if k==1
return coeff[1, 1] + x[1]
end
# print("here ", exp(-norm(x[1:k-1])^2/2), exp(-norm(x[1:k-1])^2/2) * dot(coeff[2, :], sampler.f2[k](x[1:k-1])))
return dot(coeff[1, :], sampler.f1[k](x[1:k-1])) +
sampler.g(dot(coeff[2, :], sampler.f2[k](x[1:k-1]))) * x[k]
end

function ∇MonotoneMap(sampler::PolynomialCouplingATM{d, <:Any, T}, x::PSDdata{T}, k::Int,
coeff::AbstractMatrix{T2}) where {d, T<:Number, T2<:Number}
if k==1
return hcat([1.0], [0.0])'
end
g_diff = FD.derivative(sampler.g, dot(coeff[2, :], sampler.f2[k](x[1:k-1])))
return hcat(sampler.f1[k](x[1:k-1]), g_diff * sampler.f2[k](x[1:k-1]) * x[k])'
end

@inline ∂k_MonotoneMap(sampler::PolynomialCouplingATM{d, <:Any, T}, x::PSDdata{T}, k::Int) where {d, T<:Number} = ∂k_MonotoneMap(sampler, x, k, sampler.coeff[k])
function ∂k_MonotoneMap(sampler::PolynomialCouplingATM{d, <:Any, T}, x::PSDdata{T}, k::Int, coeff) where {d, T<:Number}
if k==1
return 1.0
end
return sampler.g(dot(coeff[2, :], sampler.f2[k](x[1:k-1])))
end

function ∇∂k_MonotoneMap(sampler::PolynomialCouplingATM{d, <:Any, T}, x::PSDdata{T}, k::Int,
coeff::AbstractMatrix{T2}) where {d, T<:Number, T2<:Number}
if k==1
return hcat([0.0], [0.0])'
end
g_diff = FD.derivative(sampler.g, dot(coeff[2, :], sampler.f2[k](x[1:k-1])))
return hcat(zeros(size(coeff, 2)), g_diff * sampler.f2[k](x[1:k-1]))'
end

"""
Defined on [0, 1]^d
of type
\\int_{0}^{x_k} Φ(x_{1:k-1}, z)' A Φ(x_{1:k-1}, z) dz with A ⪰ 0, tr(A) = 1
"""
struct SoSATM{d, dC, T} <: AbstractTriangularMap{d, dC, T}
f::Vector{<:PSDModel{T}}
f_int::Vector{<:TraceModel{T}}
A_vec::Vector{<:Hermitian{T}}
variable_ordering::Vector{Int}
function SoSATM(f::Vector{<:PSDModel{T}}, variable_ordering::Vector{Int}) where {T}
d = length(f)
f_int = [integral(f[k], k) for k=1:length(f)]
A_vec = [f[k].B for k=1:length(f)]
new{d, 0, T}(f, f_int, A_vec, variable_ordering)
end
end

@inline MonotoneMap(sampler::SoSATM{d, <:Any, T}, x::PSDdata{T}, k::Int) where {d, T<:Number} = MonotoneMap(sampler, x, k, sampler.A_vec[k])
function MonotoneMap(sampler::SoSATM{d, <:Any, T}, x::PSDdata{T},
k::Int, coeff::AbstractMatrix{T2}) where {d, T<:Number, T2<:Number}
return sampler.f_int[k](x[1:k], coeff)
end

function ∇MonotoneMap(sampler::SoSATM{d, <:Any, T}, x::PSDdata{T}, k::Int) where {d, T<:Number}
return parameter_gradient(sampler.f_int[k], x[1:k])
end

@inline ∂k_MonotoneMap(sampler::SoSATM{d, <:Any, T}, x::PSDdata{T}, k::Int) where {d, T<:Number} = ∂k_MonotoneMap(sampler, x, k, sampler.A_vec[k])
function ∂k_MonotoneMap(sampler::SoSATM{d, <:Any, T}, x::PSDdata{T}, k::Int, coeff) where {d, T<:Number}
return sampler.f[k](x[1:k], coeff)
end

# function ML_fit!(sampler::ATM{d, <:Any, T}, X::PSDDataVector{T}) where {d, T<:Number}
# for k=1:d
# coeff_0 = sampler.coeff[k]
# min_func(coeff::Vector{T}) = begin
# (1/length(x)) * mapreduce(x->(0.5*MonotoneMap(sampler, x, k, coeff))^2 - log(∂k_MonotoneMap(sampler, x, k, coeff)), +, X)
# end
# sampler.coeff[k] = optimize(min_func, coeff_0, BFGS())
# end
# end
function ∇∂k_MonotoneMap(sampler::SoSATM{d, <:Any, T}, x::PSDdata{T}, k::Int) where {d, T<:Number}
return parameter_gradient(sampler.f[k], x[1:k])
end
8 changes: 6 additions & 2 deletions src/Samplers/triangular_maps/TriangularMap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,14 @@ end
function _pullback_first_n(sampler::AbstractTriangularMap{d, <:Any, T},
u::PSDdata{T},
n::Int) where {d, T<:Number}
x = zeros(T, n)
x = Vector{T}(undef, n)
u = @view u[sampler.variable_ordering[1:n]]
for k=1:n
x[k] = find_zero(z->MonotoneMap(sampler, [z; x[1:k-1]], k) - u[k], zero(T))
func(z) = begin
@inbounds x[k] = z
return MonotoneMap(sampler, x[1:k], k) - @inbounds u[k]
end
@inbounds x[k] = find_zero(func, zero(T))
end
return invpermute!(x, sampler.variable_ordering[1:n])
end
Expand Down
2 changes: 1 addition & 1 deletion src/SequentialMeasureTransport.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module SequentialMeasureTransport

using LinearAlgebra, SparseArrays
using LinearAlgebra, SparseArrays, StaticArrays
using KernelFunctions: Kernel, kernelmatrix
using DomainSets
using FastGaussQuadrature: gausslegendre
Expand Down
13 changes: 13 additions & 0 deletions src/TraceModels/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,17 @@ function (a::TraceModel{T})(x::PSDdata{T}, B::AbstractMatrix{T}) where {T<:Numbe
# return tr(B * M)
# tr(B * M) = tr(M * B) = dot(M', B) = dot(M, B) , but faster evaluation
return dot(M, B)
end

function (a::TraceModel{T})(x::PSDdata{T}, B::AbstractMatrix{T2}) where {T<:Number, T2<:Number}
M = ΦΦT(a, x)
# return tr(B * M)
# tr(B * M) = tr(M * B) = dot(M', B) = dot(M, B) , but faster evaluation
return dot(M, B)
end


function parameter_gradient(a::TraceModel{T}, x::PSDdata{T}) where {T<:Number}
M = ΦΦT(a, x)
return M
end
2 changes: 1 addition & 1 deletion src/functions/functions.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

# Tensorized functions
include("TensorFunction.jl")
include("tensor_functions/TensorFunction.jl")

# extra special functions needed
include("squared_polynomial.jl")
Expand Down
15 changes: 15 additions & 0 deletions src/functions/tensor_functions/MappedTensorFunction.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@


struct MappedTensorFunction{d , T , M<:ConditionalMapping{d, <:Any, T}, S<:Tensorizer{d}} <: TensorFunction{d, T, S}
tf::TensorFunction{d, T, S}
mapping::M
end

function (p::MappedTensorFunction{<:Any, T, M})(x::PSDdata{T}) where {T<:Number, M}
return p.tf(pushforward(p.mapping, x)) * Jacobian(p.mapping, x)
end

function (p::MappedTensorFunction{<:Any, T1, M})(x::PSDdata{T2}) where {T1<:Number, T2<:Number, M}
return p.tf(pushforward(p.mapping, x)) * Jacobian(p.mapping, x)
end

Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ include("tensorizers/Tensorizers.jl")

abstract type TensorFunction{d, T, S<:Tensorizer{d}} <: Function end

include("TensorPolynomial.jl")
include("TensorPolynomial.jl")
include("MappedTensorFunction.jl")
Loading

0 comments on commit a5428d6

Please sign in to comment.