Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TopK downselection for initial batch generation. #2636

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
8047aa1
wip: topk ic generation
CompRhys Nov 19, 2024
f5a8d64
tests: add tests
CompRhys Nov 20, 2024
96f9bef
Merge remote-tracking branch 'upstream/main' into topk-icgen
CompRhys Nov 24, 2024
a022462
fix: micro-optimization suggestion from review
CompRhys Nov 24, 2024
e75239d
fix: don't use unnormalize due to unexpected behaviour with constant …
CompRhys Nov 25, 2024
8e27422
doc: initialize_q_batch_topk -> initialize_q_batch_topn
CompRhys Nov 25, 2024
662caf1
tests: achive full coverage
CompRhys Nov 26, 2024
75eea37
clean: remote debug snippet
CompRhys Nov 26, 2024
5e0fe59
Merge remote-tracking branch 'upstream/main' into topk-icgen
CompRhys Nov 27, 2024
88a2e5d
fea: use unnormalize in more places but add flag to turn off the cons…
CompRhys Dec 2, 2024
e0202e2
doc: add docstring for the new update_constant_bounds argument
CompRhys Dec 2, 2024
21bbc27
fix: assert warns rather than catch and check
CompRhys Dec 2, 2024
6e93eba
fix: nit limit scope of context managers
CompRhys Dec 3, 2024
5e706ea
Merge branch 'topk-icgen' of https://github.com/Radical-AI/botorch in…
CompRhys Dec 3, 2024
f364fe1
doc: update the gen_batch_initial_conditions docstring
CompRhys Dec 3, 2024
7d9f9eb
Merge remote-tracking branch 'upstream/main' into topk-icgen
CompRhys Dec 3, 2024
e054fe2
Merge branch 'main' into topk-icgen
CompRhys Dec 3, 2024
1e0828c
test: reduce the number of tests
CompRhys Dec 3, 2024
ec1c167
Merge branch 'topk-icgen' of https://github.com/Radical-AI/botorch in…
CompRhys Dec 3, 2024
1ddc929
revert: redo the changes to reduce context manager scope
CompRhys Dec 3, 2024
61f6ffb
nit: change to assertWarns
CompRhys Dec 3, 2024
975bb29
Update botorch/optim/initializers.py
CompRhys Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion botorch/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
LinearHomotopySchedule,
LogLinearHomotopySchedule,
)
from botorch.optim.initializers import initialize_q_batch, initialize_q_batch_nonneg
from botorch.optim.initializers import (
initialize_q_batch,
initialize_q_batch_nonneg,
initialize_q_batch_topn,
)
from botorch.optim.optimize import (
gen_batch_initial_conditions,
optimize_acqf,
Expand All @@ -43,6 +47,7 @@
"gen_batch_initial_conditions",
"initialize_q_batch",
"initialize_q_batch_nonneg",
"initialize_q_batch_topn",
"OptimizationResult",
"OptimizationStatus",
"optimize_acqf",
Expand Down
86 changes: 79 additions & 7 deletions botorch/optim/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,24 @@ def gen_batch_initial_conditions(
init_kwargs = {}
device = bounds.device
bounds_cpu = bounds.cpu()
if "eta" in options:
init_kwargs["eta"] = options.get("eta")
if options.get("nonnegative") or is_nonnegative(acq_function):

if options.get("topn"):
init_func = initialize_q_batch_topn
init_func_opts = ["sorted", "largest"]
elif options.get("nonnegative") or is_nonnegative(acq_function):
init_func = initialize_q_batch_nonneg
if "alpha" in options:
init_kwargs["alpha"] = options.get("alpha")
init_func_opts = ["alpha", "eta"]
else:
init_func = initialize_q_batch
init_func_opts = ["eta"]

for opt in init_func_opts:
# default value of "largest" to "acq_function.maximize" if it exists
if opt == "largest" and hasattr(acq_function, "maximize"):
init_kwargs[opt] = acq_function.maximize

if opt in options:
init_kwargs[opt] = options.get(opt)

q = 1 if q is None else q
# the dimension the samples are drawn from
Expand Down Expand Up @@ -363,7 +373,7 @@ def gen_batch_initial_conditions(
X_rnd_nlzd = torch.rand(
n, q, bounds_cpu.shape[-1], dtype=bounds.dtype
)
X_rnd = bounds_cpu[0] + (bounds_cpu[1] - bounds_cpu[0]) * X_rnd_nlzd
X_rnd = X_rnd_nlzd * (bounds_cpu[1] - bounds_cpu[0]) + bounds_cpu[0]
CompRhys marked this conversation as resolved.
Show resolved Hide resolved
else:
X_rnd = sample_q_batches_from_polytope(
n=n,
Expand All @@ -375,7 +385,8 @@ def gen_batch_initial_conditions(
equality_constraints=equality_constraints,
inequality_constraints=inequality_constraints,
)
# sample points around best

# sample additional points around best
if sample_around_best:
X_best_rnd = sample_points_around_best(
acq_function=acq_function,
Expand All @@ -395,6 +406,8 @@ def gen_batch_initial_conditions(
)
# Keep X on CPU for consistency & to limit GPU memory usage.
X_rnd = fix_features(X_rnd, fixed_features=fixed_features).cpu()

# Append the fixed fantasies to the randomly generated points
CompRhys marked this conversation as resolved.
Show resolved Hide resolved
if fixed_X_fantasies is not None:
if (d_f := fixed_X_fantasies.shape[-1]) != (d_r := X_rnd.shape[-1]):
raise BotorchTensorDimensionError(
Expand All @@ -411,6 +424,9 @@ def gen_batch_initial_conditions(
],
dim=-2,
)

# Evaluate the acquisition function on `X_rnd` using `batch_limit`
# sized chunks.
with torch.no_grad():
if batch_limit is None:
batch_limit = X_rnd.shape[0]
Expand All @@ -423,16 +439,22 @@ def gen_batch_initial_conditions(
],
dim=0,
)

# Downselect the initial conditions based on the acquisition function values
batch_initial_conditions, _ = init_func(
X=X_rnd, acq_vals=acq_vals, n=num_restarts, **init_kwargs
)
batch_initial_conditions = batch_initial_conditions.to(device=device)

# Return the initial conditions if no warnings were raised
if not any(issubclass(w.category, BadInitialCandidatesWarning) for w in ws):
return batch_initial_conditions

if factor < max_factor:
factor += 1
if seed is not None:
seed += 1 # make sure to sample different X_rnd

warnings.warn(
"Unable to find non-zero acquisition function values - initial conditions "
"are being selected randomly.",
Expand Down Expand Up @@ -1057,6 +1079,56 @@ def initialize_q_batch_nonneg(
return X[idcs], acq_vals[idcs]


def initialize_q_batch_topn(
X: Tensor, acq_vals: Tensor, n: int, largest: bool = True, sorted: bool = True
) -> tuple[Tensor, Tensor]:
r"""Take the top `n` initial conditions for candidate generation.

Args:
X: A `b x q x d` tensor of `b` samples of `q`-batches from a `d`-dim.
feature space. Typically, these are generated using qMC.
acq_vals: A tensor of `b` outcomes associated with the samples. Typically, this
is the value of the batch acquisition function to be maximized.
n: The number of initial condition to be generated. Must be less than `b`.

Returns:
- An `n x q x d` tensor of `n` `q`-batch initial conditions.
- An `n` tensor of the corresponding acquisition values.

Example:
>>> # To get `n=10` starting points of q-batch size `q=3`
>>> # for model with `d=6`:
>>> qUCB = qUpperConfidenceBound(model, beta=0.1)
>>> X_rnd = torch.rand(500, 3, 6)
>>> X_init, acq_init = initialize_q_batch_topn(
... X=X_rnd, acq_vals=qUCB(X_rnd), n=10
... )

"""
n_samples = X.shape[0]
if n > n_samples:
raise RuntimeError(
f"n ({n}) cannot be larger than the number of "
f"provided samples ({n_samples})"
)
elif n == n_samples:
return X, acq_vals

Ystd = acq_vals.std(dim=0)
if torch.any(Ystd == 0):
CompRhys marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn(
"All acquisition values for raw samples points are the same for "
"at least one batch. Choosing initial conditions at random.",
BadInitialCandidatesWarning,
stacklevel=3,
)
idcs = torch.randperm(n=n_samples, device=X.device)[:n]
return X[idcs], acq_vals[idcs]

topk_out, topk_idcs = acq_vals.topk(n, largest=largest, sorted=sorted)
return X[topk_idcs], topk_out


def sample_points_around_best(
acq_function: AcquisitionFunction,
n_discrete_points: int,
Expand Down
8 changes: 3 additions & 5 deletions botorch/utils/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,12 @@ def draw_sobol_samples(
batch_shape = batch_shape or torch.Size()
batch_size = int(torch.prod(torch.tensor(batch_shape)))
d = bounds.shape[-1]
lower = bounds[0]
rng = bounds[1] - bounds[0]
sobol_engine = SobolEngine(q * d, scramble=True, seed=seed)
samples_raw = sobol_engine.draw(batch_size * n, dtype=lower.dtype)
samples_raw = samples_raw.view(*batch_shape, n, q, d).to(device=lower.device)
samples_raw = sobol_engine.draw(batch_size * n, dtype=bounds.dtype)
samples_raw = samples_raw.view(*batch_shape, n, q, d).to(device=bounds.device)
if batch_shape != torch.Size():
samples_raw = samples_raw.permute(-3, *range(len(batch_shape)), -2, -1)
return lower + rng * samples_raw
return bounds[0] + (bounds[1] - bounds[0]) * samples_raw
CompRhys marked this conversation as resolved.
Show resolved Hide resolved


def draw_sobol_normal_samples(
Expand Down
116 changes: 115 additions & 1 deletion test/optim/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@
from botorch.exceptions.warnings import BotorchWarning
from botorch.models import SingleTaskGP
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.optim import initialize_q_batch, initialize_q_batch_nonneg
from botorch.optim.initializers import (
gen_batch_initial_conditions,
gen_one_shot_hvkg_initial_conditions,
gen_one_shot_kg_initial_conditions,
gen_value_function_initial_conditions,
initialize_q_batch,
initialize_q_batch_nonneg,
initialize_q_batch_topn,
sample_perturbed_subset_dims,
sample_points_around_best,
sample_q_batches_from_polytope,
Expand Down Expand Up @@ -155,6 +157,38 @@ def test_initialize_q_batch(self):
with self.assertRaises(RuntimeError):
initialize_q_batch(X=X, acq_vals=acq_vals, n=10)

def test_initialize_q_batch_topn(self):
for dtype in (torch.float, torch.double):
# basic test
X = torch.rand(5, 3, 4, device=self.device, dtype=dtype)
acq_vals = torch.rand(5, device=self.device, dtype=dtype)
ics_X, ics_acq_vals = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2)
self.assertEqual(ics_X.shape, torch.Size([2, 3, 4]))
self.assertEqual(ics_X.device, X.device)
self.assertEqual(ics_X.dtype, X.dtype)
self.assertEqual(ics_acq_vals.shape, torch.Size([2]))
self.assertEqual(ics_acq_vals.device, acq_vals.device)
self.assertEqual(ics_acq_vals.dtype, acq_vals.dtype)
# ensure nothing happens if we want all samples
ics_X, ics_acq_vals = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=5)
self.assertTrue(torch.equal(X, ics_X))
self.assertTrue(torch.equal(acq_vals, ics_acq_vals))
# make sure things work with constant inputs
acq_vals = torch.ones(5, device=self.device, dtype=dtype)
ics, _ = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2)
self.assertEqual(ics.shape, torch.Size([2, 3, 4]))
self.assertEqual(ics.device, X.device)
self.assertEqual(ics.dtype, X.dtype)
# ensure raises correct warning
acq_vals = torch.zeros(5, device=self.device, dtype=dtype)
with warnings.catch_warnings(record=True) as w:
ics, _ = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2)
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning))
CompRhys marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(ics.shape, torch.Size([2, 3, 4]))
with self.assertRaises(RuntimeError):
initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=10)

def test_initialize_q_batch_largeZ(self):
for dtype in (torch.float, torch.double):
# testing large eta*Z
Expand Down Expand Up @@ -246,6 +280,86 @@ def test_gen_batch_initial_conditions(self):
torch.all(batch_initial_conditions[..., idx] == val)
)

def test_gen_batch_initial_conditions_topn(self):
bounds = torch.stack([torch.zeros(2), torch.ones(2)])
mock_acqf = MockAcquisitionFunction()
mock_acqf.objective = lambda y: y.squeeze(-1)
mock_acqf.maximize = True # Add maximize attribute
for dtype in (torch.float, torch.double):
bounds = bounds.to(device=self.device, dtype=dtype)
mock_acqf.X_baseline = bounds # for testing sample_around_best
mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you pull this into the product() below you can save one indent ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

happy to change this but it matches the other testcases perhaps to avoid reassigning lines 289-291 in every iteration of the product. If I change here will change in all the equivalent tests for consistency.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. I think we can go either way here. FWIW the overhead of reassigning lines 289-291 in every iteration should be negligible.

for (
topn,
largest,
is_sorted,
seed,
init_batch_limit,
ffs,
sample_around_best,
) in product(
[True, False],
[True, False, None],
[True, False],
[None, 1234],
[None, 1],
[None, {0: 0.5}],
[True, False],
):
with mock.patch.object(
MockAcquisitionFunction,
"__call__",
wraps=mock_acqf.__call__,
) as mock_acqf_call, warnings.catch_warnings():
warnings.simplefilter(
"ignore", category=BadInitialCandidatesWarning
)
options = {
"topn": topn,
"sorted": is_sorted,
"seed": seed,
"init_batch_limit": init_batch_limit,
"sample_around_best": sample_around_best,
}
if largest is not None:
options["largest"] = largest
batch_initial_conditions = gen_batch_initial_conditions(
acq_function=mock_acqf,
bounds=bounds,
q=1,
num_restarts=2,
raw_samples=10,
fixed_features=ffs,
options=options,
)
expected_shape = torch.Size([2, 1, 2])
self.assertEqual(batch_initial_conditions.shape, expected_shape)
self.assertEqual(batch_initial_conditions.device, bounds.device)
self.assertEqual(batch_initial_conditions.dtype, bounds.dtype)
self.assertLess(
_get_max_violation_of_bounds(batch_initial_conditions, bounds),
1e-6,
)
batch_shape = (
torch.Size([])
if init_batch_limit is None
else torch.Size([init_batch_limit])
)
raw_samps = mock_acqf_call.call_args[0][0]
batch_shape = (
torch.Size([20 if sample_around_best else 10])
if init_batch_limit is None
else torch.Size([init_batch_limit])
)
expected_raw_samps_shape = batch_shape + torch.Size([1, 2])
self.assertEqual(raw_samps.shape, expected_raw_samps_shape)

if ffs is not None:
for idx, val in ffs.items():
self.assertTrue(
torch.all(batch_initial_conditions[..., idx] == val)
)

def test_gen_batch_initial_conditions_highdim(self):
d = 2200 # 2200 * 10 (q) > 21201 (sobol max dim)
bounds = torch.stack([torch.zeros(d), torch.ones(d)])
Expand Down
Loading