From 70b6edbe3864a85a11ca8c87f7e47f6ca728bd06 Mon Sep 17 00:00:00 2001 From: Dustin Tran Date: Sat, 7 Oct 2017 13:40:53 -0700 Subject: [PATCH 1/2] initial commit --- edward/inferences/conjugacy/conjugacy.py | 1 + edward/models/random_variables.py | 16 ------ edward/util/random_variables.py | 62 ++++++++++++++++-------- 3 files changed, 44 insertions(+), 35 deletions(-) diff --git a/edward/inferences/conjugacy/conjugacy.py b/edward/inferences/conjugacy/conjugacy.py index 329f61f02..af7cbd584 100644 --- a/edward/inferences/conjugacy/conjugacy.py +++ b/edward/inferences/conjugacy/conjugacy.py @@ -27,6 +27,7 @@ def normal_from_natural_params(p1, p2): return {'loc': loc, 'scale': tf.sqrt(sigmasq)} +# TODO _suff_stat_to_dist = defaultdict(dict) _suff_stat_to_dist['binary'][(('#x',),)] = ( Bernoulli, lambda p1: {'logits': p1}) diff --git a/edward/models/random_variables.py b/edward/models/random_variables.py index be5cf8058..80c3e6382 100644 --- a/edward/models/random_variables.py +++ b/edward/models/random_variables.py @@ -26,22 +26,6 @@ def __init__(self, *args, **kwargs): del _candidate -# Add supports; these are used, e.g., in conjugacy. -Bernoulli.support = 'binary' -Beta.support = '01' -Binomial.support = 'onehot' -Categorical.support = 'categorical' -Chi2.support = 'nonnegative' -Dirichlet.support = 'simplex' -Exponential.support = 'nonnegative' -Gamma.support = 'nonnegative' -InverseGamma.support = 'nonnegative' -Laplace.support = 'real' -Multinomial.support = 'onehot' -MultivariateNormalDiag.support = 'multivariate_real' -Normal.support = 'real' -Poisson.support = 'countable' - del absolute_import del division del print_function diff --git a/edward/util/random_variables.py b/edward/util/random_variables.py index 5b5f3d137..3ccb66b35 100644 --- a/edward/util/random_variables.py +++ b/edward/util/random_variables.py @@ -753,28 +753,52 @@ def transform(x, *args, **kwargs): if len(args) != 0 or kwargs.get('bijector', None) is not None: return TransformedDistribution(x, *args, **kwargs) - try: - support = x.support - except AttributeError as e: - msg = """'{}' object has no 'support' - so cannot be transformed.""".format(type(x).__name__) - raise AttributeError(msg) - - if support == '01': - bij = bijectors.Invert(bijectors.Sigmoid()) - new_support = 'real' - elif support == 'nonnegative': - bij = bijectors.Invert(bijectors.Softplus()) - new_support = 'real' - elif support == 'simplex': - bij = bijectors.Invert(bijectors.SoftmaxCentered(event_ndims=1)) - new_support = 'multivariate_real' - elif support in ('real', 'multivariate_real'): + real = (Gumbel, + Laplace, + Logistic, + Normal, + StudentT, + MultivariateNormalDiag, + MultivariateNormalFullCovariance, + MultivariateNormalTriL, + MultivariateNormalDiagPlusLowRank) + if isinstance(x, real): + # Determine if distribution has real support at construction time + # via hard-coded distributions. This prevents adding unnecessary + # ops via a transformation with identity bijector. return x - else: + + if x.support is None or len(x.support) > 1: msg = "'transform' does not handle supports of type '{}'".format(support) raise ValueError(msg) + interval, measure = x.support[0] + if measure == 'simplex': + # TODO + pass + elif measure != 'real': + raise + + # TODO get event_shape + # TODO compatible dtypes + # TODO tf.fill_like + is_real = tf.logical_and(tf.is_equal(interval[0], tf.constant(-np.inf)), + tf.is_equal(interval[1], tf.constant(np.inf))) + is_01 = tf.logical_and(tf.is_equal(interval[0], tf.constant(0)), + tf.is_equal(interval[1], tf.constant(1))) + is_nonnegative = tf.logical_and(tf.is_equal(interval[0], tf.constant(0)), + tf.is_equal(interval[1], tf.constant(np.inf))) + # TODO + tf.where(is_real, x, tf.where()...) + elif interval == '01': + bij = bijectors.Invert(bijectors.Sigmoid()) + elif interval == 'nonnegative': + bij = bijectors.Invert(bijectors.Softplus()) + elif interval == 'simplex': + bij = bijectors.Invert(bijectors.SoftmaxCentered(event_ndims=1)) + # TODO identity + new_x = TransformedDistribution(x, bij, *args, **kwargs) - new_x.support = new_support + # TODO + new_x.support = [([], 'real')] return new_x From 99b45b9eb0e499ccb51a491001a50844c19f4549 Mon Sep 17 00:00:00 2001 From: Dustin Tran Date: Sat, 7 Oct 2017 14:02:36 -0700 Subject: [PATCH 2/2] fix namespaces --- edward/util/random_variables.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/edward/util/random_variables.py b/edward/util/random_variables.py index 3ccb66b35..65fc027b0 100644 --- a/edward/util/random_variables.py +++ b/edward/util/random_variables.py @@ -7,8 +7,8 @@ import tensorflow as tf from copy import deepcopy +from edward import models as rvs from edward.models.random_variable import RandomVariable -from edward.models.random_variables import TransformedDistribution from edward.util.graphs import random_variables from tensorflow.contrib.distributions import bijectors from tensorflow.core.framework import attr_value_pb2 @@ -751,17 +751,17 @@ def transform(x, *args, **kwargs): ``` """ if len(args) != 0 or kwargs.get('bijector', None) is not None: - return TransformedDistribution(x, *args, **kwargs) - - real = (Gumbel, - Laplace, - Logistic, - Normal, - StudentT, - MultivariateNormalDiag, - MultivariateNormalFullCovariance, - MultivariateNormalTriL, - MultivariateNormalDiagPlusLowRank) + return rvs.TransformedDistribution(x, *args, **kwargs) + + real = (rvs.Gumbel, + rvs.Laplace, + rvs.Logistic, + rvs.Normal, + rvs.StudentT, + rvs.MultivariateNormalDiag, + rvs.MultivariateNormalFullCovariance, + rvs.MultivariateNormalTriL, + rvs.MultivariateNormalDiagPlusLowRank) if isinstance(x, real): # Determine if distribution has real support at construction time # via hard-coded distributions. This prevents adding unnecessary @@ -798,7 +798,7 @@ def transform(x, *args, **kwargs): bij = bijectors.Invert(bijectors.SoftmaxCentered(event_ndims=1)) # TODO identity - new_x = TransformedDistribution(x, bij, *args, **kwargs) + new_x = rvs.TransformedDistribution(x, bij, *args, **kwargs) # TODO new_x.support = [([], 'real')] return new_x