-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate progress bar from fastprogress to tqdm #655
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #655 +/- ##
==========================================
- Coverage 98.87% 98.80% -0.07%
==========================================
Files 59 59
Lines 2745 2752 +7
==========================================
+ Hits 2714 2719 +5
- Misses 31 33 +2 ☔ View full report in Codecov by Sentry. |
@@ -14,70 +14,83 @@ | |||
"""Progress bar decorators for use with step functions. | |||
Adapted from Jeremie Coullon's blog post :cite:p:`progress_bar`. | |||
""" | |||
from fastprogress.fastprogress import progress_bar | |||
import jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from jax.debug import callback
?
def progress_bar_scan(num_samples, print_rate=None): | ||
"Progress bar for a JAX scan" | ||
progress_bars = {} | ||
def progress_bar_scan(num_samples, num_chains=1, print_rate=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC in the usage we need to specify the num_chains
for pmap to work properly. Could you explain a bit more how you are planning to change the API for downstream application so that part works?
Line 198 in 7cf4f9d
one_step = progress_bar_scan(num_steps)(_one_step) |
one_step_ = jax.jit(progress_bar_scan(num_steps)(one_step)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not fully committed to this API, but I was thinking something where along with passing an array of iteration numbers, you also pass in the chain you are currently in. I think this is better than the numpyro design where you are using regexes on device objects to guess what chain to put the computation on.
def inference_loop(rng_key, kernel, initial_state, chain, num_samples, num_chains):
def _one_step(state, xs):
_, _, rng_key = xs
state, _ = kernel(rng_key, state)
return state, state
one_step = jax.jit(progress_bar_factory(num_samples, num_chains)(_one_step))
keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(
one_step,
initial_state,
(np.arange(num_samples), chain * np.ones(num_samples), keys),
)
return states
inference_loop_multiple_chains = jax.pmap(
inference_loop,
in_axes=(0, None, 0, 0, None, None),
static_broadcasted_argnums=(1, 4, 5),
devices=jax.devices(),
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For downstream applications that don't use multiple chains, I have included logic to maintain backward compatibility. Though I'm not sure how actual code is implementing progress bars for multiple chains today.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, could you share a small jupyter notebook how it looks like?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a self-contained example
https://gist.github.com/zaxtax/5fd7c881c6ac83a7ca2798d0a7e230b7
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, that's very helpful. Let me think about it a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfectly happy to rework the API. This is was an attempt to make something simple and backwards compatible.
Do we still want to migrate to tqdm?
…On Fri, 27 Sept 2024, 09:55 Junpeng Lao, ***@***.***> wrote:
@zaxtax <https://github.com/zaxtax> should we update this after #712
<#712> and get it merge?
—
Reply to this email directly, view it on GitHub
<#655 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCUOYVZRJOXX7BDEYNFDZYUFNPAVCNFSM6AAAAABO6TKKOCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGNZYGY2DIMRZHA>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
I forgot the reason why we were doing this beside the pmap bug (which is now fixed) |
Moves to use tqdm along with adding support for multiple progress bars
Makes blackjax suitable for running multiple chains in parallel.