Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import and extend PosteriorStats #431

Open
wants to merge 62 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
43a5630
Add PosteriorStats as dependency
sethaxen Aug 20, 2023
2994a56
Import and reexport PosteriorStats functions
sethaxen Aug 20, 2023
9ffb60b
Forward to PosteriorStats.summarize
sethaxen Aug 20, 2023
a6d16e5
Update docstring
sethaxen Aug 20, 2023
4e99fe8
Update docstring
sethaxen Aug 20, 2023
38dbdac
Forward summarystats to summarize
sethaxen Aug 20, 2023
395a3d8
Simplify mean implementation
sethaxen Aug 20, 2023
4fa204e
Simplify quantile implementation
sethaxen Aug 20, 2023
5c0f35d
Replace hpd with hdi
sethaxen Aug 20, 2023
3660d0e
Deprecate hpd
sethaxen Aug 20, 2023
e3d2d16
Simplify autocor implementation
sethaxen Aug 20, 2023
49af2d9
Remove unused keyword `etype`
sethaxen Aug 20, 2023
bdde660
Explicitly build list of stats
sethaxen Aug 20, 2023
d851307
Simultaneously compute all quantiles
sethaxen Aug 20, 2023
1717483
Print an extra newline
sethaxen Aug 20, 2023
bf06653
Use and export SummaryStats
sethaxen Aug 20, 2023
147b56b
Use SummaryStats in place of ChainsDataFrame
sethaxen Aug 20, 2023
706f29c
Update and repair changerate
sethaxen Aug 20, 2023
2c07794
Remove ChainDataFrame
sethaxen Aug 20, 2023
58ae52a
Update docs
sethaxen Aug 20, 2023
340a694
Increment major version
sethaxen Aug 20, 2023
a67ac11
Increment MCMCChains compat for docs
sethaxen Aug 20, 2023
1af83c8
Refer to processed chains
sethaxen Aug 20, 2023
4aed409
Fix doctest
sethaxen Aug 20, 2023
eef9393
Add back append_chains keyword
sethaxen Aug 20, 2023
6a1b482
Compute all lags simultaneously
sethaxen Aug 20, 2023
c7d2021
Vectorize before autocor
sethaxen Aug 20, 2023
bfe8deb
Correctly insert chain id into name
sethaxen Aug 20, 2023
646e108
Update diagnostic tests
sethaxen Aug 20, 2023
8066b46
Update ess_rhat_tests.jl
sethaxen Aug 20, 2023
34b1da2
Update mcse_tests.jl
sethaxen Aug 20, 2023
5b0db14
Increment MCMCChains compat for tests
sethaxen Aug 20, 2023
fb04e61
Remove ChainDataFrames to tables tests
sethaxen Aug 20, 2023
faf667b
Update summarize_tests.jl
sethaxen Aug 20, 2023
fba84a2
Use crossreference
sethaxen Aug 20, 2023
d51dd65
Update plotting functions
sethaxen Aug 20, 2023
5d13530
Update to use hdi and hdi_prob
sethaxen Aug 20, 2023
57a2921
Fix missing tests
sethaxen Aug 20, 2023
fc96b86
Remove references not defined here.
sethaxen Aug 20, 2023
cbd06d2
Improve vertical spacing
sethaxen Aug 20, 2023
f2bdc68
Bump PosteriorStats compat
sethaxen Aug 20, 2023
313a786
Merge branch 'master' into posteriorstats
sethaxen Oct 24, 2023
93f6ae9
Merge branch 'master' into posteriorstats
sethaxen Dec 24, 2023
c19b134
Bump PosteriorStats compat
sethaxen Dec 24, 2023
02be9c7
Make stack available for older Julia versions
sethaxen Dec 24, 2023
ea2548e
Update SummaryStats constructor user
sethaxen Dec 24, 2023
929c654
Use dict backing for cor summary
sethaxen Dec 24, 2023
5f7aa7c
Remove unused kwargs
sethaxen Dec 24, 2023
49fe1c8
Refactor autocor to avoid large namedtuple
sethaxen Dec 24, 2023
1cef3ba
Use stack for autocor to avoid large compile times
sethaxen Dec 24, 2023
17e08ed
Improve type inference of OrderedDict
sethaxen Dec 24, 2023
1e3e96a
Make doctest reproducible
sethaxen Dec 24, 2023
d001894
Add StableRNGs as test dependency
sethaxen Dec 24, 2023
d6bb4ee
Merge branch 'master' into posteriorstats
sethaxen Feb 6, 2024
b3a3143
Apply suggestions from code review
sethaxen Feb 10, 2024
67f7eaa
Merge branch 'posteriorstats' of https://github.com/TuringLang/MCMCCh…
sethaxen Feb 10, 2024
1fbe8ad
Merge branch 'master' into posteriorstats
sethaxen Feb 11, 2024
0fad7f9
Avoid splatting
sethaxen Feb 11, 2024
3579c55
Avoid recomputing all medians for every parameter
sethaxen Feb 11, 2024
5332916
Add test for show method
sethaxen Feb 11, 2024
d61c535
Merge branch 'master' into posteriorstats
sethaxen Nov 17, 2024
d488220
Update ess_rhat tests
sethaxen Nov 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "Chain types and utility functions for MCMC simulations."
version = "6.0.6"
version = "7.0.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
Expand All @@ -17,6 +18,7 @@ MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PosteriorStats = "7f36be82-ad55-44ba-a5c0-b8b5480d7aa5"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Expand All @@ -29,6 +31,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
[compat]
AbstractMCMC = "0.4, 0.5, 1.0, 2.0, 3.0, 4, 5"
AxisArrays = "0.4.4"
Compat = "4.2.0"
Dates = "<0.0.1, 1"
Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
IteratorInterfaceExtensions = "0.1.1, 1"
Expand All @@ -38,6 +41,7 @@ MCMCDiagnosticTools = "0.3"
MLJModelInterface = "0.3.5, 0.4, 1.0"
NaturalSort = "1"
OrderedCollections = "1.4"
PosteriorStats = "0.2"
PrettyTables = "0.9, 0.10, 0.11, 0.12, 1, 2"
Random = "<0.0.1, 1"
RecipesBase = "0.7, 0.8, 1.0"
Expand Down
4 changes: 3 additions & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Gadfly = "c91e804a-d5a3-530f-b6f0-dfbca275c004"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"

[compat]
Expand All @@ -15,8 +16,9 @@ CategoricalArrays = "0.8, 0.9, 0.10"
DataFrames = "0.22, 1"
Documenter = "0.26, 0.27, 1"
Gadfly = "1.3.4"
MCMCChains = "6"
MCMCChains = "7"
MLJBase = "0.19, 0.20, 0.21, 1"
MLJDecisionTreeInterface = "0.3, 0.4"
StableRNGs = "1"
StatsPlots = "0.14, 0.15"
julia = "1.7"
2 changes: 1 addition & 1 deletion docs/src/stats.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ describe
mean
summarystats
quantile
hpd
hdi
```
5 changes: 2 additions & 3 deletions docs/src/summarize.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

The methods listed below are defined in `src/summarize.jl`.

```@autodocs
Modules = [MCMCChains]
Pages = ["summarize.jl"]
```@docs
summarize
```
9 changes: 6 additions & 3 deletions src/MCMCChains.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module MCMCChains

using Compat: stack
using AxisArrays
const axes = Base.axes
import AbstractMCMC
Expand All @@ -15,6 +16,7 @@ import MCMCDiagnosticTools
import MLJModelInterface
import NaturalSort
import OrderedCollections
import PosteriorStats
import PrettyTables
import StatsFuns
import Tables
Expand All @@ -30,8 +32,6 @@ export setrange, resetrange
export set_section, get_params, sections, sort_sections, setinfo
export replacenames, namesingroup, group
export autocor, describe, sample, summarystats, AbstractWeights, mean, quantile
export ChainDataFrame
export summarize

# Reexport diagnostics functions
using MCMCDiagnosticTools: discretediag, ess, ess_rhat, AutocovMethod, FFTAutocovMethod,
Expand All @@ -46,7 +46,10 @@ export mcse
export rafterydiag
export rstar

export hpd
# Reexport stats functions
using PosteriorStats: SummaryStats, default_diagnostics, default_stats,
default_summary_stats, hdi, summarize
export SummaryStats, hdi, summarize

"""
Chains
Expand Down
7 changes: 5 additions & 2 deletions src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,14 @@ function Base.show(io::IO, chains::Chains)
end

function Base.show(io::IO, mime::MIME"text/plain", chains::Chains)
print(io, "Chains ", chains, ":\n\n", header(chains))
println(io, "Chains ", chains, ":\n\n", header(chains))
sethaxen marked this conversation as resolved.
Show resolved Hide resolved

# Show summary stats.
summaries = describe(chains)
for summary in summaries
summary, others = Iterators.peel(summaries)
show(io, mime, summary)
for summary in others
println(io)
println(io)
show(io, mime, summary)
end
Expand Down
16 changes: 8 additions & 8 deletions src/discretediag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ function MCMCDiagnosticTools.discretediag(
_permutedims_diagnostics(_chains.value.data); kwargs...
)

# Create dataframes
parameters = (parameters = names(_chains),)
between_chain_df = ChainDataFrame(
"Chisq diagnostic - Between chains", merge(parameters, between_chain_vals),
# Create SummaryStats
param_names = names(_chains)
between_chain_stats = SummaryStats(
"Chisq diagnostic - Between chains", between_chain_vals, param_names,
)
within_chain_dfs = map(1:size(_chains, 3)) do i
within_chain_stats = map(1:size(_chains, 3)) do i
vals = map(val -> val[:, i], within_chain_vals)
return ChainDataFrame("Chisq diagnostic - Chain $i", merge(parameters, vals))
return SummaryStats("Chisq diagnostic - Chain $i", vals, param_names)
end
dfs = vcat(between_chain_df, within_chain_dfs)
stats = vcat([between_chain_stats], within_chain_stats)

return dfs
return stats
end
12 changes: 6 additions & 6 deletions src/ess_rhat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ function MCMCDiagnosticTools.ess(

# Convert to NamedTuple
ess_per_sec = ess ./ dur
nt = merge((parameters = names(_chains),), (; ess, ess_per_sec))
nt = (; ess, ess_per_sec)

return ChainDataFrame("ESS", nt)
return SummaryStats("ESS", nt, names(_chains))
end

"""
Expand All @@ -48,9 +48,9 @@ function MCMCDiagnosticTools.rhat(
)

# Convert to NamedTuple
nt = merge((parameters = names(_chains),), (; rhat))
nt = (; rhat)

return ChainDataFrame("R-hat", nt)
return SummaryStats("R-hat", nt, names(_chains))
end

"""
Expand Down Expand Up @@ -79,7 +79,7 @@ function MCMCDiagnosticTools.ess_rhat(

# Convert to NamedTuple
ess_per_sec = ess_rhat.ess ./ dur
nt = merge((parameters = names(_chains),), ess_rhat, (; ess_per_sec))
nt = merge(ess_rhat, (; ess_per_sec))

return ChainDataFrame("ESS/R-hat", nt)
return SummaryStats("ESS/R-hat", nt, names(_chains))
end
16 changes: 9 additions & 7 deletions src/gelmandiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ function MCMCDiagnosticTools.gelmandiag(
results = MCMCDiagnosticTools.gelmandiag(_permutedims_diagnostics(psi); kwargs...)

# Create a data frame with the results.
df = ChainDataFrame(
stats = SummaryStats(
"Gelman, Rubin, and Brooks diagnostic",
merge((parameters = names(_chains),), results),
results,
names(_chains),
)

return df
return stats
end

function MCMCDiagnosticTools.gelmandiag_multivariate(
Expand All @@ -36,11 +37,12 @@ function MCMCDiagnosticTools.gelmandiag_multivariate(
kwargs...,
)

# Create a data frame with the results.
df = ChainDataFrame(
# Create SummaryStats with the results.
stats = SummaryStats(
"Gelman, Rubin, and Brooks diagnostic",
(parameters = names(_chains), psrf = results.psrf, psrfci = results.psrfci),
(psrf = results.psrf, psrfci = results.psrfci),
names(_chains),
)

return df, results.psrfmultivariate
return stats, results.psrfmultivariate
end
9 changes: 4 additions & 5 deletions src/gewekediag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@ function MCMCDiagnosticTools.gewekediag(
return namedtuple_of_vecs
end

# Create data frames.
parameters = (parameters = names(_chains),)
dfs = [
ChainDataFrame("Geweke diagnostic - Chain $i", merge(parameters, result))
# Create SummaryStats.
stats = [
SummaryStats("Geweke diagnostic - Chain $i", result, names(_chains))
for (i, result) in enumerate(results)
]

return dfs
return stats
end
11 changes: 5 additions & 6 deletions src/heideldiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@ function MCMCDiagnosticTools.heideldiag(
return namedtuple_of_vecs
end

# Create data frames.
parameters = (parameters = names(_chains),)
dfs = [
ChainDataFrame(
"Heidelberger and Welch diagnostic - Chain $i", merge(parameters, result)
# Create SummaryStats.
stats = [
SummaryStats(
"Heidelberger and Welch diagnostic - Chain $i", result, names(_chains),
)
for (i, result) in enumerate(results)
]

return dfs
return stats
end
4 changes: 2 additions & 2 deletions src/mcse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function MCMCDiagnosticTools.mcse(
kwargs...,
)

nt = merge((parameters = names(_chains),), (; mcse))
nt = (; mcse)

return ChainDataFrame("MCSE", nt)
return SummaryStats("MCSE", nt, names(_chains))
end
Loading
Loading