Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Refactor random variables' support to use tf.contrib.distributions' #781

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions edward/inferences/conjugacy/conjugacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def normal_from_natural_params(p1, p2):
return {'loc': loc, 'scale': tf.sqrt(sigmasq)}


# TODO
_suff_stat_to_dist = defaultdict(dict)
_suff_stat_to_dist['binary'][(('#x',),)] = (
Bernoulli, lambda p1: {'logits': p1})
Expand Down
16 changes: 0 additions & 16 deletions edward/models/random_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,6 @@ def __init__(self, *args, **kwargs):

del _candidate

# Add supports; these are used, e.g., in conjugacy.
Bernoulli.support = 'binary'
Beta.support = '01'
Binomial.support = 'onehot'
Categorical.support = 'categorical'
Chi2.support = 'nonnegative'
Dirichlet.support = 'simplex'
Exponential.support = 'nonnegative'
Gamma.support = 'nonnegative'
InverseGamma.support = 'nonnegative'
Laplace.support = 'real'
Multinomial.support = 'onehot'
MultivariateNormalDiag.support = 'multivariate_real'
Normal.support = 'real'
Poisson.support = 'countable'

del absolute_import
del division
del print_function
66 changes: 45 additions & 21 deletions edward/util/random_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import tensorflow as tf

from copy import deepcopy
from edward import models as rvs
from edward.models.random_variable import RandomVariable
from edward.models.random_variables import TransformedDistribution
from edward.util.graphs import random_variables
from tensorflow.contrib.distributions import bijectors
from tensorflow.core.framework import attr_value_pb2
Expand Down Expand Up @@ -751,30 +751,54 @@ def transform(x, *args, **kwargs):
```
"""
if len(args) != 0 or kwargs.get('bijector', None) is not None:
return TransformedDistribution(x, *args, **kwargs)
return rvs.TransformedDistribution(x, *args, **kwargs)

real = (rvs.Gumbel,
rvs.Laplace,
rvs.Logistic,
rvs.Normal,
rvs.StudentT,
rvs.MultivariateNormalDiag,
rvs.MultivariateNormalFullCovariance,
rvs.MultivariateNormalTriL,
rvs.MultivariateNormalDiagPlusLowRank)
if isinstance(x, real):
# Determine if distribution has real support at construction time
# via hard-coded distributions. This prevents adding unnecessary
# ops via a transformation with identity bijector.
return x

try:
support = x.support
except AttributeError as e:
msg = """'{}' object has no 'support'
so cannot be transformed.""".format(type(x).__name__)
raise AttributeError(msg)
if x.support is None or len(x.support) > 1:
msg = "'transform' does not handle supports of type '{}'".format(support)
raise ValueError(msg)

if support == '01':
interval, measure = x.support[0]
if measure == 'simplex':
# TODO
pass
elif measure != 'real':
raise

# TODO get event_shape
# TODO compatible dtypes
# TODO tf.fill_like
is_real = tf.logical_and(tf.is_equal(interval[0], tf.constant(-np.inf)),
tf.is_equal(interval[1], tf.constant(np.inf)))
is_01 = tf.logical_and(tf.is_equal(interval[0], tf.constant(0)),
tf.is_equal(interval[1], tf.constant(1)))
is_nonnegative = tf.logical_and(tf.is_equal(interval[0], tf.constant(0)),
tf.is_equal(interval[1], tf.constant(np.inf)))
# TODO
tf.where(is_real, x, tf.where()...)
elif interval == '01':
bij = bijectors.Invert(bijectors.Sigmoid())
new_support = 'real'
elif support == 'nonnegative':
elif interval == 'nonnegative':
bij = bijectors.Invert(bijectors.Softplus())
new_support = 'real'
elif support == 'simplex':
elif interval == 'simplex':
bij = bijectors.Invert(bijectors.SoftmaxCentered(event_ndims=1))
new_support = 'multivariate_real'
elif support in ('real', 'multivariate_real'):
return x
else:
msg = "'transform' does not handle supports of type '{}'".format(support)
raise ValueError(msg)
# TODO identity

new_x = TransformedDistribution(x, bij, *args, **kwargs)
new_x.support = new_support
new_x = rvs.TransformedDistribution(x, bij, *args, **kwargs)
# TODO
new_x.support = [([], 'real')]
return new_x