-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow access to different nutpie backends via pip-style syntax #7498
base: main
Are you sure you want to change the base?
Conversation
I created some global variables to track available backends in |
I'm not sure how the other backends (specially JAX) behave with multiprocessing tbh :O Otherwise the idea sounds cool. Perhaps @aseyboldt can weigh in as he has a better picture of how we handle multiprocessing in pm.sample |
Should we go work on a |
The PyMC codebase would still need to know about it and work around differently for JAX |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7498 +/- ##
==========================================
- Coverage 92.85% 92.80% -0.06%
==========================================
Files 105 105
Lines 17591 17612 +21
==========================================
+ Hits 16335 16344 +9
- Misses 1256 1268 +12
|
With numba backend we could actually do it with threads (with nogil)? Maybe worth opening an issue to investigate different backends for pymc samplers. Also this shouldn't have to be nuts specific so for that a |
Should probably just be |
Not if we need to change how the samplers/threads are orchestrated |
But compile kwargs already works anyway |
What do you mean, |
It doesn't? Maybe I've played with global mode then |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jessegrabowski, this looks very nice. I left a few suggestions though
if match is None: | ||
return NUTPIE_DEFAULT_BACKEND |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could also get a None
match if the string is misformatted. For example, nutpie[jax
would return a None
match. I suggest that you test exact equality to set the default option, and if you get None
then raise a ValueError
.
if match is None: | |
return NUTPIE_DEFAULT_BACKEND | |
if string == "nutpie": | |
return NUTPIE_DEFAULT_BACKEND | |
elif match is None: | |
raise ValueError(f"Could not parse nutpie backend. Found {string!r}") |
expected = ( | ||
", ".join([f'"{x}"' for x in NUTPIE_BACKENDS[:-1]]) + f' or "{last_option}"' | ||
) | ||
raise ValueError(f'Expected one of {expected}; found "{result}"') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise ValueError(f'Expected one of {expected}; found "{result}"') | |
raise ValueError( | |
'Could not parse nutpie backend. Expected one of {expected}; found "{result}"' | |
) |
return m | ||
|
||
|
||
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) | |
@pytest.mark.parametrize( | |
"nuts_sampler", | |
["pymc", "nutpie", "nutpie[jax]", "blackjax", "numpyro"], | |
) |
|
||
def test_invalid_nutpie_backend_raises(pymc_model): | ||
pytest.importorskip("nutpie") | ||
with pytest.raises(ValueError, match='Expected one of "numba" or "jax"; found "invalid"'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with pytest.raises(ValueError, match='Expected one of "numba" or "jax"; found "invalid"'): | |
with pytest.raises( | |
ValueError, | |
match='Could not parse nutpie backend. Expected one of "numba" or "jax"; found "invalid"', | |
): |
with pytest.raises(ValueError, match='Expected one of "numba" or "jax"; found "invalid"'): | ||
with pymc_model: | ||
sample(nuts_sampler="nutpie[invalid]", random_seed=123, chains=2, tune=500, draws=500) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with pytest.raises(ValueError, match="Could not parse nutpie backend. Found 'nutpie[bad'"): | |
with pymc_model: | |
sample(nuts_sampler="nutpie[bad", random_seed=123, chains=2, tune=500, draws=500) |
I posted this PR with comments into GPT o1-mini:
---
SummaryThis pull request introduces a pip-style syntax for specifying different Nutpie backends within the Key Changes:
Detailed Review1. Code Changesa.
b.
c.
d. Documentation Updates
2. Testing Enhancements
3. GitHub Discussion & Iterative Improvements
ConclusionThis pull request effectively enhances PyMC's sampling functionality by introducing a flexible and user-friendly way to specify different Nutpie backends. The implementation is thoughtfully designed, with clear type definitions, robust error handling, and comprehensive testing to ensure reliability and maintainability. Recommendations Before Merging:
Overall, this PR represents a significant improvement to PyMC's flexibility and user experience, enabling more tailored and optimized sampling strategies. With the recommended refinements, it is well-positioned for successful integration into the main codebase. Approved with minor recommendations. |
Close this as stale now that #7535 is merged? Or is there interest in this syntax still. |
This syntax is a bit more ergonomic / discoverable |
763cd87
to
a5b3241
Compare
I like the syntax :-) |
Thoughts on the best way to ask for the gradient backend in this syntax? Or |
That will break people's existing code if they don't have jax installed |
This looks super helpful @jessegrabowski !! |
Description
Adds a pip-style syntax to the
nuts_sampler
argument that allows access to alternative compile backends, when relevant. This lets you get the nutpie jax backend by settingnuts_sampler='nutpie[jax]'
. For backwards compatibility,nuts_sampler='nutpie'
is equivalent tonuts_sampler='nutpie[numba]'
.The current PR only deals with nutpie, but we could easily extend this to include the default PyMC sampler, to compile to JAX, numba, or pytorch directly, without going through nutpie. I'm willing to do that extension in this PR if it is deemed worthwhile..
Related Issue
nutpie
compile backends throughpm.sample
#7497Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7498.org.readthedocs.build/en/7498/