-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat: Implement Full-Rank VI #720
base: main
Are you sure you want to change the base?
Conversation
@junpenglao I am trying to better grok JAX's PyTrees. I would love specific feedback on how to support PyTrees instead of just JAX arrays in my code. I believe this mostly affects my For example, consider my sampling implementation vs that in def _sample(rng_key, mu, rho, L, num_samples):
cholesky = jnp.tril(L, k=-1) + jnp.diag(jnp.exp(L))
eps = jax.random.normal(rng_key, (num_samples,) + mu.shape)
return mu + eps @ cholesky.T def _sample(rng_key, mu, rho, num_samples):
sigma = jax.tree.map(jnp.exp, rho)
mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu)
sigma_flat, _ = jax.flatten_util.ravel_pytree(sigma)
flatten_sample = (
jax.random.normal(rng_key, (num_samples,) + mu_flatten.shape) * sigma_flat
+ mu_flatten
)
return jax.vmap(unravel_fn)(flatten_sample) How + why would I change my code? Any resource recommendations would be greatly appreciated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTree with covariate matrix is a headache to deal with. I suggest you take a look at how pathfinder deal with it internally: basically all the state parameter is represented as flatten array, and you only unflatten at the end. Considering what you have right now already assume everything is a flatten array, you just need to add the flatten and unflatten part.
blackjax/vi/fullrank_vi.py
Outdated
def generate_fullrank_logdensity(mu, rho, L): | ||
cholesky = jnp.tril(L, k=-1) + jnp.diag(jnp.exp(L)) | ||
log_det = 2 * jnp.sum(rho) | ||
const = -0.5 * mu.shape[-1] * jnp.log(2 * jnp.pi) | ||
|
||
def fullrank_logdensity(position): | ||
y = jsp.linalg.solve_triangular(cholesky, position - mu, lower=True) | ||
mahalanobis_dist = jnp.sum(y ** 2, axis=-1) | ||
return const - 0.5 * log_det - 0.5 * mahalanobis_dist | ||
|
||
return fullrank_logdensity |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use multivariate_normal.logpdf
from JAX?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to avoid computing the inverse and log determinant of the covariance matrix
Does jax.random.multivariate_normal.logpdf
take Cholesky factors as input? I want to avoid needing to compute the covariance matrix
From https://jax.readthedocs.io/en/latest/_autosummary/jax.random.multivariate_normal.html it appears the multivariate normal log density only accepts the covariance as a dense matrix!
See jax-ml/jax#11386. Thoughts on tradeoff btwn readability (with JAX's multivariate normal) and speed (custom implementation)?
A couple notes about the Cholesky factor:
L = jax.tree.map(lambda x: jnp.zeros((*x.shape, x.shape)), position) During optimization, this means we may learn non-zero values for the diagonal and upper triangle in def unflatten_lower_triangular(tril_flat):
n = tril_flat.size # Number of elements in the lower triangular part
d = int(jnp.sqrt(1 + 8 * n) - 1) // 2 # Dimension of the original matrix
lower_tri_matrix = jnp.zeros((d, d))
indices = jnp.tril_indices(d, k=-1) # Indices for the lower triangle (excluding the diagonal)
return lower_tri_matrix.at[indices].set(tril_flat)
C = jnp.diag(jnp.exp(rho)) + unflatten_lower_triangular(L)
def approximate(...):
initial_position_flatten, unravel_fn = ravel_pytree(initial_position)
...
unravel_fn_mapped = jax.vmap(unravel_fn)
pathfinder_result = PathfinderState(
elbo,
unravel_fn_mapped(position),
unravel_fn_mapped(grad_position),
alpha,
beta,
gamma,
)
... How would I apply this to my def _sample(rng_key, mu, rho, L, num_samples):
cholesky = jnp.diag(jnp.exp(rho)) + unflatten_lower_triangular(L)
eps = jax.random.normal(rng_key, (num_samples,) + mu.shape)
return mu + eps @ cholesky.T |
Define `chol_params` as a flattened Cholesky factor PyTree that consists of diagonal elements followed by the off-diagonal elements in row-major order for n = dim * (dim + 1) / 2 elements. The diagonal (first dim elements) are passed through a softplus function to ensure positivity, crucial to maintain a valid covariance matrix This parameterization allows for unconstrained optimization while ensuring the resulting covariance matrix Sigma = CC^T is symmetric and positive definite. The `chol_params` are then reshaped into a lower triangular matrix `chol_factor` using `jnp.tril` and `jnp.diag` functions.
|
blackjax/vi/fullrank_vi.py
Outdated
def fullrank_logdensity(position): | ||
position_flatten = jax.flatten_util.ravel_pytree(position)[0] | ||
# TODO: inefficient because of redundant cholesky decomposition | ||
return jsp.stats.multivariate_normal.logpdf(position_flatten, mu_flatten, cov) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have a good point about computation efficiency, let's keep the cholesky version of the logpdf. Could you rewrite the API to similar to multivariate_normal.logpdf
, and use function partial when you call it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you specify the API you imagine my implementation would have and how function partial would be applied to it? Feel free to reference my code I just committed.
I see, you are following the pattern in meanfield VI, I think there are a few places need refactoring there. Let me send out a PR so you see what i meant. |
Once you finish refactoring Meanfield VI, I'm happy to adapt the new style. Let me know when you're done! |
Fix testing bug, add docstrings, and change softmax to exponential when converting `chol_params` to `chol_factor` in `_unflatten_cholesky`.
Refactor `_unflatten_cholesky()` function to take `dim` argument instead of infering it (dynamically) from the `chol_params` input vector. This avoids JIT compilation issues. Also update docstrings.
Add assert statements that verify full-rank VI recovers the true, full-rank covariance matrix.
Update: need to figure out why the full covariance matrix is not being recovered. May need to come back to this in a few weeks b/c of some deadlines. |
Thank you for your work, I am really looking forward to try it ;) |
Implement variational inference (VI) with full-rank Gaussian approximation
While mean-field VI learns a Gaussian approximation$N(\mu, \sigma I)$ with diagonal covariance $\sigma I$ , full-rank VI learns a Gaussian approximation $N(\mu, \Sigma)$ with full-rank covariance $\Sigma$ .
We use the Cholesky decomposition$\Sigma = C C^T$ parameterized by the log standard deviation $\rho$ and lower triangle matrix $L$ such that $C = \exp(\rho) + L$ . This (1) ensures $\Sigma$ remains symmetric and positive-definite during optimization 1 and (2) admits better sampling and log density computation (with improved time complexity, space complexity, and numerical stability)2. Thus the full-rank Gaussian is parameterized by $(\mu, \rho, L)$ .
To-Do:
Footnotes
Automatic Differentiation Variational Inference Section 2.4 ↩
Lecture notes ↩