-
Notifications
You must be signed in to change notification settings - Fork 411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix PosteriorMeanModel
and add qPosteriorStandardDeviation
acquisition function
#2634
base: main
Are you sure you want to change the base?
Conversation
So
Adding these properties to |
Sorry, missed that, you're right. Will remove
|
PosteriorMeanModel
and add qPosteriorMean
/qPosteriorStandardDeviation
acquisition functionsPosteriorMeanModel
and add qPosteriorStandardDeviation
acquisition function
At this point the PR can support constrained AL (I want to minimise uncertainty in the feasible region). For a simple example, see the image below: the task is to minimise uncertainty in output 1, but only where output 2 is negative. The blue dashed line (PSTD) is analytical, green solid line is the MC equivalent (with no Bessel correction - will add this, thanks @Balandat), and orange solid line is the constrained PSTD. Not closed the loop and run an actual optimisation yet, interested to see how it will perform. Code used for testing is below: Code(Sorry the plotting code is messy!) import torch
from botorch import acquisition
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.sampling import SobolQMCNormalSampler
from botorch.acquisition.objective import ScalarizedPosteriorTransform, LinearMCObjective
from plotly.subplots import make_subplots
n_train = 10
device = torch.device("cuda:1")
torch.manual_seed(3)
train_x = torch.rand(n_train, 1, dtype=torch.float64, device=device)
train_y = torch.randn(n_train, 2, dtype=torch.float64, device=device)
model = SingleTaskGP(
train_x,
train_y,
)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
_ = fit_gpytorch_mll(mll)
sampler = SobolQMCNormalSampler(torch.Size([64]))
# Objective weights
w = torch.tensor([1, 0], dtype=torch.float64, device=device)
pstd = acquisition.PosteriorStandardDeviation(
model,
posterior_transform=ScalarizedPosteriorTransform(w),
)
qpstd = acquisition.qPosteriorStandardDeviation(
model,
sampler=sampler,
objective=LinearMCObjective(w),
)
qpstd_constr = acquisition.qPosteriorStandardDeviation(
model,
sampler=sampler,
objective=LinearMCObjective(w),
constraints=[lambda samples: samples[..., 1]],
)
x = torch.linspace(0, 1, 100, device=device)
std = pstd(x[:, None, None])
qstd = qpstd(x[:, None, None])
qstd_c = qpstd_constr(x[:, None, None])
post = model.posterior(x[:, None, None])
fig = make_subplots(rows=3, cols=1, shared_xaxes=True)
for i in [0, 1]:
fig.add_scatter(x=x.cpu().detach().numpy(), y=post.mean[:, 0, i].cpu().detach().numpy(), mode="lines", line_color="black", row=i+1, col=1, showlegend=False)
fig.add_scatter(x=x.cpu().detach().numpy(), y=(post.mean[:, 0, i] + post.variance[:, 0, i]**0.5).cpu().detach().numpy(), line_color="grey", mode="lines", showlegend=False, row=i+1, col=1)
fig.add_scatter(x=x.cpu().detach().numpy(), y=(post.mean[:, 0, i] - post.variance[:, 0, i]**0.5).cpu().detach().numpy(), line_color="grey", mode="lines", showlegend=False, fill="tonexty", row=i+1, col=1)
fig.add_scatter(x=train_x[:, 0].cpu().detach().numpy(), y=train_y[:, i].cpu().detach().numpy(), mode="markers", marker_color="red", row=i+1, col=1, showlegend=False)
fig.add_hline(y=0, row=2, line_dash="dash")
fig.add_scatter(x=x.cpu().detach().numpy(), y=qstd.cpu().detach().numpy(), row=3, col=1, name="qPSTD", line_color="orange")
fig.add_scatter(x=x.cpu().detach().numpy(), y=qstd_c.cpu().detach().numpy(), row=3, col=1, name="qPSTD (constrained y2<0)", line_color="green")
fig.add_scatter(x=x.cpu().detach().numpy(), y=std.cpu().detach().numpy(), row=3, col=1, name="PSTD", line_dash="dash", line_color="blue")
fig.update_yaxes(row=1, title_text="Output 1")
fig.update_yaxes(row=2, title_text="Output 2")
fig.update_yaxes(row=3, title_text="Acquisition value")
fig.update_xaxes(row=3, title_text="x") |
Added unit tests, copied from the rest of the MC acquisition function tests. This one seems incomplete as it relies on the posterior samples having some variance in order to return nonzero acquisition values, but |
Yeah it's quite possible that |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2634 +/- ##
==========================================
- Coverage 99.98% 99.97% -0.02%
==========================================
Files 196 196
Lines 17410 17423 +13
==========================================
+ Hits 17408 17419 +11
- Misses 2 4 +2 ☔ View full report in Codecov by Sentry. |
Thanks for the pointer! Unfortunately it wasn't quite as simple as modifying that one line, the The other issue is that the reparameterisation trick for qUCB (which I'm using for qPSTD) only really makes sense for Gaussian posteriors; if a set of samples is manually constructed then the standard deviation is inaccurate. This might merit further thought, because although the UCB family of acquisition functions are most naturally suited to Gaussian posteriors (to my knowledge), it might be valuable for PSTD to also work for non-Gaussian posteriors. I've sampled from I haven't edited |
Motivation
This is a small collection of changes for improving support for optimisation with deterministic (posterior mean) and pure exploration (posterior std) acquisition functions:
PosteriorMeanModel
withoptimize_acqf
is currently not supported asPosteriorMeanModel
does not implementnum_outputs
orbatch_shape
.PosteriorStandardDeviation
acquisition function has no MC equivalent.This PR addresses the points above, and consequentially adds support for the constrained PSTD acquisition function.
Have you read the Contributing Guidelines on pull requests?
Yes
Test Plan
TODO - just submitting draft for now for discussion.
Related PRs
(If this PR adds or changes functionality, please take some time to update the docs at https://github.com/pytorch/botorch, and link to your PR here.)