Skip to content

Commit

Permalink
added pretty printing of models and automatic timestep choosing in di…
Browse files Browse the repository at this point in the history
…ffusion bridge
  • Loading branch information
benjione committed Oct 13, 2023
1 parent d8add37 commit 0448246
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/PSDModels/feature_map/polynomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ end

@inline _tensorizer(a::PSDModelPolynomial) = a.Φ.ten

## Pretty printing
function Base.show(io::IO, a::PSDModelPolynomial{d, T, S}) where {d, T, S}
println(io, "PSDModelPolynomial{d=$d, T=$T, S=$S}")
println(io, " matrix size: ", size(a.B))
println(io, " Φ: ", a.Φ)
end

domain_interval(a::PSDModelPolynomial{d, T}, k::Int) where {d, T<:Number} = begin
@assert 1 k d
Expand Down
7 changes: 7 additions & 0 deletions src/Samplers/PSDModelSampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ end

Sampler(model::PSDModelOrthonormal{d}) where {d} = PSDModelSampler(model, collect(1:d))

## Pretty printing
function Base.show(io::IO, sampler::PSDModelSampler{d, T, S, R}) where {d, T, S, R}
println(io, "PSDModelSampler{d=$d, T=$T, S=$S, R=$R}")
println(io, " model: ", sampler.model)
println(io, " order of variables: ", sampler.variable_ordering)
end

function Distributions.pdf(
sar::PSDModelSampler{d, T},
x::PSDdata{T}
Expand Down
15 changes: 15 additions & 0 deletions src/Samplers/SelfReinforcedSampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@ struct SelfReinforcedSampler{d, T, R} <: Sampler{d, T, R}
end
end

## Pretty printing
function Base.show(io::IO, sra::SelfReinforcedSampler{d, T}) where {d, T<:Number}
println(io, "SelfReinforcedSampler{d=$d, T=$T}")
println(io, " samplers:")
for (i, sampler) in enumerate(sra.samplers)
if i>3
println(io, "...")
break
end
println(io, " $i: $sampler")
end
println(io, " reference map: $(sra.R_map)")
end

## Overwrite pdf function from Distributions
function Distributions.pdf(
sar::SelfReinforcedSampler{d, T},
x::PSDdata{T}
Expand Down
21 changes: 21 additions & 0 deletions src/Samplers/bridging/diffusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,29 @@ struct DiffusionBrigdingDensity{d, T} <: BridgingDensity{d, T}
σ::T) where {d, T<:Number}
new{d, T}(target_density, t_vec, σ)
end
function DiffusionBrigdingDensity{d}(target_density::Function,
N::Int) where {d, T<:Number}
t_vec = choosing_timesteps(0.5, d, N)
new{d, T}(target_density, t_vec, 1.0)
end
end

function choosing_timesteps::T, d, N::Int) where { T<:Number}
@assert β > 1.0
## by Proposition 6
next_t(t_previous) = begin
return -0.5 * log(1.0 -
(1/β^(2/d)) * (1.0 - exp(-2.0 * t_previous)))
end
next_t() = return -0.5 * log(1.0 - (1/β^(2/d)))
t_vec = Vector{T}(undef, N)
t_vec[1] = next_t()
for i=2:(N-1)
t_vec[i] = next_t(t_vec[i-1])
end
t_vec[end] = 0.0
return t_vec
end

function evolve_samples(bridge::DiffusionBrigdingDensity{<:Any, T},
X::PSDDataVector{T},
Expand Down
3 changes: 3 additions & 0 deletions src/Samplers/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ while a mapping does not have any definition of a reference or target by itself.
abstract type Sampler{d, T, R} <: Mapping{d, T} end

Sampler(model::PSDModel) = @error "not implemented for this type of PSDModel"
function Base.show(io::IO, sampler::Sampler{d, T, R}) where {d, T, R}
println(io, "Sampler{d=$d, T=$T, R=$R}")
end
Distributions.pdf(sampler::Sampler, x::PSDdata) = @error "not implemented for this type of Sampler"


Expand Down
8 changes: 8 additions & 0 deletions src/functions/TensorPolynomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ end
@inline σ(p::FMTensorPolynomial{<:Any, <:Any, S}, i) where {S<:Tensorizer} = σ(p.ten, i)
@inline σ_inv(p::FMTensorPolynomial{<:Any, <:Any, S}, i) where {S<:Tensorizer} = σ_inv(p.ten, i)

## Pretty printing
Base.show(io::IO, p::FMTensorPolynomial{d, T, S, tsp}) where {d, T, S<:Tensorizer, tsp<:TensorSpace} = begin
println(io, "FMTensorPolynomial{d=$d, T=$T, ...}")
println(io, " space: ", p.space)
println(io, " highest order: ", p.highest_order)
println(io, " N: ", p.N)
end

function add_index(p::FMTensorPolynomial{d, T}, index::Vector{Int}) where {d, T}
ten = deepcopy(p.ten)
add_index!(ten, index)
Expand Down

0 comments on commit 0448246

Please sign in to comment.