diff --git a/pyro/infer/importance.py b/pyro/infer/importance.py index bcb144b170..ca088645cb 100644 --- a/pyro/infer/importance.py +++ b/pyro/infer/importance.py @@ -3,6 +3,7 @@ import math import warnings +from typing import List, Union import torch @@ -15,7 +16,13 @@ from .util import plate_log_prob_sum -class WeightAnalytics: +class LogWeightsMixin: + """ + Mixin class to compute analytics from a ``.log_weights`` attribute. + """ + + log_weights: Union[List[Union[float, torch.Tensor]], torch.Tensor] + def get_log_normalizer(self): """ Estimator of the normalizing constant of the target distribution. @@ -67,7 +74,7 @@ def get_ESS(self): return ess -class Importance(TracePosterior, WeightAnalytics): +class Importance(TracePosterior, LogWeightsMixin): """ :param model: probabilistic model defined as a function :param guide: guide used for sampling defined as a function diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index 3d020f0fdb..ea89aff5e5 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -2,14 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 import warnings +from dataclasses import dataclass from functools import reduce -from typing import List, NamedTuple, Union +from typing import List, Union import torch import pyro import pyro.poutine as poutine -from pyro.infer.importance import WeightAnalytics +from pyro.infer.importance import LogWeightsMixin from pyro.infer.util import plate_log_prob_sum from pyro.poutine.trace_struct import Trace from pyro.poutine.util import prune_subsample_sites @@ -35,7 +36,8 @@ def _guess_max_plate_nesting(model, args, kwargs): return max_plate_nesting -class _predictiveResults(NamedTuple): +@dataclass(frozen=True, eq=False) +class _predictiveResults: """ Return value of call to ``_predictive`` and ``_predictive_sequential``. """ @@ -317,19 +319,16 @@ def get_vectorized_trace(self, *args, **kwargs): ).trace -class WeighedPredictiveResults(NamedTuple): - samples: Union[dict, tuple] - log_weights: torch.Tensor - guide_log_prob: torch.Tensor - model_log_prob: torch.Tensor - - -class WeighedPredictiveResults(WeighedPredictiveResults, WeightAnalytics): +@dataclass(frozen=True, eq=False) +class WeighedPredictiveResults(LogWeightsMixin): """ Return value of call to instance of :class:`WeighedPredictive`. """ - pass + samples: Union[dict, tuple] + log_weights: torch.Tensor + guide_log_prob: torch.Tensor + model_log_prob: torch.Tensor class WeighedPredictive(Predictive):