-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow access to different nutpie backends via pip-style syntax #7498
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -17,17 +17,13 @@ | |||||||||
import contextlib | ||||||||||
import logging | ||||||||||
import pickle | ||||||||||
import re | ||||||||||
import sys | ||||||||||
import time | ||||||||||
import warnings | ||||||||||
|
||||||||||
from collections.abc import Callable, Iterator, Mapping, Sequence | ||||||||||
from typing import ( | ||||||||||
Any, | ||||||||||
Literal, | ||||||||||
TypeAlias, | ||||||||||
overload, | ||||||||||
) | ||||||||||
from typing import Any, Literal, TypeAlias, cast, get_args, overload | ||||||||||
|
||||||||||
import numpy as np | ||||||||||
import pytensor.gradient as tg | ||||||||||
|
@@ -86,6 +82,14 @@ | |||||||||
|
||||||||||
Step: TypeAlias = BlockedStep | CompoundStep | ||||||||||
|
||||||||||
ExternalNutsSampler = Literal["nutpie", "numpyro", "blackjax"] | ||||||||||
NutsSampler = Literal["pymc"] | ExternalNutsSampler | ||||||||||
NutpieBackend = Literal["numba", "jax"] | ||||||||||
|
||||||||||
|
||||||||||
NUTPIE_BACKENDS = get_args(NutpieBackend) | ||||||||||
NUTPIE_DEFAULT_BACKEND = cast(NutpieBackend, "numba") | ||||||||||
|
||||||||||
|
||||||||||
class SamplingIteratorCallback(Protocol): | ||||||||||
"""Signature of the callable that may be passed to `pm.sample(callable=...)`.""" | ||||||||||
|
@@ -262,7 +266,7 @@ | |||||||||
|
||||||||||
|
||||||||||
def _sample_external_nuts( | ||||||||||
sampler: Literal["nutpie", "numpyro", "blackjax"], | ||||||||||
sampler: ExternalNutsSampler, | ||||||||||
draws: int, | ||||||||||
tune: int, | ||||||||||
chains: int, | ||||||||||
|
@@ -280,7 +284,7 @@ | |||||||||
if nuts_sampler_kwargs is None: | ||||||||||
nuts_sampler_kwargs = {} | ||||||||||
|
||||||||||
if sampler == "nutpie": | ||||||||||
if sampler.startswith("nutpie"): | ||||||||||
try: | ||||||||||
import nutpie | ||||||||||
except ImportError as err: | ||||||||||
|
@@ -313,6 +317,23 @@ | |||||||||
model, | ||||||||||
**compile_kwargs, | ||||||||||
) | ||||||||||
|
||||||||||
def extract_backend(string: str) -> NutpieBackend: | ||||||||||
match = re.search(r"(?<=\[)[^\]]+(?=\])", string) | ||||||||||
if match is None: | ||||||||||
return NUTPIE_DEFAULT_BACKEND | ||||||||||
result = cast(NutpieBackend, match.group(0)) | ||||||||||
if result not in NUTPIE_BACKENDS: | ||||||||||
last_option = f"{NUTPIE_BACKENDS[-1]}" | ||||||||||
expected = ( | ||||||||||
", ".join([f'"{x}"' for x in NUTPIE_BACKENDS[:-1]]) + f' or "{last_option}"' | ||||||||||
) | ||||||||||
raise ValueError(f'Expected one of {expected}; found "{result}"') | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
return result | ||||||||||
|
||||||||||
backend = extract_backend(sampler) | ||||||||||
compiled_model = nutpie.compile_pymc_model(model, backend=backend) | ||||||||||
|
||||||||||
t_start = time.time() | ||||||||||
idata = nutpie.sample( | ||||||||||
compiled_model, | ||||||||||
|
@@ -361,6 +382,10 @@ | |||||||||
elif sampler in ("numpyro", "blackjax"): | ||||||||||
import pymc.sampling.jax as pymc_jax | ||||||||||
|
||||||||||
from pymc.sampling.jax import JaxNutsSampler | ||||||||||
|
||||||||||
sampler = cast(JaxNutsSampler, sampler) | ||||||||||
|
||||||||||
idata = pymc_jax.sample_jax_nuts( | ||||||||||
draws=draws, | ||||||||||
tune=tune, | ||||||||||
|
@@ -396,7 +421,7 @@ | |||||||||
progressbar_theme: Theme | None = default_progress_theme, | ||||||||||
step=None, | ||||||||||
var_names: Sequence[str] | None = None, | ||||||||||
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", | ||||||||||
nuts_sampler: NutsSampler = "pymc", | ||||||||||
initvals: StartDict | Sequence[StartDict | None] | None = None, | ||||||||||
init: str = "auto", | ||||||||||
jitter_max_retries: int = 10, | ||||||||||
|
@@ -427,7 +452,7 @@ | |||||||||
progressbar_theme: Theme | None = default_progress_theme, | ||||||||||
step=None, | ||||||||||
var_names: Sequence[str] | None = None, | ||||||||||
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", | ||||||||||
nuts_sampler: NutsSampler = "pymc", | ||||||||||
initvals: StartDict | Sequence[StartDict | None] | None = None, | ||||||||||
init: str = "auto", | ||||||||||
jitter_max_retries: int = 10, | ||||||||||
|
@@ -458,7 +483,7 @@ | |||||||||
progressbar_theme: Theme | None = default_progress_theme, | ||||||||||
step=None, | ||||||||||
var_names: Sequence[str] | None = None, | ||||||||||
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", | ||||||||||
nuts_sampler: NutsSampler = "pymc", | ||||||||||
initvals: StartDict | Sequence[StartDict | None] | None = None, | ||||||||||
init: str = "auto", | ||||||||||
jitter_max_retries: int = 10, | ||||||||||
|
@@ -517,8 +542,10 @@ | |||||||||
method will be used, if appropriate to the model. | ||||||||||
var_names : list of str, optional | ||||||||||
Names of variables to be stored in the trace. Defaults to all free variables and deterministics. | ||||||||||
nuts_sampler : str | ||||||||||
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"]. | ||||||||||
nuts_sampler : str, default "pymc" | ||||||||||
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"]. In addition, the compilation | ||||||||||
backend for the chosen sampler can be set using square brackets, if available. For example, "nutpie[jax]" will | ||||||||||
use the JAX backend for the nutpie sampler. Currently, "nutpie[jax]" and "nutpie[numba]" are allowed. | ||||||||||
This requires the chosen sampler to be installed. | ||||||||||
All samplers, except "pymc", require the full model to be continuous. | ||||||||||
blas_cores: int or "auto" or None, default = "auto" | ||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -16,45 +16,19 @@ | |||||||||||
import numpy.testing as npt | ||||||||||||
import pytest | ||||||||||||
|
||||||||||||
from pymc import Data, Model, Normal, sample | ||||||||||||
from pymc import Data, Model, Normal, modelcontext, sample | ||||||||||||
|
||||||||||||
|
||||||||||||
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) | ||||||||||||
def test_external_nuts_sampler(recwarn, nuts_sampler): | ||||||||||||
if nuts_sampler != "pymc": | ||||||||||||
pytest.importorskip(nuts_sampler) | ||||||||||||
|
||||||||||||
with Model(): | ||||||||||||
x = Normal("x", 100, 5) | ||||||||||||
y = Data("y", [1, 2, 3, 4]) | ||||||||||||
Data("z", [100, 190, 310, 405]) | ||||||||||||
|
||||||||||||
Normal("L", mu=x, sigma=0.1, observed=y) | ||||||||||||
|
||||||||||||
kwargs = { | ||||||||||||
"nuts_sampler": nuts_sampler, | ||||||||||||
"random_seed": 123, | ||||||||||||
"chains": 2, | ||||||||||||
"tune": 500, | ||||||||||||
"draws": 500, | ||||||||||||
"progressbar": False, | ||||||||||||
"initvals": {"x": 0.0}, | ||||||||||||
} | ||||||||||||
|
||||||||||||
idata1 = sample(**kwargs) | ||||||||||||
idata2 = sample(**kwargs) | ||||||||||||
def check_external_sampler_output(warns, idata1, idata2, sample_kwargs): | ||||||||||||
nuts_sampler = sample_kwargs["nuts_sampler"] | ||||||||||||
reference_kwargs = sample_kwargs.copy() | ||||||||||||
reference_kwargs["nuts_sampler"] = "pymc" | ||||||||||||
|
||||||||||||
reference_kwargs = kwargs.copy() | ||||||||||||
reference_kwargs["nuts_sampler"] = "pymc" | ||||||||||||
with modelcontext(None): | ||||||||||||
idata_reference = sample(**reference_kwargs) | ||||||||||||
|
||||||||||||
warns = { | ||||||||||||
(warn.category, warn.message.args[0]) | ||||||||||||
for warn in recwarn | ||||||||||||
if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning) | ||||||||||||
} | ||||||||||||
expected = set() | ||||||||||||
if nuts_sampler == "nutpie": | ||||||||||||
if nuts_sampler.startswith("nutpie"): | ||||||||||||
expected.add( | ||||||||||||
( | ||||||||||||
UserWarning, | ||||||||||||
|
@@ -74,7 +48,84 @@ def test_external_nuts_sampler(recwarn, nuts_sampler): | |||||||||||
assert idata_reference.posterior.attrs.keys() == idata1.posterior.attrs.keys() | ||||||||||||
|
||||||||||||
|
||||||||||||
@pytest.fixture | ||||||||||||
def pymc_model(): | ||||||||||||
with Model() as m: | ||||||||||||
x = Normal("x", 100, 5) | ||||||||||||
y = Data("y", [1, 2, 3, 4]) | ||||||||||||
Data("z", [100, 190, 310, 405]) | ||||||||||||
|
||||||||||||
Normal("L", mu=x, sigma=0.1, observed=y) | ||||||||||||
|
||||||||||||
return m | ||||||||||||
|
||||||||||||
|
||||||||||||
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
def test_external_nuts_sampler(pymc_model, recwarn, nuts_sampler): | ||||||||||||
if nuts_sampler != "pymc": | ||||||||||||
pytest.importorskip(nuts_sampler) | ||||||||||||
|
||||||||||||
sample_kwargs = dict( | ||||||||||||
nuts_sampler=nuts_sampler, | ||||||||||||
random_seed=123, | ||||||||||||
chains=2, | ||||||||||||
tune=500, | ||||||||||||
draws=500, | ||||||||||||
progressbar=False, | ||||||||||||
initvals={"x": 0.0}, | ||||||||||||
) | ||||||||||||
|
||||||||||||
with pymc_model: | ||||||||||||
idata1 = sample(**sample_kwargs) | ||||||||||||
idata2 = sample(**sample_kwargs) | ||||||||||||
|
||||||||||||
warns = { | ||||||||||||
(warn.category, warn.message.args[0]) | ||||||||||||
for warn in recwarn | ||||||||||||
if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning) | ||||||||||||
} | ||||||||||||
|
||||||||||||
check_external_sampler_output(warns, idata1, idata2, sample_kwargs) | ||||||||||||
|
||||||||||||
|
||||||||||||
@pytest.mark.parametrize("backend", ["numba", "jax"], ids=["numba", "jax"]) | ||||||||||||
def test_numba_backend_options(pymc_model, recwarn, backend): | ||||||||||||
pytest.importorskip("nutpie") | ||||||||||||
pytest.importorskip(backend) | ||||||||||||
|
||||||||||||
sample_kwargs = dict( | ||||||||||||
nuts_sampler=f"nutpie[{backend}]", | ||||||||||||
random_seed=123, | ||||||||||||
chains=2, | ||||||||||||
tune=500, | ||||||||||||
draws=500, | ||||||||||||
progressbar=False, | ||||||||||||
initvals={"x": 0.0}, | ||||||||||||
) | ||||||||||||
|
||||||||||||
with pymc_model: | ||||||||||||
idata1 = sample(**sample_kwargs) | ||||||||||||
idata2 = sample(**sample_kwargs) | ||||||||||||
|
||||||||||||
warns = { | ||||||||||||
(warn.category, warn.message.args[0]) | ||||||||||||
for warn in recwarn | ||||||||||||
if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning) | ||||||||||||
} | ||||||||||||
|
||||||||||||
check_external_sampler_output(warns, idata1, idata2, sample_kwargs) | ||||||||||||
|
||||||||||||
|
||||||||||||
def test_invalid_nutpie_backend_raises(pymc_model): | ||||||||||||
pytest.importorskip("nutpie") | ||||||||||||
with pytest.raises(ValueError, match='Expected one of "numba" or "jax"; found "invalid"'): | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
with pymc_model: | ||||||||||||
sample(nuts_sampler="nutpie[invalid]", random_seed=123, chains=2, tune=500, draws=500) | ||||||||||||
|
||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
|
||||||||||||
def test_step_args(): | ||||||||||||
pytest.importorskip("numpyro") | ||||||||||||
|
||||||||||||
with Model() as model: | ||||||||||||
a = Normal("a") | ||||||||||||
idata = sample( | ||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could also get a
None
match if the string is misformatted. For example,nutpie[jax
would return aNone
match. I suggest that you test exact equality to set the default option, and if you getNone
then raise aValueError
.