Skip to content

Commit

Permalink
add optimize_acqf_mixed
Browse files Browse the repository at this point in the history
  • Loading branch information
jduerholt committed Nov 27, 2024
1 parent 2e143c9 commit 1e5325e
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 18 deletions.
69 changes: 51 additions & 18 deletions botorch/optim/optimize_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from botorch.generation.gen import TGenCandidates
from botorch.optim.homotopy import Homotopy
from botorch.optim.initializers import TGenInitialConditions
from botorch.optim.optimize import optimize_acqf
from botorch.optim.optimize import optimize_acqf, optimize_acqf_mixed
from torch import Tensor


Expand Down Expand Up @@ -67,6 +67,7 @@ def optimize_acqf_homotopy(
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
nonlinear_inequality_constraints: list[tuple[Callable, bool]] | None = None,
fixed_features: dict[int, float] | None = None,
fixed_features_list: list[dict[int, float]] | None = None,
post_processing_func: Callable[[Tensor], Tensor] | None = None,
batch_initial_conditions: Tensor | None = None,
gen_candidates: TGenCandidates | None = None,
Expand Down Expand Up @@ -129,6 +130,10 @@ def optimize_acqf_homotopy(
`options`.
fixed_features: A map `{feature_index: value}` for features that
should be fixed to a particular value during generation.
fixed_features_list: A list of maps `{feature_index: value}`. The i-th
item represents the fixed_feature for the i-th optimization. If
`fixed_features_list` is provided, `optimize_acqf_mixed` is invoked.
All indices (`feature_index`) should be non-negative.
post_processing_func: A function that post-processes an optimization
result appropriately (i.e., according to `round-trip`
transformations).
Expand All @@ -155,13 +160,17 @@ def optimize_acqf_homotopy(
ic_gen_kwargs: Additional keyword arguments passed to function specified by
`ic_generator`
"""
if fixed_features and fixed_features_list:
raise ValueError(
"Èither `fixed_feature` or `fixed_features_list` can be provided, not both."
)

shared_optimize_acqf_kwargs = {
"num_restarts": num_restarts,
"raw_samples": raw_samples,
"inequality_constraints": inequality_constraints,
"equality_constraints": equality_constraints,
"nonlinear_inequality_constraints": nonlinear_inequality_constraints,
"fixed_features": fixed_features,
"return_best_only": False, # False to make n_restarts persist through homotopy.
"gen_candidates": gen_candidates,
"sequential": sequential,
Expand All @@ -181,14 +190,26 @@ def optimize_acqf_homotopy(
homotopy.restart()

while not homotopy.should_stop:
candidates, acq_values = optimize_acqf(
acq_function=acq_function,
bounds=bounds,
q=1,
options=options,
batch_initial_conditions=candidates,
**shared_optimize_acqf_kwargs,
)
if fixed_features_list:
candidates, acq_values = optimize_acqf_mixed(
acq_function=acq_function,
bounds=bounds,
q=1,
options=options,
batch_initial_conditions=candidates,
fixed_features_list=fixed_features_list,
**shared_optimize_acqf_kwargs,
)
else:
candidates, acq_values = optimize_acqf(
acq_function=acq_function,
bounds=bounds,
q=1,
options=options,
batch_initial_conditions=candidates,
fixed_features=fixed_features,
**shared_optimize_acqf_kwargs,
)
homotopy.step()

# Prune candidates
Expand All @@ -199,14 +220,26 @@ def optimize_acqf_homotopy(
).unsqueeze(1)

# Optimize one more time with the final options
candidates, acq_values = optimize_acqf(
acq_function=acq_function,
bounds=bounds,
q=1,
options=final_options,
batch_initial_conditions=candidates,
**shared_optimize_acqf_kwargs,
)
if fixed_features_list:
candidates, acq_values = optimize_acqf_mixed(
acq_function=acq_function,
bounds=bounds,
q=1,
options=final_options,
batch_initial_conditions=candidates,
fixed_features=fixed_features_list,
**shared_optimize_acqf_kwargs,
)
else:
candidates, acq_values = optimize_acqf(
acq_function=acq_function,
bounds=bounds,
q=1,
options=final_options,
batch_initial_conditions=candidates,
fixed_features_list=fixed_features,
**shared_optimize_acqf_kwargs,
)

# Post-process the candidates and grab the best candidate
if post_processing_func is not None:
Expand Down
32 changes: 32 additions & 0 deletions test/optim/test_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,38 @@ def test_optimize_acqf_homotopy(self):
)
self.assertEqual(candidate[0, 0], torch.tensor(1, **tkwargs))

# test fixed feature list
fixed_features_list = [{0: 1.0}]
model = GenericDeterministicModel(
f=lambda x: 5 - (x - p).sum(dim=-1, keepdims=True) ** 2
)
acqf = PosteriorMean(model=model)
# test raise error when fixed_features and fixed_features_list are both provided
with self.assertRaisesRegex(
ValueError,
"Èither `fixed_feature` or `fixed_features_list` can be provided, not both.",
):
optimize_acqf_homotopy(
q=1,
acq_function=acqf,
bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
num_restarts=2,
raw_samples=16,
fixed_features_list=fixed_features_list,
fixed_features=fixed_features,
)
candidate, acqf_val = optimize_acqf_homotopy(
q=1,
acq_function=acqf,
bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
num_restarts=2,
raw_samples=16,
fixed_features_list=fixed_features_list,
)
self.assertEqual(candidate[0, 0], torch.tensor(1, **tkwargs))

# With q > 1.
acqf = qExpectedImprovement(model=model, best_f=0.0)
candidate, acqf_val = optimize_acqf_homotopy(
Expand Down

0 comments on commit 1e5325e

Please sign in to comment.