Skip to content

Commit

Permalink
Clarify Mixin usage and convert namedtuple to dataclass.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Zickel committed Mar 28, 2024
1 parent e3d5e25 commit 06ecfda
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
11 changes: 9 additions & 2 deletions pyro/infer/importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import math
import warnings
from typing import List, Union

import torch

Expand All @@ -15,7 +16,13 @@
from .util import plate_log_prob_sum


class WeightAnalytics:
class LogWeightsMixin:
"""
Mixin class to compute analytics from a ``.log_weights`` attribute.
"""

log_weights: Union[List[Union[float, torch.Tensor]], torch.Tensor]

def get_log_normalizer(self):
"""
Estimator of the normalizing constant of the target distribution.
Expand Down Expand Up @@ -67,7 +74,7 @@ def get_ESS(self):
return ess


class Importance(TracePosterior, WeightAnalytics):
class Importance(TracePosterior, LogWeightsMixin):
"""
:param model: probabilistic model defined as a function
:param guide: guide used for sampling defined as a function
Expand Down
23 changes: 11 additions & 12 deletions pyro/infer/predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
# SPDX-License-Identifier: Apache-2.0

import warnings
from dataclasses import dataclass
from functools import reduce
from typing import List, NamedTuple, Union
from typing import List, Union

import torch

import pyro
import pyro.poutine as poutine
from pyro.infer.importance import WeightAnalytics
from pyro.infer.importance import LogWeightsMixin
from pyro.infer.util import plate_log_prob_sum
from pyro.poutine.trace_struct import Trace
from pyro.poutine.util import prune_subsample_sites
Expand All @@ -35,7 +36,8 @@ def _guess_max_plate_nesting(model, args, kwargs):
return max_plate_nesting


class _predictiveResults(NamedTuple):
@dataclass(frozen=True, eq=False)
class _predictiveResults:
"""
Return value of call to ``_predictive`` and ``_predictive_sequential``.
"""
Expand Down Expand Up @@ -317,19 +319,16 @@ def get_vectorized_trace(self, *args, **kwargs):
).trace


class WeighedPredictiveResults(NamedTuple):
samples: Union[dict, tuple]
log_weights: torch.Tensor
guide_log_prob: torch.Tensor
model_log_prob: torch.Tensor


class WeighedPredictiveResults(WeighedPredictiveResults, WeightAnalytics):
@dataclass(frozen=True, eq=False)
class WeighedPredictiveResults(LogWeightsMixin):
"""
Return value of call to instance of :class:`WeighedPredictive`.
"""

pass
samples: Union[dict, tuple]
log_weights: torch.Tensor
guide_log_prob: torch.Tensor
model_log_prob: torch.Tensor


class WeighedPredictive(Predictive):
Expand Down

0 comments on commit 06ecfda

Please sign in to comment.