Skip to content

Commit

Permalink
some fix in triangular map
Browse files Browse the repository at this point in the history
  • Loading branch information
benjione committed Oct 15, 2024
1 parent 5ba0454 commit 25eff3e
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 29 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
2 changes: 2 additions & 0 deletions ext/InvertibleNetworks/InvertibleNetworksExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ struct InvertibleNetworksMapping{d, dC, T, NT} <: SMT.ConditionalMapping{d, dC,
end
end

SMT.Sampler(network, forward, inverse, d, T; dC=0) = InvertibleNetworksMapping{d, dC, T}(network, forward, inverse)

function SMT.pushforward(m::InvertibleNetworksMapping{<:Any, <:Any, T}, x::SMT.PSDdata{T}) where {T <: Number}
y = m.forward(reshape(x, 1, 1, length(x), 1))[1]
return reshape(y, length(y))
Expand Down
98 changes: 73 additions & 25 deletions src/Samplers/triangular_maps/ATM.jl
Original file line number Diff line number Diff line change
@@ -1,54 +1,102 @@


struct ATM{d, dC, T<:Number} <: AbstractTriangularMap{d, dC, T}
# abstract type ATM{d, dC, T} <: AbstractTriangularMap{d, dC, T} end

struct PolynomialATM{d, dC, T<:Number} <: AbstractTriangularMap{d, dC, T}
f::Vector{<:FMTensorPolynomial{<:Any, T}}
coeff::Vector{<:Vector{T}}
g::Function
variable_ordering::Vector{Int}
function ATM(f::Vector{FMTensorPolynomial{<:Any, T}}, g::Function, variable_ordering::Vector{Int}, dC::Int) where {T<:Number}
function PolynomialATM(f::Vector{FMTensorPolynomial{<:Any, T}}, g::Function, variable_ordering::Vector{Int}, dC::Int) where {T<:Number}
d = length(f)
coeff = Vector{Vector{T}}(undef, d)
for k=1:d
coeff[k] = randn(T, length(f[k](rand(k))))
end
new{d, dC, T}(f, coeff, g, variable_ordering)
end
function ATM(f::Vector{FMTensorPolynomial}, g::Function, variable_ordering::Vector{Int})
ATM(f, g, variable_ordering, 0)
function PolynomialATM(f::Vector{FMTensorPolynomial}, g::Function, variable_ordering::Vector{Int})
PolynomialATM(f, g, variable_ordering, 0)
end
end


@inline MonotoneMap(sampler::ATM{d, <:Any, T}, x::PSDdata{T}, k::Int) where {d, T<:Number} = MonotoneMap(sampler, x, k, sampler.coeff[k])
function MonotoneMap(sampler::ATM{d, <:Any, T}, x::PSDdata{T}, k::Int, coeff) where {d, T<:Number}
f_part(z) = sampler.f[k]([x[1:k-1]; z])
int_x, int_w = gausslegendre(50)
int_x .= int_x * 0.5 .+ 0.5
int_w .= int_w * 0.5
@inline MonotoneMap(sampler::PolynomialATM{d, <:Any, T}, x::PSDdata{T}, k::Int) where {d, T<:Number} = MonotoneMap(sampler, x, k, sampler.coeff[k])
function MonotoneMap(sampler::PolynomialATM{d, <:Any, T}, x::PSDdata{T}, k::Int, coeff) where {d, T<:Number}
f_part(z) = begin
sampler.f[k]([x[1:k-1]; z])
end
f_partial(z) = FD.derivative(f_part, z)
int_f(z) = sampler.g(dot(coeff, f_partial(z)))
int_x, int_w = gausslegendre(100)
int_x .= int_x * 0.5 .+ 0.5
int_x .= int_x * x[k]
int_w .= int_w * 0.5
int_w .*= x[k]

_int_x = copy(int_x)
_int_x .= _int_x * x[k]
_int_w = int_w * x[k]


int_part = sum(int_w .* int_f.(int_x))
int_part = sum(_int_w .* int_f.(_int_x))
return dot(coeff, sampler.f[k]([x[1:k-1]; 0])) + int_part
end

@inline ∂k_MonotoneMap(sampler::ATM{d, <:Any, T}, x::PSDdata{T}, k::Int) where {d, T<:Number} = ∂k_MonotoneMap(sampler, x, k, sampler.coeff[k])
function ∂k_MonotoneMap(sampler::ATM{d, <:Any, T}, x::PSDdata{T}, k::Int, coeff) where {d, T<:Number}
@inline ∂k_MonotoneMap(sampler::PolynomialATM{d, <:Any, T}, x::PSDdata{T}, k::Int) where {d, T<:Number} = ∂k_MonotoneMap(sampler, x, k, sampler.coeff[k])
function ∂k_MonotoneMap(sampler::PolynomialATM{d, <:Any, T}, x::PSDdata{T}, k::Int, coeff) where {d, T<:Number}
f_part(z) = sampler.f[k]([x[1:k-1]; z])
f_partial(z) = FD.derivative(f_part, z)
int_f(z) = sampler.g(dot(coeff, f_partial(z)))
return int_f(x[k])
end

"""
Map of type
f1(x_{1:k-1}) + g(f2(x_{1:k-1})) * x_k
"""
struct PolynomialCouplingATM{d, dC, T} <: AbstractTriangularMap{d, dC, T}
f1::AbstractVector{<:FMTensorPolynomial{<:Any, T}}
f2::AbstractVector{<:FMTensorPolynomial{<:Any, T}}
coeff::AbstractVector{<:AbstractMatrix{T}}
g::Function
# poly_measure::Function
variable_ordering::Vector{Int}
end

function PolynomialCouplingATM(f1, f2, g, variable_ordering)
d = length(f1)
coeff = Vector{Matrix{Float64}}(undef, d)
coeff = [k==1 ? randn(2, 1) : randn(Float64, 2, length(f1[k](rand(k-1)))) for k=1:d]
# for k=1:d
# coeff[k] = randn(Float64, 2, length(f1[k](rand(k))))
# end
PolynomialCouplingATM{d, 0, Float64}(f1, f2, coeff, g, variable_ordering)
end

@inline MonotoneMap(sampler::PolynomialCouplingATM{d, <:Any, T}, x::PSDdata{T}, k::Int) where {d, T<:Number} = MonotoneMap(sampler, x, k, sampler.coeff[k])
function MonotoneMap(sampler::PolynomialCouplingATM{d, <:Any, T}, x::PSDdata{T},
k::Int, coeff::AbstractMatrix{T2}) where {d, T<:Number, T2<:Number}
if k==1
return coeff[1, 1] + x[1]
end
# print("here ", exp(-norm(x[1:k-1])^2/2), exp(-norm(x[1:k-1])^2/2) * dot(coeff[2, :], sampler.f2[k](x[1:k-1])))
return dot(coeff[1, :], sampler.f1[k](x[1:k-1])) +
(sampler.g(dot(coeff[2, :], sampler.f2[k](x[1:k-1]))) + T(1e-5)) * x[k]
end

@inline ∂k_MonotoneMap(sampler::PolynomialCouplingATM{d, <:Any, T}, x::PSDdata{T}, k::Int) where {d, T<:Number} = ∂k_MonotoneMap(sampler, x, k, sampler.coeff[k])
function ∂k_MonotoneMap(sampler::PolynomialCouplingATM{d, <:Any, T}, x::PSDdata{T}, k::Int, coeff) where {d, T<:Number}
if k==1
return 1.0
end
return (sampler.g(dot(coeff[2, :], sampler.f2[k](x[1:k-1]))) + T(1e-5))
end


"""
Defined on [0, 1]^d
of type
\\int_{0}^{x_k} Φ(x_{1:k-1}, z)' A Φ(x_{1:k-1}, z) dz with A ⪰ 0, tr(A) = 1
"""
struct SoSATM{d, dC, T} <: TriangularMap{d, dC, T}

# function ML_fit!(sampler::ATM{d, <:Any, T}, X::PSDDataVector{T}) where {d, T<:Number}
# for k=1:d
# coeff_0 = sampler.coeff[k]
# min_func(coeff::Vector{T}) = begin
# (1/length(x)) * mapreduce(x->(0.5*MonotoneMap(sampler, x, k, coeff))^2 - log(∂k_MonotoneMap(sampler, x, k, coeff)), +, X)
# end
# sampler.coeff[k] = optimize(min_func, coeff_0, BFGS())
# end
# end
end
8 changes: 6 additions & 2 deletions src/Samplers/triangular_maps/TriangularMap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,14 @@ end
function _pullback_first_n(sampler::AbstractTriangularMap{d, <:Any, T},
u::PSDdata{T},
n::Int) where {d, T<:Number}
x = zeros(T, n)
x = Vector{T}(undef, n)
u = @view u[sampler.variable_ordering[1:n]]
for k=1:n
x[k] = find_zero(z->MonotoneMap(sampler, [z; x[1:k-1]], k) - u[k], zero(T))
func(z) = begin
@inbounds x[k] = z
return MonotoneMap(sampler, x[1:k], k) - @inbounds u[k]
end
@inbounds x[k] = find_zero(func, zero(T))
end
return invpermute!(x, sampler.variable_ordering[1:n])
end
Expand Down
2 changes: 1 addition & 1 deletion src/SequentialMeasureTransport.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module SequentialMeasureTransport

using LinearAlgebra, SparseArrays
using LinearAlgebra, SparseArrays, StaticArrays
using KernelFunctions: Kernel, kernelmatrix
using DomainSets
using FastGaussQuadrature: gausslegendre
Expand Down
16 changes: 15 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,18 @@ unslice_matrix(A::Vector{Vector{T}}) where {T<:Number} = reduce(hcat, A)

## norm utilitys

nuclearnorm(A::AbstractMatrix) = tr(A)
nuclearnorm(A::AbstractMatrix) = tr(A)


### macro
macro _StaticArrayAppend(A, a::Int, b::Int, z)
str = "SA[ "

for i = a:1:b
str = "$str $A[$i], "
end
str = "$str $z]"
# return str
expr = Meta.parse(str)
return esc(expr)
end

0 comments on commit 25eff3e

Please sign in to comment.