Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow access to different nutpie backends via pip-style syntax #7498

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@
"sample_numpyro_nuts",
)

JaxNutsSampler = Literal["numpyro", "blackjax"]


@jax_funcify.register(Assert)
@jax_funcify.register(CheckParameterValue)
Expand Down Expand Up @@ -486,7 +488,7 @@ def sample_jax_nuts(
postprocessing_chunks=None,
idata_kwargs: dict | None = None,
compute_convergence_checks: bool = True,
nuts_sampler: Literal["numpyro", "blackjax"],
nuts_sampler: JaxNutsSampler,
) -> az.InferenceData:
"""
Draw samples from the posterior using a jax NUTS method.
Expand Down
53 changes: 40 additions & 13 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,13 @@
import contextlib
import logging
import pickle
import re
import sys
import time
import warnings

from collections.abc import Callable, Iterator, Mapping, Sequence
from typing import (
Any,
Literal,
TypeAlias,
overload,
)
from typing import Any, Literal, TypeAlias, cast, get_args, overload

import numpy as np
import pytensor.gradient as tg
Expand Down Expand Up @@ -86,6 +82,14 @@

Step: TypeAlias = BlockedStep | CompoundStep

ExternalNutsSampler = Literal["nutpie", "numpyro", "blackjax"]
NutsSampler = Literal["pymc"] | ExternalNutsSampler
NutpieBackend = Literal["numba", "jax"]


NUTPIE_BACKENDS = get_args(NutpieBackend)
NUTPIE_DEFAULT_BACKEND = cast(NutpieBackend, "numba")


class SamplingIteratorCallback(Protocol):
"""Signature of the callable that may be passed to `pm.sample(callable=...)`."""
Expand Down Expand Up @@ -262,7 +266,7 @@


def _sample_external_nuts(
sampler: Literal["nutpie", "numpyro", "blackjax"],
sampler: ExternalNutsSampler,
draws: int,
tune: int,
chains: int,
Expand All @@ -280,7 +284,7 @@
if nuts_sampler_kwargs is None:
nuts_sampler_kwargs = {}

if sampler == "nutpie":
if sampler.startswith("nutpie"):
try:
import nutpie
except ImportError as err:
Expand Down Expand Up @@ -313,6 +317,23 @@
model,
**compile_kwargs,
)

def extract_backend(string: str) -> NutpieBackend:
match = re.search(r"(?<=\[)[^\]]+(?=\])", string)
if match is None:
return NUTPIE_DEFAULT_BACKEND
Comment on lines +323 to +324
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also get a None match if the string is misformatted. For example, nutpie[jax would return a None match. I suggest that you test exact equality to set the default option, and if you get None then raise a ValueError.

Suggested change
if match is None:
return NUTPIE_DEFAULT_BACKEND
if string == "nutpie":
return NUTPIE_DEFAULT_BACKEND
elif match is None:
raise ValueError(f"Could not parse nutpie backend. Found {string!r}")

result = cast(NutpieBackend, match.group(0))
if result not in NUTPIE_BACKENDS:
last_option = f"{NUTPIE_BACKENDS[-1]}"
expected = (

Check warning on line 328 in pymc/sampling/mcmc.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/mcmc.py#L321-L328

Added lines #L321 - L328 were not covered by tests
", ".join([f'"{x}"' for x in NUTPIE_BACKENDS[:-1]]) + f' or "{last_option}"'
)
raise ValueError(f'Expected one of {expected}; found "{result}"')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ValueError(f'Expected one of {expected}; found "{result}"')
raise ValueError(
'Could not parse nutpie backend. Expected one of {expected}; found "{result}"'
)

return result

Check warning on line 332 in pymc/sampling/mcmc.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/mcmc.py#L331-L332

Added lines #L331 - L332 were not covered by tests

backend = extract_backend(sampler)
compiled_model = nutpie.compile_pymc_model(model, backend=backend)

Check warning on line 335 in pymc/sampling/mcmc.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/mcmc.py#L334-L335

Added lines #L334 - L335 were not covered by tests

t_start = time.time()
idata = nutpie.sample(
compiled_model,
Expand Down Expand Up @@ -361,6 +382,10 @@
elif sampler in ("numpyro", "blackjax"):
import pymc.sampling.jax as pymc_jax

from pymc.sampling.jax import JaxNutsSampler

sampler = cast(JaxNutsSampler, sampler)

idata = pymc_jax.sample_jax_nuts(
draws=draws,
tune=tune,
Expand Down Expand Up @@ -396,7 +421,7 @@
progressbar_theme: Theme | None = default_progress_theme,
step=None,
var_names: Sequence[str] | None = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
nuts_sampler: NutsSampler = "pymc",
initvals: StartDict | Sequence[StartDict | None] | None = None,
init: str = "auto",
jitter_max_retries: int = 10,
Expand Down Expand Up @@ -427,7 +452,7 @@
progressbar_theme: Theme | None = default_progress_theme,
step=None,
var_names: Sequence[str] | None = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
nuts_sampler: NutsSampler = "pymc",
initvals: StartDict | Sequence[StartDict | None] | None = None,
init: str = "auto",
jitter_max_retries: int = 10,
Expand Down Expand Up @@ -458,7 +483,7 @@
progressbar_theme: Theme | None = default_progress_theme,
step=None,
var_names: Sequence[str] | None = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
nuts_sampler: NutsSampler = "pymc",
initvals: StartDict | Sequence[StartDict | None] | None = None,
init: str = "auto",
jitter_max_retries: int = 10,
Expand Down Expand Up @@ -517,8 +542,10 @@
method will be used, if appropriate to the model.
var_names : list of str, optional
Names of variables to be stored in the trace. Defaults to all free variables and deterministics.
nuts_sampler : str
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"].
nuts_sampler : str, default "pymc"
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"]. In addition, the compilation
backend for the chosen sampler can be set using square brackets, if available. For example, "nutpie[jax]" will
use the JAX backend for the nutpie sampler. Currently, "nutpie[jax]" and "nutpie[numba]" are allowed.
This requires the chosen sampler to be installed.
All samplers, except "pymc", require the full model to be continuous.
blas_cores: int or "auto" or None, default = "auto"
Expand Down
117 changes: 84 additions & 33 deletions tests/sampling/test_mcmc_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,45 +16,19 @@
import numpy.testing as npt
import pytest

from pymc import Data, Model, Normal, sample
from pymc import Data, Model, Normal, modelcontext, sample


@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
def test_external_nuts_sampler(recwarn, nuts_sampler):
if nuts_sampler != "pymc":
pytest.importorskip(nuts_sampler)

with Model():
x = Normal("x", 100, 5)
y = Data("y", [1, 2, 3, 4])
Data("z", [100, 190, 310, 405])

Normal("L", mu=x, sigma=0.1, observed=y)

kwargs = {
"nuts_sampler": nuts_sampler,
"random_seed": 123,
"chains": 2,
"tune": 500,
"draws": 500,
"progressbar": False,
"initvals": {"x": 0.0},
}

idata1 = sample(**kwargs)
idata2 = sample(**kwargs)
def check_external_sampler_output(warns, idata1, idata2, sample_kwargs):
nuts_sampler = sample_kwargs["nuts_sampler"]
reference_kwargs = sample_kwargs.copy()
reference_kwargs["nuts_sampler"] = "pymc"

reference_kwargs = kwargs.copy()
reference_kwargs["nuts_sampler"] = "pymc"
with modelcontext(None):
idata_reference = sample(**reference_kwargs)

warns = {
(warn.category, warn.message.args[0])
for warn in recwarn
if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning)
}
expected = set()
if nuts_sampler == "nutpie":
if nuts_sampler.startswith("nutpie"):
expected.add(
(
UserWarning,
Expand All @@ -74,7 +48,84 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
assert idata_reference.posterior.attrs.keys() == idata1.posterior.attrs.keys()


@pytest.fixture
def pymc_model():
with Model() as m:
x = Normal("x", 100, 5)
y = Data("y", [1, 2, 3, 4])
Data("z", [100, 190, 310, 405])

Normal("L", mu=x, sigma=0.1, observed=y)

return m


@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
@pytest.mark.parametrize(
"nuts_sampler",
["pymc", "nutpie", "nutpie[jax]", "blackjax", "numpyro"],
)

def test_external_nuts_sampler(pymc_model, recwarn, nuts_sampler):
if nuts_sampler != "pymc":
pytest.importorskip(nuts_sampler)

sample_kwargs = dict(
nuts_sampler=nuts_sampler,
random_seed=123,
chains=2,
tune=500,
draws=500,
progressbar=False,
initvals={"x": 0.0},
)

with pymc_model:
idata1 = sample(**sample_kwargs)
idata2 = sample(**sample_kwargs)

warns = {
(warn.category, warn.message.args[0])
for warn in recwarn
if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning)
}

check_external_sampler_output(warns, idata1, idata2, sample_kwargs)


@pytest.mark.parametrize("backend", ["numba", "jax"], ids=["numba", "jax"])
def test_numba_backend_options(pymc_model, recwarn, backend):
pytest.importorskip("nutpie")
pytest.importorskip(backend)

sample_kwargs = dict(
nuts_sampler=f"nutpie[{backend}]",
random_seed=123,
chains=2,
tune=500,
draws=500,
progressbar=False,
initvals={"x": 0.0},
)

with pymc_model:
idata1 = sample(**sample_kwargs)
idata2 = sample(**sample_kwargs)

warns = {
(warn.category, warn.message.args[0])
for warn in recwarn
if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning)
}

check_external_sampler_output(warns, idata1, idata2, sample_kwargs)


def test_invalid_nutpie_backend_raises(pymc_model):
pytest.importorskip("nutpie")
with pytest.raises(ValueError, match='Expected one of "numba" or "jax"; found "invalid"'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with pytest.raises(ValueError, match='Expected one of "numba" or "jax"; found "invalid"'):
with pytest.raises(
ValueError,
match='Could not parse nutpie backend. Expected one of "numba" or "jax"; found "invalid"',
):

with pymc_model:
sample(nuts_sampler="nutpie[invalid]", random_seed=123, chains=2, tune=500, draws=500)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with pytest.raises(ValueError, match="Could not parse nutpie backend. Found 'nutpie[bad'"):
with pymc_model:
sample(nuts_sampler="nutpie[bad", random_seed=123, chains=2, tune=500, draws=500)


def test_step_args():
pytest.importorskip("numpyro")

with Model() as model:
a = Normal("a")
idata = sample(
Expand Down
Loading