Skip to content

Commit

Permalink
doc: initialize_q_batch_topk -> initialize_q_batch_topn
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Nov 25, 2024
1 parent e75239d commit 8e27422
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
4 changes: 2 additions & 2 deletions botorch/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from botorch.optim.initializers import (
initialize_q_batch,
initialize_q_batch_nonneg,
initialize_q_batch_topk,
initialize_q_batch_topn,
)
from botorch.optim.optimize import (
gen_batch_initial_conditions,
Expand All @@ -47,7 +47,7 @@
"gen_batch_initial_conditions",
"initialize_q_batch",
"initialize_q_batch_nonneg",
"initialize_q_batch_topk",
"initialize_q_batch_topn",
"OptimizationResult",
"OptimizationStatus",
"optimize_acqf",
Expand Down
8 changes: 4 additions & 4 deletions botorch/optim/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ def gen_batch_initial_conditions(
device = bounds.device
bounds_cpu = bounds.cpu()

if options.get("topk"):
init_func = initialize_q_batch_topk
if options.get("topn"):
init_func = initialize_q_batch_topn
init_func_opts = ["sorted", "largest"]

Check warning on line 334 in botorch/optim/initializers.py

View check run for this annotation

Codecov / codecov/patch

botorch/optim/initializers.py#L333-L334

Added lines #L333 - L334 were not covered by tests
elif options.get("nonnegative") or is_nonnegative(acq_function):
init_func = initialize_q_batch_nonneg
Expand Down Expand Up @@ -1079,7 +1079,7 @@ def initialize_q_batch_nonneg(
return X[idcs], acq_vals[idcs]


def initialize_q_batch_topk(
def initialize_q_batch_topn(
X: Tensor, acq_vals: Tensor, n: int, largest: bool = True, sorted: bool = True
) -> tuple[Tensor, Tensor]:
r"""Take the top `n` initial conditions for candidate generation.
Expand All @@ -1100,7 +1100,7 @@ def initialize_q_batch_topk(
>>> # for model with `d=6`:
>>> qUCB = qUpperConfidenceBound(model, beta=0.1)
>>> X_rnd = torch.rand(500, 3, 6)
>>> X_init, acq_init = initialize_q_batch_topk(
>>> X_init, acq_init = initialize_q_batch_topn(
... X=X_rnd, acq_vals=qUCB(X_rnd), n=10
... )
Expand Down
18 changes: 9 additions & 9 deletions test/optim/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@
from botorch.models import SingleTaskGP
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.optim.initializers import (
initialize_q_batch,
initialize_q_batch_nonneg,
initialize_q_batch_topk,
gen_batch_initial_conditions,
gen_one_shot_hvkg_initial_conditions,
gen_one_shot_kg_initial_conditions,
gen_value_function_initial_conditions,
initialize_q_batch,
initialize_q_batch_nonneg,
initialize_q_batch_topn,
sample_perturbed_subset_dims,
sample_points_around_best,
sample_q_batches_from_polytope,
Expand Down Expand Up @@ -157,37 +157,37 @@ def test_initialize_q_batch(self):
with self.assertRaises(RuntimeError):
initialize_q_batch(X=X, acq_vals=acq_vals, n=10)

def test_initialize_q_batch_topk(self):
def test_initialize_q_batch_topn(self):
for dtype in (torch.float, torch.double):
# basic test
X = torch.rand(5, 3, 4, device=self.device, dtype=dtype)
acq_vals = torch.rand(5, device=self.device, dtype=dtype)
ics_X, ics_acq_vals = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=2)
ics_X, ics_acq_vals = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2)
self.assertEqual(ics_X.shape, torch.Size([2, 3, 4]))
self.assertEqual(ics_X.device, X.device)
self.assertEqual(ics_X.dtype, X.dtype)
self.assertEqual(ics_acq_vals.shape, torch.Size([2]))
self.assertEqual(ics_acq_vals.device, acq_vals.device)
self.assertEqual(ics_acq_vals.dtype, acq_vals.dtype)
# ensure nothing happens if we want all samples
ics_X, ics_acq_vals = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=5)
ics_X, ics_acq_vals = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=5)
self.assertTrue(torch.equal(X, ics_X))
self.assertTrue(torch.equal(acq_vals, ics_acq_vals))
# make sure things work with constant inputs
acq_vals = torch.ones(5, device=self.device, dtype=dtype)
ics, _ = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=2)
ics, _ = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2)
self.assertEqual(ics.shape, torch.Size([2, 3, 4]))
self.assertEqual(ics.device, X.device)
self.assertEqual(ics.dtype, X.dtype)
# ensure raises correct warning
acq_vals = torch.zeros(5, device=self.device, dtype=dtype)
with warnings.catch_warnings(record=True) as w:
ics, _ = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=2)
ics, _ = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2)
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning))
self.assertEqual(ics.shape, torch.Size([2, 3, 4]))
with self.assertRaises(RuntimeError):
initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=10)
initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=10)

def test_initialize_q_batch_largeZ(self):
for dtype in (torch.float, torch.double):
Expand Down

0 comments on commit 8e27422

Please sign in to comment.