Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Nested rhat MCMC diagnostic #752

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

gil2rok
Copy link
Contributor

@gil2rok gil2rok commented Oct 30, 2024

Implement nested R-hat for Markov chain Monte Carlo (MCMC) diagnostic.

The potential scale reduction factor, also known as R-hat, is a popular MCMC diagnostic from Gelman and Rubin.
R-hat detects convergence of MCMC chains by comparing within chain variance to between chain variance.

Nested r-hat from Margossian et al. better predicts convergence when running thousands of short chains on modern hardware. Nested r-hat uses superchains, collections of MCMC chains, and compares within and between chain and superchain variance.

I am seeking feedback on the code style + API design. The code is somewhat complicated by requiring input_array to have 4 dimensions -- num_superchains, num_chains, num_samples, and num_params -- where most users may expect only 3 (or 2). Tests are also still needed, as well as a brief doc explanation of the math.

Quick nit: why does the existing R-hat function return the potential scale factor after flattening along the sample and chain dimensions? I followed this convention for my implementation.

Addresses issue #278 .

@gil2rok
Copy link
Contributor Author

gil2rok commented Oct 30, 2024

@charlesm93: you may be interested in this.

(For everyone else: Charles is the author of the nested R-hat paper.)

@gil2rok
Copy link
Contributor Author

gil2rok commented Oct 30, 2024

Code style tests are failing because Flake8 is finding extra spaces around operators on line 43 of smc/resampling:

Screenshot 2024-10-30 at 4 22 19 PM

However I do not see what the problem is. Here's the line:

If anyone sees how to fix this issue, please let me know.

Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Great start. Let me know when you add some test.

NDArray of the resulting statistics (r-hat), with the chain and sample dimensions squeezed.

"""
assert input_array.ndim == 4, "The input array must have 4 dimensions."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should relax the ndim, as our input could have multiple dimensions of event shape (ie the random variable is non-scaler).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use keepdims=True and it should works.

@junpenglao
Copy link
Member

Quick nit: why does the existing R-hat function return the potential scale factor after flattening along the sample and chain dimensions? I followed this convention for my implementation.

The current r-hat function does not flatten the sample, but rather squeeze it, so if you have a random variable with shape=(2, 5), the output result could be
shape=(1, 1, 2, 5) or (1, 2, 5, 1)
doing a squzze makes it return rhat the same shape as the random variable.

@gil2rok
Copy link
Contributor Author

gil2rok commented Oct 31, 2024

Thanks so much for the quick feedback! I'll continue working on this next week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants