Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nonprojective dependency entropy #103

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
64 changes: 45 additions & 19 deletions torch_struct/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
from .semirings import (
LogSemiring,
MaxSemiring,
EntropySemiring,
CrossEntropySemiring,
KLDivergenceSemiring,
MultiSampledSemiring,
KMaxSemiring,
StdSemiring,
Expand Down Expand Up @@ -72,14 +69,25 @@ def log_prob(self, value):

@lazy_property
def entropy(self):
"""
Compute entropy for distribution :math:`H[z]`.
r"""
Compute entropy for distribution :math:`H[p]`.

Algorithm derivation:
..math::
H[p] &= E_{p(z)}[-\log p(z)]\\
&= -E_{p(z)}\big[ \log [\frac{1}{Z} \prod\limits_{c \in \mathcal{C}} \exp\{\phi_c(z_c)\}] \big]\\
&= -E_{p(z)}\big[ \sum\limits_{c \in \mathcal{C}} \phi_{c}(z_c) - \log Z \big]\\
&= \log Z -E_{p(z)}\big[\sum\limits_{c \in \mathcal{C}} \phi_{c}(z_c)\big]\\
&= \log Z - \sum\limits_{c \in \mathcal{C}} p(z_c) \phi_{c}(z_c)

Returns:
entropy (*batch_shape*)
"""

return self._struct(EntropySemiring).sum(self.log_potentials, self.lengths)
logZ = self.partition
p = self.marginals
phi = self.log_potentials
Hp = logZ - (p * phi).reshape(p.shape[0], -1).sum(-1)
return Hp

def cross_entropy(self, other):
"""
Expand All @@ -91,10 +99,11 @@ def cross_entropy(self, other):
Returns:
cross entropy (*batch_shape*)
"""

return self._struct(CrossEntropySemiring).sum(
[self.log_potentials, other.log_potentials], self.lengths
)
logZ = other.partition
p = self.marginals
phi_q = other.log_potentials
Hq = logZ - (p * phi_q).reshape(p.shape[0], -1).sum(-1)
return Hq

def kl(self, other):
"""
Expand All @@ -104,11 +113,15 @@ def kl(self, other):
other : Comparison distribution

Returns:
cross entropy (*batch_shape*)
kl divergence (*batch_shape*)
"""
return self._struct(KLDivergenceSemiring).sum(
[self.log_potentials, other.log_potentials], self.lengths
)
logZp = self.partition
logZq = other.partition
p = self.marginals
phi_p = self.log_potentials
phi_q = other.log_potentials
KLpq = (p * (phi_p - phi_q)).reshape(p.shape[0], -1).sum(-1) - logZp + logZq
return KLpq

@lazy_property
def max(self):
Expand Down Expand Up @@ -472,6 +485,23 @@ def __init__(self, log_potentials, lengths=None, args={}, multiroot=False):
super(NonProjectiveDependencyCRF, self).__init__(log_potentials, lengths, args)
self.multiroot = multiroot

def log_prob(self, value):
"""
Compute log probability over values :math:`p(z)`.

Parameters:
value (tensor): One-hot events (*sample_shape x batch_shape x event_shape*)

Returns:
log_probs (*sample_shape x batch_shape*)
"""
s = value.shape
# assumes values do not have any 1s outside of the lengths
value_total_log_potentials = (
(value * self.log_potentials.expand(s)).reshape(*s[:-2], -1).sum(-1)
)
return value_total_log_potentials - self.partition

@lazy_property
def marginals(self):
"""
Expand Down Expand Up @@ -502,7 +532,3 @@ def argmax(self):
(Currently not implemented)
"""
pass

@lazy_property
def entropy(self):
pass