-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Nested rhat MCMC diagnostic #752
base: main
Are you sure you want to change the base?
Conversation
@charlesm93: you may be interested in this. (For everyone else: Charles is the author of the nested R-hat paper.) |
Code style tests are failing because Flake8 is finding extra spaces around operators on line 43 of smc/resampling: However I do not see what the problem is. Here's the line: blackjax/blackjax/smc/resampling.py Line 43 in 65ae00e
If anyone sees how to fix this issue, please let me know. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Great start. Let me know when you add some test.
NDArray of the resulting statistics (r-hat), with the chain and sample dimensions squeezed. | ||
|
||
""" | ||
assert input_array.ndim == 4, "The input array must have 4 dimensions." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should relax the ndim, as our input could have multiple dimensions of event shape (ie the random variable is non-scaler).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use keepdims=True
and it should works.
The current r-hat function does not flatten the sample, but rather squeeze it, so if you have a random variable with shape=(2, 5), the output result could be |
Thanks so much for the quick feedback! I'll continue working on this next week. |
Implement nested R-hat for Markov chain Monte Carlo (MCMC) diagnostic.
The potential scale reduction factor, also known as R-hat, is a popular MCMC diagnostic from Gelman and Rubin.
R-hat detects convergence of MCMC chains by comparing within chain variance to between chain variance.
Nested r-hat from Margossian et al. better predicts convergence when running thousands of short chains on modern hardware. Nested r-hat uses superchains, collections of MCMC chains, and compares within and between chain and superchain variance.
I am seeking feedback on the code style + API design. The code is somewhat complicated by requiring
input_array
to have 4 dimensions --num_superchains
,num_chains
,num_samples
, andnum_params
-- where most users may expect only 3 (or 2). Tests are also still needed, as well as a brief doc explanation of the math.Quick nit: why does the existing R-hat function return the potential scale factor after flattening along the sample and chain dimensions? I followed this convention for my implementation.
Addresses issue #278 .