From e96139e10f5a236dd21fc5b5901f6629e1e71a1d Mon Sep 17 00:00:00 2001 From: Dustin Tran Date: Sat, 7 Oct 2017 13:40:53 -0700 Subject: [PATCH] initial commit --- edward/inferences/conjugacy/conjugacy.py | 1 + edward/models/random_variables.py | 16 -------- edward/util/random_variables.py | 48 ++++++++++++++---------- 3 files changed, 30 insertions(+), 35 deletions(-) diff --git a/edward/inferences/conjugacy/conjugacy.py b/edward/inferences/conjugacy/conjugacy.py index 329f61f02..af7cbd584 100644 --- a/edward/inferences/conjugacy/conjugacy.py +++ b/edward/inferences/conjugacy/conjugacy.py @@ -27,6 +27,7 @@ def normal_from_natural_params(p1, p2): return {'loc': loc, 'scale': tf.sqrt(sigmasq)} +# TODO _suff_stat_to_dist = defaultdict(dict) _suff_stat_to_dist['binary'][(('#x',),)] = ( Bernoulli, lambda p1: {'logits': p1}) diff --git a/edward/models/random_variables.py b/edward/models/random_variables.py index be5cf8058..80c3e6382 100644 --- a/edward/models/random_variables.py +++ b/edward/models/random_variables.py @@ -26,22 +26,6 @@ def __init__(self, *args, **kwargs): del _candidate -# Add supports; these are used, e.g., in conjugacy. -Bernoulli.support = 'binary' -Beta.support = '01' -Binomial.support = 'onehot' -Categorical.support = 'categorical' -Chi2.support = 'nonnegative' -Dirichlet.support = 'simplex' -Exponential.support = 'nonnegative' -Gamma.support = 'nonnegative' -InverseGamma.support = 'nonnegative' -Laplace.support = 'real' -Multinomial.support = 'onehot' -MultivariateNormalDiag.support = 'multivariate_real' -Normal.support = 'real' -Poisson.support = 'countable' - del absolute_import del division del print_function diff --git a/edward/util/random_variables.py b/edward/util/random_variables.py index 5b5f3d137..19bcce819 100644 --- a/edward/util/random_variables.py +++ b/edward/util/random_variables.py @@ -753,28 +753,38 @@ def transform(x, *args, **kwargs): if len(args) != 0 or kwargs.get('bijector', None) is not None: return TransformedDistribution(x, *args, **kwargs) - try: - support = x.support - except AttributeError as e: - msg = """'{}' object has no 'support' - so cannot be transformed.""".format(type(x).__name__) - raise AttributeError(msg) - - if support == '01': - bij = bijectors.Invert(bijectors.Sigmoid()) - new_support = 'real' - elif support == 'nonnegative': - bij = bijectors.Invert(bijectors.Softplus()) - new_support = 'real' - elif support == 'simplex': - bij = bijectors.Invert(bijectors.SoftmaxCentered(event_ndims=1)) - new_support = 'multivariate_real' - elif support in ('real', 'multivariate_real'): + if isinstance(x, (Normal, MultivariateNormal..., ?)): + # TODO would like to do this one statically; hard-coded by + # distributions return x - else: + + if x.support is None or len(x.support) > 1: msg = "'transform' does not handle supports of type '{}'".format(support) raise ValueError(msg) + interval, measure = x.support[0] + if measure is not 'real': + raise + # TODO simplex + + # TODO get event_shape + # TODO compatible dtypes + # TODO tf.fill_like + is_real = tf.logical_and(tf.is_equal(interval[0], tf.constant(-np.inf)), + tf.is_equal(interval[1], tf.constant(np.inf))) + is_01 = tf.logical_and(tf.is_equal(interval[0], tf.constant(0)), + tf.is_equal(interval[1], tf.constant(1))) + is_nonnegative = tf.logical_and(tf.is_equal(interval[0], tf.constant(0)), + tf.is_equal(interval[1], tf.constant(np.inf))) + # TODO + tf.where(is_real, x, tf.where()...) + elif interval == '01': + bij = bijectors.Invert(bijectors.Sigmoid()) + elif interval == 'nonnegative': + bij = bijectors.Invert(bijectors.Softplus()) + elif interval == 'simplex': + bij = bijectors.Invert(bijectors.SoftmaxCentered(event_ndims=1)) + new_x = TransformedDistribution(x, bij, *args, **kwargs) - new_x.support = new_support + new_x.support = [([], 'real')] return new_x