From 8047aa185b4ba8b13112bdec63418a46781988ac Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Tue, 19 Nov 2024 16:22:36 -0500 Subject: [PATCH 01/16] wip: topk ic generation --- botorch/optim/initializers.py | 83 ++++++++++++++++++++++++++++++++--- botorch/utils/sampling.py | 8 ++-- 2 files changed, 79 insertions(+), 12 deletions(-) diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index af0f918f4a..ad074a4e85 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -328,14 +328,24 @@ def gen_batch_initial_conditions( init_kwargs = {} device = bounds.device bounds_cpu = bounds.cpu() - if "eta" in options: - init_kwargs["eta"] = options.get("eta") - if options.get("nonnegative") or is_nonnegative(acq_function): + + if options.get("topk"): + init_func = initialize_q_batch_topk + init_func_opts = ["sorted", "largest"] + elif options.get("nonnegative") or is_nonnegative(acq_function): init_func = initialize_q_batch_nonneg - if "alpha" in options: - init_kwargs["alpha"] = options.get("alpha") + init_func_opts = ["alpha", "eta"] else: init_func = initialize_q_batch + init_func_opts = ["eta"] + + for opt in init_func_opts: + # default value of "largest" to "acq_function.maximize" if it exists + if opt == "largest" and hasattr(acq_function, "maximize"): + init_kwargs[opt] = acq_function.maximize + + if opt in options: + init_kwargs[opt] = options.get(opt) q = 1 if q is None else q # the dimension the samples are drawn from @@ -363,7 +373,7 @@ def gen_batch_initial_conditions( X_rnd_nlzd = torch.rand( n, q, bounds_cpu.shape[-1], dtype=bounds.dtype ) - X_rnd = bounds_cpu[0] + (bounds_cpu[1] - bounds_cpu[0]) * X_rnd_nlzd + X_rnd = unnormalize(X_rnd_nlzd, bounds_cpu) else: X_rnd = sample_q_batches_from_polytope( n=n, @@ -375,7 +385,8 @@ def gen_batch_initial_conditions( equality_constraints=equality_constraints, inequality_constraints=inequality_constraints, ) - # sample points around best + + # sample additional points around best if sample_around_best: X_best_rnd = sample_points_around_best( acq_function=acq_function, @@ -395,6 +406,8 @@ def gen_batch_initial_conditions( ) # Keep X on CPU for consistency & to limit GPU memory usage. X_rnd = fix_features(X_rnd, fixed_features=fixed_features).cpu() + + # Append the fixed fantasies to the randomly generated points if fixed_X_fantasies is not None: if (d_f := fixed_X_fantasies.shape[-1]) != (d_r := X_rnd.shape[-1]): raise BotorchTensorDimensionError( @@ -411,6 +424,9 @@ def gen_batch_initial_conditions( ], dim=-2, ) + + # Evaluate the acquisition function on `X_rnd` using `batch_limit` + # sized chunks. with torch.no_grad(): if batch_limit is None: batch_limit = X_rnd.shape[0] @@ -423,16 +439,22 @@ def gen_batch_initial_conditions( ], dim=0, ) + + # Downselect the initial conditions based on the acquisition function values batch_initial_conditions, _ = init_func( X=X_rnd, acq_vals=acq_vals, n=num_restarts, **init_kwargs ) batch_initial_conditions = batch_initial_conditions.to(device=device) + + # Return the initial conditions if no warnings were raised if not any(issubclass(w.category, BadInitialCandidatesWarning) for w in ws): return batch_initial_conditions + if factor < max_factor: factor += 1 if seed is not None: seed += 1 # make sure to sample different X_rnd + warnings.warn( "Unable to find non-zero acquisition function values - initial conditions " "are being selected randomly.", @@ -1057,6 +1079,53 @@ def initialize_q_batch_nonneg( return X[idcs], acq_vals[idcs] +def initialize_q_batch_topk( + X: Tensor, acq_vals: Tensor, n: int, largest: bool = True, sorted: bool = True +) -> tuple[Tensor, Tensor]: + r"""Take the top `n` initial conditions for candidate generation. + + Args: + X: A `b x q x d` tensor of `b` samples of `q`-batches from a `d`-dim. + feature space. Typically, these are generated using qMC. + acq_vals: A tensor of `b` outcomes associated with the samples. Typically, this + is the value of the batch acquisition function to be maximized. + n: The number of initial condition to be generated. Must be less than `b`. + + Returns: + - An `n x q x d` tensor of `n` `q`-batch initial conditions. + - An `n` tensor of the corresponding acquisition values. + + Example: + >>> # To get `n=10` starting points of q-batch size `q=3` + >>> # for model with `d=6`: + >>> qUCB = qUpperConfidenceBound(model, beta=0.1) + >>> X_rnd = torch.rand(500, 3, 6) + >>> X_init, acq_init = initialize_q_batch_topk(X=X_rnd, acq_vals=qUCB(X_rnd), n=10) + """ + n_samples = X.shape[0] + if n > n_samples: + raise RuntimeError( + f"n ({n}) cannot be larger than the number of " + f"provided samples ({n_samples})" + ) + elif n == n_samples: + return X, acq_vals + + Ystd = acq_vals.std(dim=0) + if torch.any(Ystd == 0): + warnings.warn( + "All acquisition values for raw samples points are the same for " + "at least one batch. Choosing initial conditions at random.", + BadInitialCandidatesWarning, + stacklevel=3, + ) + idcs = torch.randperm(n=n_samples, device=X.device)[:n] + return X[idcs], acq_vals[idcs] + + idcs = acq_vals.topk(n, largest=largest, sorted=sorted).indices + return X[idcs], acq_vals[idcs] + + def sample_points_around_best( acq_function: AcquisitionFunction, n_discrete_points: int, diff --git a/botorch/utils/sampling.py b/botorch/utils/sampling.py index 52fe54fbb2..a508320299 100644 --- a/botorch/utils/sampling.py +++ b/botorch/utils/sampling.py @@ -98,14 +98,12 @@ def draw_sobol_samples( batch_shape = batch_shape or torch.Size() batch_size = int(torch.prod(torch.tensor(batch_shape))) d = bounds.shape[-1] - lower = bounds[0] - rng = bounds[1] - bounds[0] sobol_engine = SobolEngine(q * d, scramble=True, seed=seed) - samples_raw = sobol_engine.draw(batch_size * n, dtype=lower.dtype) - samples_raw = samples_raw.view(*batch_shape, n, q, d).to(device=lower.device) + samples_raw = sobol_engine.draw(batch_size * n, dtype=bounds.dtype) + samples_raw = samples_raw.view(*batch_shape, n, q, d).to(device=bounds.device) if batch_shape != torch.Size(): samples_raw = samples_raw.permute(-3, *range(len(batch_shape)), -2, -1) - return lower + rng * samples_raw + return unnormalize(samples_raw, bounds) def draw_sobol_normal_samples( From f5a8d64f279dc9e7e752c93ea2589cc93cf966f7 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 20 Nov 2024 11:41:29 -0500 Subject: [PATCH 02/16] tests: add tests --- botorch/optim/__init__.py | 7 ++++++- test/optim/test_initializers.py | 36 ++++++++++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/botorch/optim/__init__.py b/botorch/optim/__init__.py index f4abe3fd87..6bb32b6658 100644 --- a/botorch/optim/__init__.py +++ b/botorch/optim/__init__.py @@ -22,7 +22,11 @@ LinearHomotopySchedule, LogLinearHomotopySchedule, ) -from botorch.optim.initializers import initialize_q_batch, initialize_q_batch_nonneg +from botorch.optim.initializers import ( + initialize_q_batch, + initialize_q_batch_nonneg, + initialize_q_batch_topk, +) from botorch.optim.optimize import ( gen_batch_initial_conditions, optimize_acqf, @@ -43,6 +47,7 @@ "gen_batch_initial_conditions", "initialize_q_batch", "initialize_q_batch_nonneg", + "initialize_q_batch_topk", "OptimizationResult", "OptimizationStatus", "optimize_acqf", diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 09be6f2326..7cf7621eca 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -30,8 +30,10 @@ from botorch.exceptions.warnings import BotorchWarning from botorch.models import SingleTaskGP from botorch.models.model_list_gp_regression import ModelListGP -from botorch.optim import initialize_q_batch, initialize_q_batch_nonneg from botorch.optim.initializers import ( + initialize_q_batch, + initialize_q_batch_nonneg, + initialize_q_batch_topk, gen_batch_initial_conditions, gen_one_shot_hvkg_initial_conditions, gen_one_shot_kg_initial_conditions, @@ -155,6 +157,38 @@ def test_initialize_q_batch(self): with self.assertRaises(RuntimeError): initialize_q_batch(X=X, acq_vals=acq_vals, n=10) + def test_initialize_q_batch_topk(self): + for dtype in (torch.float, torch.double): + # basic test + X = torch.rand(5, 3, 4, device=self.device, dtype=dtype) + acq_vals = torch.rand(5, device=self.device, dtype=dtype) + ics_X, ics_acq_vals = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=2) + self.assertEqual(ics_X.shape, torch.Size([2, 3, 4])) + self.assertEqual(ics_X.device, X.device) + self.assertEqual(ics_X.dtype, X.dtype) + self.assertEqual(ics_acq_vals.shape, torch.Size([2])) + self.assertEqual(ics_acq_vals.device, acq_vals.device) + self.assertEqual(ics_acq_vals.dtype, acq_vals.dtype) + # ensure nothing happens if we want all samples + ics_X, ics_acq_vals = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=5) + self.assertTrue(torch.equal(X, ics_X)) + self.assertTrue(torch.equal(acq_vals, ics_acq_vals)) + # make sure things work with constant inputs + acq_vals = torch.ones(5, device=self.device, dtype=dtype) + ics, _ = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=2) + self.assertEqual(ics.shape, torch.Size([2, 3, 4])) + self.assertEqual(ics.device, X.device) + self.assertEqual(ics.dtype, X.dtype) + # ensure raises correct warning + acq_vals = torch.zeros(5, device=self.device, dtype=dtype) + with warnings.catch_warnings(record=True) as w: + ics, _ = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=2) + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning)) + self.assertEqual(ics.shape, torch.Size([2, 3, 4])) + with self.assertRaises(RuntimeError): + initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=10) + def test_initialize_q_batch_largeZ(self): for dtype in (torch.float, torch.double): # testing large eta*Z From a022462d84a8f63a977e539498c75467887cc98c Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 23 Nov 2024 19:13:15 -0500 Subject: [PATCH 03/16] fix: micro-optimization suggestion from review --- botorch/optim/initializers.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index ad074a4e85..520818fdcf 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -1100,7 +1100,10 @@ def initialize_q_batch_topk( >>> # for model with `d=6`: >>> qUCB = qUpperConfidenceBound(model, beta=0.1) >>> X_rnd = torch.rand(500, 3, 6) - >>> X_init, acq_init = initialize_q_batch_topk(X=X_rnd, acq_vals=qUCB(X_rnd), n=10) + >>> X_init, acq_init = initialize_q_batch_topk( + ... X=X_rnd, acq_vals=qUCB(X_rnd), n=10 + ... ) + """ n_samples = X.shape[0] if n > n_samples: @@ -1122,8 +1125,8 @@ def initialize_q_batch_topk( idcs = torch.randperm(n=n_samples, device=X.device)[:n] return X[idcs], acq_vals[idcs] - idcs = acq_vals.topk(n, largest=largest, sorted=sorted).indices - return X[idcs], acq_vals[idcs] + topk_out, topk_idcs = acq_vals.topk(n, largest=largest, sorted=sorted) + return X[topk_idcs], topk_out def sample_points_around_best( From e75239d939606f710fc3671a58385615601fa538 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 25 Nov 2024 10:33:48 -0500 Subject: [PATCH 04/16] fix: don't use unnormalize due to unexpected behaviour with constant bounds --- botorch/optim/initializers.py | 2 +- botorch/utils/sampling.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index 520818fdcf..0908c78f39 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -373,7 +373,7 @@ def gen_batch_initial_conditions( X_rnd_nlzd = torch.rand( n, q, bounds_cpu.shape[-1], dtype=bounds.dtype ) - X_rnd = unnormalize(X_rnd_nlzd, bounds_cpu) + X_rnd = X_rnd_nlzd * (bounds_cpu[1] - bounds_cpu[0]) + bounds_cpu[0] else: X_rnd = sample_q_batches_from_polytope( n=n, diff --git a/botorch/utils/sampling.py b/botorch/utils/sampling.py index a508320299..9ca48a5668 100644 --- a/botorch/utils/sampling.py +++ b/botorch/utils/sampling.py @@ -103,7 +103,7 @@ def draw_sobol_samples( samples_raw = samples_raw.view(*batch_shape, n, q, d).to(device=bounds.device) if batch_shape != torch.Size(): samples_raw = samples_raw.permute(-3, *range(len(batch_shape)), -2, -1) - return unnormalize(samples_raw, bounds) + return bounds[0] + (bounds[1] - bounds[0]) * samples_raw def draw_sobol_normal_samples( From 8e274227e51adec300fe1bbd9da7ecee5d47ccf9 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 25 Nov 2024 10:39:19 -0500 Subject: [PATCH 05/16] doc: initialize_q_batch_topk -> initialize_q_batch_topn --- botorch/optim/__init__.py | 4 ++-- botorch/optim/initializers.py | 8 ++++---- test/optim/test_initializers.py | 18 +++++++++--------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/botorch/optim/__init__.py b/botorch/optim/__init__.py index 6bb32b6658..5156bba684 100644 --- a/botorch/optim/__init__.py +++ b/botorch/optim/__init__.py @@ -25,7 +25,7 @@ from botorch.optim.initializers import ( initialize_q_batch, initialize_q_batch_nonneg, - initialize_q_batch_topk, + initialize_q_batch_topn, ) from botorch.optim.optimize import ( gen_batch_initial_conditions, @@ -47,7 +47,7 @@ "gen_batch_initial_conditions", "initialize_q_batch", "initialize_q_batch_nonneg", - "initialize_q_batch_topk", + "initialize_q_batch_topn", "OptimizationResult", "OptimizationStatus", "optimize_acqf", diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index 0908c78f39..0d91d08bea 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -329,8 +329,8 @@ def gen_batch_initial_conditions( device = bounds.device bounds_cpu = bounds.cpu() - if options.get("topk"): - init_func = initialize_q_batch_topk + if options.get("topn"): + init_func = initialize_q_batch_topn init_func_opts = ["sorted", "largest"] elif options.get("nonnegative") or is_nonnegative(acq_function): init_func = initialize_q_batch_nonneg @@ -1079,7 +1079,7 @@ def initialize_q_batch_nonneg( return X[idcs], acq_vals[idcs] -def initialize_q_batch_topk( +def initialize_q_batch_topn( X: Tensor, acq_vals: Tensor, n: int, largest: bool = True, sorted: bool = True ) -> tuple[Tensor, Tensor]: r"""Take the top `n` initial conditions for candidate generation. @@ -1100,7 +1100,7 @@ def initialize_q_batch_topk( >>> # for model with `d=6`: >>> qUCB = qUpperConfidenceBound(model, beta=0.1) >>> X_rnd = torch.rand(500, 3, 6) - >>> X_init, acq_init = initialize_q_batch_topk( + >>> X_init, acq_init = initialize_q_batch_topn( ... X=X_rnd, acq_vals=qUCB(X_rnd), n=10 ... ) diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 7cf7621eca..155e333cf9 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -31,13 +31,13 @@ from botorch.models import SingleTaskGP from botorch.models.model_list_gp_regression import ModelListGP from botorch.optim.initializers import ( - initialize_q_batch, - initialize_q_batch_nonneg, - initialize_q_batch_topk, gen_batch_initial_conditions, gen_one_shot_hvkg_initial_conditions, gen_one_shot_kg_initial_conditions, gen_value_function_initial_conditions, + initialize_q_batch, + initialize_q_batch_nonneg, + initialize_q_batch_topn, sample_perturbed_subset_dims, sample_points_around_best, sample_q_batches_from_polytope, @@ -157,12 +157,12 @@ def test_initialize_q_batch(self): with self.assertRaises(RuntimeError): initialize_q_batch(X=X, acq_vals=acq_vals, n=10) - def test_initialize_q_batch_topk(self): + def test_initialize_q_batch_topn(self): for dtype in (torch.float, torch.double): # basic test X = torch.rand(5, 3, 4, device=self.device, dtype=dtype) acq_vals = torch.rand(5, device=self.device, dtype=dtype) - ics_X, ics_acq_vals = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=2) + ics_X, ics_acq_vals = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2) self.assertEqual(ics_X.shape, torch.Size([2, 3, 4])) self.assertEqual(ics_X.device, X.device) self.assertEqual(ics_X.dtype, X.dtype) @@ -170,24 +170,24 @@ def test_initialize_q_batch_topk(self): self.assertEqual(ics_acq_vals.device, acq_vals.device) self.assertEqual(ics_acq_vals.dtype, acq_vals.dtype) # ensure nothing happens if we want all samples - ics_X, ics_acq_vals = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=5) + ics_X, ics_acq_vals = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=5) self.assertTrue(torch.equal(X, ics_X)) self.assertTrue(torch.equal(acq_vals, ics_acq_vals)) # make sure things work with constant inputs acq_vals = torch.ones(5, device=self.device, dtype=dtype) - ics, _ = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=2) + ics, _ = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2) self.assertEqual(ics.shape, torch.Size([2, 3, 4])) self.assertEqual(ics.device, X.device) self.assertEqual(ics.dtype, X.dtype) # ensure raises correct warning acq_vals = torch.zeros(5, device=self.device, dtype=dtype) with warnings.catch_warnings(record=True) as w: - ics, _ = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=2) + ics, _ = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2) self.assertEqual(len(w), 1) self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning)) self.assertEqual(ics.shape, torch.Size([2, 3, 4])) with self.assertRaises(RuntimeError): - initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=10) + initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=10) def test_initialize_q_batch_largeZ(self): for dtype in (torch.float, torch.double): From 662caf132d7e24cd825e84b27dfb9d9ff8e1f9f8 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Tue, 26 Nov 2024 11:38:20 -0500 Subject: [PATCH 06/16] tests: achive full coverage --- test/optim/test_initializers.py | 86 +++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 155e333cf9..d8b571ad91 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -280,6 +280,86 @@ def test_gen_batch_initial_conditions(self): torch.all(batch_initial_conditions[..., idx] == val) ) + def test_gen_batch_initial_conditions_topn(self): + bounds = torch.stack([torch.zeros(2), torch.ones(2)]) + mock_acqf = MockAcquisitionFunction() + mock_acqf.objective = lambda y: y.squeeze(-1) + mock_acqf.maximize = True # Add maximize attribute + for dtype in (torch.float, torch.double): + bounds = bounds.to(device=self.device, dtype=dtype) + mock_acqf.X_baseline = bounds # for testing sample_around_best + mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1])) + for ( + topn, + largest, + is_sorted, + seed, + init_batch_limit, + ffs, + sample_around_best, + ) in product( + [True, False], + [True, False, None], + [True, False], + [None, 1234], + [None, 1], + [None, {0: 0.5}], + [True, False], + ): + with mock.patch.object( + MockAcquisitionFunction, + "__call__", + wraps=mock_acqf.__call__, + ) as mock_acqf_call, warnings.catch_warnings(): + warnings.simplefilter( + "ignore", category=BadInitialCandidatesWarning + ) + options = { + "topn": topn, + "sorted": is_sorted, + "seed": seed, + "init_batch_limit": init_batch_limit, + "sample_around_best": sample_around_best, + } + if largest is not None: + options["largest"] = largest + batch_initial_conditions = gen_batch_initial_conditions( + acq_function=mock_acqf, + bounds=bounds, + q=1, + num_restarts=2, + raw_samples=10, + fixed_features=ffs, + options=options, + ) + expected_shape = torch.Size([2, 1, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([20 if sample_around_best else 10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) + def test_gen_batch_initial_conditions_highdim(self): d = 2200 # 2200 * 10 (q) > 21201 (sobol max dim) bounds = torch.stack([torch.zeros(d), torch.ones(d)]) @@ -1471,3 +1551,9 @@ def test_sample_points_around_best(self): self.assertTrue( ((X_rnd.unsqueeze(0) == X_train.unsqueeze(1)).all(dim=-1)).sum() == 0 ) + + +if __name__ == "__main__": + import pytest + + pytest.main([__file__]) From 75eea37bd75776294fc276cfc27f97fdb1d0af7a Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Tue, 26 Nov 2024 11:45:21 -0500 Subject: [PATCH 07/16] clean: remote debug snippet --- test/optim/test_initializers.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index d8b571ad91..65cde7e183 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -1551,9 +1551,3 @@ def test_sample_points_around_best(self): self.assertTrue( ((X_rnd.unsqueeze(0) == X_train.unsqueeze(1)).all(dim=-1)).sum() == 0 ) - - -if __name__ == "__main__": - import pytest - - pytest.main([__file__]) From 88a2e5d25b1da2f7e0834c1127d613104a6d1f2a Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 2 Dec 2024 09:51:01 -0500 Subject: [PATCH 08/16] fea: use unnormalize in more places but add flag to turn off the constant bound adjustment --- botorch/optim/initializers.py | 4 +- botorch/utils/feasible_volume.py | 5 +- botorch/utils/sampling.py | 2 +- botorch/utils/transforms.py | 14 +- test/optim/test_initializers.py | 262 ++++++++++++++++--------------- 5 files changed, 153 insertions(+), 134 deletions(-) diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index 0d91d08bea..fbf975cedc 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -373,7 +373,9 @@ def gen_batch_initial_conditions( X_rnd_nlzd = torch.rand( n, q, bounds_cpu.shape[-1], dtype=bounds.dtype ) - X_rnd = X_rnd_nlzd * (bounds_cpu[1] - bounds_cpu[0]) + bounds_cpu[0] + X_rnd = unnormalize( + X_rnd_nlzd, bounds, update_constant_bounds=False + ) else: X_rnd = sample_q_batches_from_polytope( n=n, diff --git a/botorch/utils/feasible_volume.py b/botorch/utils/feasible_volume.py index f3b8d2fb76..2608c03c2a 100644 --- a/botorch/utils/feasible_volume.py +++ b/botorch/utils/feasible_volume.py @@ -11,7 +11,7 @@ import botorch.models.model as model import torch from botorch.logging import _get_logger -from botorch.utils.sampling import manual_seed +from botorch.utils.sampling import manual_seed, unnormalize from torch import Tensor @@ -164,9 +164,10 @@ def estimate_feasible_volume( seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item() with manual_seed(seed=seed): - box_samples = bounds[0] + (bounds[1] - bounds[0]) * torch.rand( + samples_nlzd = torch.rand( (nsample_feature, bounds.size(1)), dtype=dtype, device=device ) + box_samples = unnormalize(samples_nlzd, bounds, update_constant_bounds=False) features, p_feature = get_feasible_samples( samples=box_samples, inequality_constraints=inequality_constraints diff --git a/botorch/utils/sampling.py b/botorch/utils/sampling.py index 9ca48a5668..f914dea24d 100644 --- a/botorch/utils/sampling.py +++ b/botorch/utils/sampling.py @@ -103,7 +103,7 @@ def draw_sobol_samples( samples_raw = samples_raw.view(*batch_shape, n, q, d).to(device=bounds.device) if batch_shape != torch.Size(): samples_raw = samples_raw.permute(-3, *range(len(batch_shape)), -2, -1) - return bounds[0] + (bounds[1] - bounds[0]) * samples_raw + return unnormalize(samples_raw, bounds, update_constant_bounds=False) def draw_sobol_normal_samples( diff --git a/botorch/utils/transforms.py b/botorch/utils/transforms.py index 01f34c0da4..5b60ec4ff1 100644 --- a/botorch/utils/transforms.py +++ b/botorch/utils/transforms.py @@ -66,7 +66,7 @@ def _update_constant_bounds(bounds: Tensor) -> Tensor: return bounds -def normalize(X: Tensor, bounds: Tensor) -> Tensor: +def normalize(X: Tensor, bounds: Tensor, update_constant_bounds: bool = True) -> Tensor: r"""Min-max normalize X w.r.t. the provided bounds. NOTE: If the upper and lower bounds are identical for a dimension, that dimension @@ -89,11 +89,15 @@ def normalize(X: Tensor, bounds: Tensor) -> Tensor: >>> bounds = torch.stack([torch.zeros(3), 0.5 * torch.ones(3)]) >>> X_normalized = normalize(X, bounds) """ - bounds = _update_constant_bounds(bounds=bounds) + bounds = ( + _update_constant_bounds(bounds=bounds) if update_constant_bounds else bounds + ) return (X - bounds[0]) / (bounds[1] - bounds[0]) -def unnormalize(X: Tensor, bounds: Tensor) -> Tensor: +def unnormalize( + X: Tensor, bounds: Tensor, update_constant_bounds: bool = True +) -> Tensor: r"""Un-normalizes X w.r.t. the provided bounds. NOTE: If the upper and lower bounds are identical for a dimension, that dimension @@ -116,7 +120,9 @@ def unnormalize(X: Tensor, bounds: Tensor) -> Tensor: >>> bounds = torch.stack([torch.zeros(3), 0.5 * torch.ones(3)]) >>> X = unnormalize(X_normalized, bounds) """ - bounds = _update_constant_bounds(bounds=bounds) + bounds = ( + _update_constant_bounds(bounds=bounds) if update_constant_bounds else bounds + ) return X * (bounds[1] - bounds[0]) + bounds[0] diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 65cde7e183..e9145eb59f 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -47,7 +47,7 @@ transform_intra_point_constraint, ) from botorch.sampling.normal import IIDNormalSampler -from botorch.utils.sampling import draw_sobol_samples, manual_seed +from botorch.utils.sampling import draw_sobol_samples, manual_seed, unnormalize from botorch.utils.testing import ( _get_max_violation_of_bounds, _get_max_violation_of_constraints, @@ -221,144 +221,152 @@ def test_gen_batch_initial_conditions(self): bounds = torch.stack([torch.zeros(2), torch.ones(2)]) mock_acqf = MockAcquisitionFunction() mock_acqf.objective = lambda y: y.squeeze(-1) - for dtype in (torch.float, torch.double): + for ( + dtype, + nonnegative, + seed, + init_batch_limit, + ffs, + sample_around_best, + ) in product( + (torch.float, torch.double), + [True, False], + [None, 1234], + [None, 1], + [None, {0: 0.5}], + [True, False], + ): bounds = bounds.to(device=self.device, dtype=dtype) mock_acqf.X_baseline = bounds # for testing sample_around_best mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1])) - for nonnegative, seed, init_batch_limit, ffs, sample_around_best in product( - [True, False], [None, 1234], [None, 1], [None, {0: 0.5}], [True, False] - ): - with mock.patch.object( - MockAcquisitionFunction, - "__call__", - wraps=mock_acqf.__call__, - ) as mock_acqf_call, warnings.catch_warnings(): - warnings.simplefilter( - "ignore", category=BadInitialCandidatesWarning - ) - batch_initial_conditions = gen_batch_initial_conditions( - acq_function=mock_acqf, - bounds=bounds, - q=1, - num_restarts=2, - raw_samples=10, - fixed_features=ffs, - options={ - "nonnegative": nonnegative, - "eta": 0.01, - "alpha": 0.1, - "seed": seed, - "init_batch_limit": init_batch_limit, - "sample_around_best": sample_around_best, - }, - ) - expected_shape = torch.Size([2, 1, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - batch_shape = ( - torch.Size([]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - raw_samps = mock_acqf_call.call_args[0][0] - batch_shape = ( - torch.Size([20 if sample_around_best else 10]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) - self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + with mock.patch.object( + MockAcquisitionFunction, + "__call__", + wraps=mock_acqf.__call__, + ) as mock_acqf_call, warnings.catch_warnings(): + warnings.simplefilter("ignore", category=BadInitialCandidatesWarning) + batch_initial_conditions = gen_batch_initial_conditions( + acq_function=mock_acqf, + bounds=bounds, + q=1, + num_restarts=2, + raw_samples=10, + fixed_features=ffs, + options={ + "nonnegative": nonnegative, + "eta": 0.01, + "alpha": 0.1, + "seed": seed, + "init_batch_limit": init_batch_limit, + "sample_around_best": sample_around_best, + }, + ) + expected_shape = torch.Size([2, 1, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([20 if sample_around_best else 10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) def test_gen_batch_initial_conditions_topn(self): bounds = torch.stack([torch.zeros(2), torch.ones(2)]) mock_acqf = MockAcquisitionFunction() mock_acqf.objective = lambda y: y.squeeze(-1) mock_acqf.maximize = True # Add maximize attribute - for dtype in (torch.float, torch.double): + for ( + dtype, + topn, + largest, + is_sorted, + seed, + init_batch_limit, + ffs, + sample_around_best, + ) in product( + [torch.float, torch.double], + [True, False], + [True, False, None], + [True, False], + [None, 1234], + [None, 1], + [None, {0: 0.5}], + [True, False], + ): bounds = bounds.to(device=self.device, dtype=dtype) mock_acqf.X_baseline = bounds # for testing sample_around_best mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1])) - for ( - topn, - largest, - is_sorted, - seed, - init_batch_limit, - ffs, - sample_around_best, - ) in product( - [True, False], - [True, False, None], - [True, False], - [None, 1234], - [None, 1], - [None, {0: 0.5}], - [True, False], - ): - with mock.patch.object( - MockAcquisitionFunction, - "__call__", - wraps=mock_acqf.__call__, - ) as mock_acqf_call, warnings.catch_warnings(): - warnings.simplefilter( - "ignore", category=BadInitialCandidatesWarning - ) - options = { - "topn": topn, - "sorted": is_sorted, - "seed": seed, - "init_batch_limit": init_batch_limit, - "sample_around_best": sample_around_best, - } - if largest is not None: - options["largest"] = largest - batch_initial_conditions = gen_batch_initial_conditions( - acq_function=mock_acqf, - bounds=bounds, - q=1, - num_restarts=2, - raw_samples=10, - fixed_features=ffs, - options=options, - ) - expected_shape = torch.Size([2, 1, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - batch_shape = ( - torch.Size([]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - raw_samps = mock_acqf_call.call_args[0][0] - batch_shape = ( - torch.Size([20 if sample_around_best else 10]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) - self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + with mock.patch.object( + MockAcquisitionFunction, + "__call__", + wraps=mock_acqf.__call__, + ) as mock_acqf_call, warnings.catch_warnings(): + warnings.simplefilter("ignore", category=BadInitialCandidatesWarning) + options = { + "topn": topn, + "sorted": is_sorted, + "seed": seed, + "init_batch_limit": init_batch_limit, + "sample_around_best": sample_around_best, + } + if largest is not None: + options["largest"] = largest + batch_initial_conditions = gen_batch_initial_conditions( + acq_function=mock_acqf, + bounds=bounds, + q=1, + num_restarts=2, + raw_samples=10, + fixed_features=ffs, + options=options, + ) + expected_shape = torch.Size([2, 1, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([20 if sample_around_best else 10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) def test_gen_batch_initial_conditions_highdim(self): d = 2200 # 2200 * 10 (q) > 21201 (sobol max dim) @@ -841,7 +849,9 @@ def generator(n: int, q: int, seed: int | None): dtype=bounds.dtype, device=self.device, ) - X_rnd = bounds[0] + (bounds[1] - bounds[0]) * X_rnd_nlzd + X_rnd = unnormalize( + X_rnd_nlzd, bounds, update_constant_bounds=False + ) X_rnd[..., -1] = 0.42 return X_rnd From e0202e2b4d1a1e6623b01955f44788aca5827344 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 2 Dec 2024 09:56:04 -0500 Subject: [PATCH 09/16] doc: add docstring for the new update_constant_bounds argument --- botorch/utils/transforms.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/botorch/utils/transforms.py b/botorch/utils/transforms.py index 5b60ec4ff1..b354821cfb 100644 --- a/botorch/utils/transforms.py +++ b/botorch/utils/transforms.py @@ -69,14 +69,15 @@ def _update_constant_bounds(bounds: Tensor) -> Tensor: def normalize(X: Tensor, bounds: Tensor, update_constant_bounds: bool = True) -> Tensor: r"""Min-max normalize X w.r.t. the provided bounds. - NOTE: If the upper and lower bounds are identical for a dimension, that dimension - will not be scaled. Such dimensions will only be shifted as - `new_X[..., i] = X[..., i] - bounds[0, i]`. This avoids division by zero issues. - Args: X: `... x d` tensor of data bounds: `2 x d` tensor of lower and upper bounds for each of the X's d columns. + update_constant_bounds: If `True`, update the constant bounds in order to + avoid division by zero issues. When the upper and lower bounds are + identical for a dimension, that dimension will not be scaled. Such + dimensions will only be shifted as + `new_X[..., i] = X[..., i] - bounds[0, i]`. Returns: A `... x d`-dim tensor of normalized data, given by @@ -100,14 +101,16 @@ def unnormalize( ) -> Tensor: r"""Un-normalizes X w.r.t. the provided bounds. - NOTE: If the upper and lower bounds are identical for a dimension, that dimension - will not be scaled. Such dimensions will only be shifted as - `new_X[..., i] = X[..., i] + bounds[0, i]`, matching the behavior of `normalize`. - Args: X: `... x d` tensor of data bounds: `2 x d` tensor of lower and upper bounds for each of the X's d columns. + update_constant_bounds: If `True`, update the constant bounds in order to + avoid division by zero issues. When the upper and lower bounds are + identical for a dimension, that dimension will not be scaled. Such + dimensions will only be shifted as + `new_X[..., i] = X[..., i] + bounds[0, i]`. This is the inverse of + the behavior of `normalize` when `update_constant_bounds=True`. Returns: A `... x d`-dim tensor of unnormalized data, given by From 21bbc278335f0a10397f1245203edf26c1995cc2 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 2 Dec 2024 18:58:48 -0500 Subject: [PATCH 10/16] fix: assert warns rather than catch and check Co-authored-by: Elizabeth Santorella --- test/optim/test_initializers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index e9145eb59f..97879c2165 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -181,10 +181,10 @@ def test_initialize_q_batch_topn(self): self.assertEqual(ics.dtype, X.dtype) # ensure raises correct warning acq_vals = torch.zeros(5, device=self.device, dtype=dtype) - with warnings.catch_warnings(record=True) as w: + with self.assertWarns( + BadInitialCandidatesWarning, + ): ics, _ = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2) - self.assertEqual(len(w), 1) - self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning)) self.assertEqual(ics.shape, torch.Size([2, 3, 4])) with self.assertRaises(RuntimeError): initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=10) From 6e93eba27302759d77c18b4248727019214e9c58 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 2 Dec 2024 19:05:23 -0500 Subject: [PATCH 11/16] fix: nit limit scope of context managers --- test/optim/test_initializers.py | 332 ++++++++++++++++---------------- 1 file changed, 166 insertions(+), 166 deletions(-) diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index e9145eb59f..551a8bbb56 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -261,33 +261,33 @@ def test_gen_batch_initial_conditions(self): "sample_around_best": sample_around_best, }, ) - expected_shape = torch.Size([2, 1, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - batch_shape = ( - torch.Size([]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - raw_samps = mock_acqf_call.call_args[0][0] - batch_shape = ( - torch.Size([20 if sample_around_best else 10]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) - self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + expected_shape = torch.Size([2, 1, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([20 if sample_around_best else 10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) def test_gen_batch_initial_conditions_topn(self): bounds = torch.stack([torch.zeros(2), torch.ones(2)]) @@ -340,33 +340,33 @@ def test_gen_batch_initial_conditions_topn(self): fixed_features=ffs, options=options, ) - expected_shape = torch.Size([2, 1, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - batch_shape = ( - torch.Size([]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - raw_samps = mock_acqf_call.call_args[0][0] - batch_shape = ( - torch.Size([20 if sample_around_best else 10]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) - self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + expected_shape = torch.Size([2, 1, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([20 if sample_around_best else 10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) def test_gen_batch_initial_conditions_highdim(self): d = 2200 # 2200 * 10 (q) > 21201 (sobol max dim) @@ -703,51 +703,51 @@ def test_gen_batch_initial_conditions_constraints(self): inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, ) - expected_shape = torch.Size([2, 1, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - self.assertLess( - _get_max_violation_of_constraints( - batch_initial_conditions, - inequality_constraints, - equality=False, - ), - 1e-6, - ) - self.assertLess( - _get_max_violation_of_constraints( - batch_initial_conditions, - equality_constraints, - equality=True, - ), - 1e-6, - ) + expected_shape = torch.Size([2, 1, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + self.assertLess( + _get_max_violation_of_constraints( + batch_initial_conditions, + inequality_constraints, + equality=False, + ), + 1e-6, + ) + self.assertLess( + _get_max_violation_of_constraints( + batch_initial_conditions, + equality_constraints, + equality=True, + ), + 1e-6, + ) - batch_shape = ( - torch.Size([]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - raw_samps = mock_acqf_call.call_args[0][0] - batch_shape = ( - torch.Size([10]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) - self.assertEqual(raw_samps.shape, expected_raw_samps_shape) - self.assertTrue((raw_samps[..., 0] == 0.5).all()) - self.assertTrue((-4 * raw_samps[..., 1] >= -3).all()) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + self.assertTrue((raw_samps[..., 0] == 0.5).all()) + self.assertTrue((-4 * raw_samps[..., 1] >= -3).all()) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) def test_gen_batch_initial_conditions_interpoint_constraints(self): for dtype in (torch.float, torch.double): @@ -801,33 +801,33 @@ def test_gen_batch_initial_conditions_interpoint_constraints(self): inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, ) - expected_shape = torch.Size([2, 3, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - - self.assertTrue((batch_initial_conditions.sum(dim=-1) <= 1).all()) - - self.assertAllClose( - batch_initial_conditions[0, 0, 0], - batch_initial_conditions[0, 1, 0], - batch_initial_conditions[0, 2, 0], - atol=1e-7, - ) + expected_shape = torch.Size([2, 3, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertAllClose( - batch_initial_conditions[1, 0, 0], - batch_initial_conditions[1, 1, 0], - batch_initial_conditions[1, 2, 0], - ) - self.assertLess( - _get_max_violation_of_constraints( - batch_initial_conditions, - inequality_constraints, - equality=False, - ), - 1e-6, - ) + self.assertTrue((batch_initial_conditions.sum(dim=-1) <= 1).all()) + + self.assertAllClose( + batch_initial_conditions[0, 0, 0], + batch_initial_conditions[0, 1, 0], + batch_initial_conditions[0, 2, 0], + atol=1e-7, + ) + + self.assertAllClose( + batch_initial_conditions[1, 0, 0], + batch_initial_conditions[1, 1, 0], + batch_initial_conditions[1, 2, 0], + ) + self.assertLess( + _get_max_violation_of_constraints( + batch_initial_conditions, + inequality_constraints, + equality=False, + ), + 1e-6, + ) def test_gen_batch_initial_conditions_generator(self): mock_acqf = MockAcquisitionFunction() @@ -880,20 +880,20 @@ def generator(n: int, q: int, seed: int | None): "init_batch_limit": init_batch_limit, }, ) - expected_shape = torch.Size([4, 2, 3]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertTrue((batch_initial_conditions[..., -1] == 0.42).all()) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) + expected_shape = torch.Size([4, 2, 3]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertTrue((batch_initial_conditions[..., -1] == 0.42).all()) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) def test_error_generator_with_sample_around_best(self): tkwargs = {"device": self.device, "dtype": torch.double} @@ -976,39 +976,39 @@ def test_gen_batch_initial_conditions_fixed_X_fantasies(self): }, fixed_X_fantasies=fixed_X_fantasies, ) - expected_shape = torch.Size([2, 4, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - batch_shape = ( - torch.Size([]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - raw_samps = mock_acqf_call.call_args[0][0] - batch_shape = ( - torch.Size([20 if sample_around_best else 10]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - expected_raw_samps_shape = batch_shape + torch.Size([4, 2]) - self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + expected_shape = torch.Size([2, 4, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([20 if sample_around_best else 10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([4, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., 0, idx] == val) - ) - self.assertTrue( - torch.equal( - batch_initial_conditions[:, 1:], - fixed_X_fantasies.unsqueeze(0).expand(2, 3, 2), + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., 0, idx] == val) ) + self.assertTrue( + torch.equal( + batch_initial_conditions[:, 1:], + fixed_X_fantasies.unsqueeze(0).expand(2, 3, 2), ) + ) # test wrong shape msg = ( "`fixed_X_fantasies` and `bounds` must both have the same trailing" From f364fe14592bfa1c36640949d4e0e88405fe06a2 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 2 Dec 2024 19:11:02 -0500 Subject: [PATCH 12/16] doc: update the gen_batch_initial_conditions docstring --- botorch/optim/initializers.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index fbf975cedc..b8584eab99 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -271,13 +271,15 @@ def gen_batch_initial_conditions( fixed_features: A map `{feature_index: value}` for features that should be fixed to a particular value during generation. options: Options for initial condition generation. For valid options see - `initialize_q_batch` and `initialize_q_batch_nonneg`. If `options` - contains a `nonnegative=True` entry, then `acq_function` is - assumed to be non-negative (useful when using custom acquisition - functions). In addition, an "init_batch_limit" option can be passed - to specify the batch limit for the initialization. This is useful - for avoiding memory limits when computing the batch posterior over - raw samples. + `initialize_q_batch_topn`, `initialize_q_batch_nonneg`, and + `initialize_q_batch`. If `options` contains a `topn=True` then + `initialize_q_batch_topn` will be used. Else if `options` contains a + `nonnegative=True` entry, then `acq_function` is assumed to be + non-negative (useful when using custom acquisition functions). + `initialize_q_batch` will be used otherwise. In addition, an + "init_batch_limit" option can be passed to specify the batch limit + for the initialization. This is useful for avoiding memory limits + when computing the batch posterior over raw samples. inequality constraints: A list of tuples (indices, coefficients, rhs), with each tuple encoding an inequality constraint of the form `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`. From 1e0828c9b0f65a751e76fd488c536dee7b90e24c Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Tue, 3 Dec 2024 11:30:44 -0500 Subject: [PATCH 13/16] test: reduce the number of tests --- test/optim/test_initializers.py | 494 ++++++++++++++++---------------- 1 file changed, 247 insertions(+), 247 deletions(-) diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 909c342321..902ecfc449 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -131,31 +131,36 @@ def test_initialize_q_batch_nonneg(self): self.assertEqual(ics.dtype, X.dtype) def test_initialize_q_batch(self): - for dtype in (torch.float, torch.double): - for batch_shape in (torch.Size(), [3, 2], (2,), torch.Size([2, 3, 4]), []): - # basic test - X = torch.rand(5, *batch_shape, 3, 4, device=self.device, dtype=dtype) - acq_vals = torch.rand(5, *batch_shape, device=self.device, dtype=dtype) - ics_X, ics_acq_vals = initialize_q_batch(X=X, acq_vals=acq_vals, n=2) - self.assertEqual(ics_X.shape, torch.Size([2, *batch_shape, 3, 4])) - self.assertEqual(ics_X.device, X.device) - self.assertEqual(ics_X.dtype, X.dtype) - self.assertEqual(ics_acq_vals.shape, torch.Size([2, *batch_shape])) - self.assertEqual(ics_acq_vals.device, acq_vals.device) - self.assertEqual(ics_acq_vals.dtype, acq_vals.dtype) - # ensure nothing happens if we want all samples - ics_X, ics_acq_vals = initialize_q_batch(X=X, acq_vals=acq_vals, n=5) - self.assertTrue(torch.equal(X, ics_X)) - self.assertTrue(torch.equal(acq_vals, ics_acq_vals)) - # ensure raises correct warning - acq_vals = torch.zeros(5, device=self.device, dtype=dtype) - with warnings.catch_warnings(record=True) as w: - ics, _ = initialize_q_batch(X=X, acq_vals=acq_vals, n=2) - self.assertEqual(len(w), 1) - self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning)) - self.assertEqual(ics.shape, torch.Size([2, *batch_shape, 3, 4])) - with self.assertRaises(RuntimeError): - initialize_q_batch(X=X, acq_vals=acq_vals, n=10) + for dtype, batch_shape in ( + (torch.float, torch.Size()), + (torch.double, [3, 2]), + (torch.float, (2,)), + (torch.double, torch.Size([2, 3, 4])), + (torch.float, []), + ): + # basic test + X = torch.rand(5, *batch_shape, 3, 4, device=self.device, dtype=dtype) + acq_vals = torch.rand(5, *batch_shape, device=self.device, dtype=dtype) + ics_X, ics_acq_vals = initialize_q_batch(X=X, acq_vals=acq_vals, n=2) + self.assertEqual(ics_X.shape, torch.Size([2, *batch_shape, 3, 4])) + self.assertEqual(ics_X.device, X.device) + self.assertEqual(ics_X.dtype, X.dtype) + self.assertEqual(ics_acq_vals.shape, torch.Size([2, *batch_shape])) + self.assertEqual(ics_acq_vals.device, acq_vals.device) + self.assertEqual(ics_acq_vals.dtype, acq_vals.dtype) + # ensure nothing happens if we want all samples + ics_X, ics_acq_vals = initialize_q_batch(X=X, acq_vals=acq_vals, n=5) + self.assertTrue(torch.equal(X, ics_X)) + self.assertTrue(torch.equal(acq_vals, ics_acq_vals)) + # ensure raises correct warning + acq_vals = torch.zeros(5, device=self.device, dtype=dtype) + with warnings.catch_warnings(record=True) as w: + ics, _ = initialize_q_batch(X=X, acq_vals=acq_vals, n=2) + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning)) + self.assertEqual(ics.shape, torch.Size([2, *batch_shape, 3, 4])) + with self.assertRaises(RuntimeError): + initialize_q_batch(X=X, acq_vals=acq_vals, n=10) def test_initialize_q_batch_topn(self): for dtype in (torch.float, torch.double): @@ -181,10 +186,10 @@ def test_initialize_q_batch_topn(self): self.assertEqual(ics.dtype, X.dtype) # ensure raises correct warning acq_vals = torch.zeros(5, device=self.device, dtype=dtype) - with self.assertWarns( - BadInitialCandidatesWarning, - ): + with warnings.catch_warnings(record=True) as w: ics, _ = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2) + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning)) self.assertEqual(ics.shape, torch.Size([2, 3, 4])) with self.assertRaises(RuntimeError): initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=10) @@ -228,13 +233,10 @@ def test_gen_batch_initial_conditions(self): init_batch_limit, ffs, sample_around_best, - ) in product( - (torch.float, torch.double), - [True, False], - [None, 1234], - [None, 1], - [None, {0: 0.5}], - [True, False], + ) in ( + (torch.float, True, None, None, None, True), + (torch.double, False, 1234, 1, {0: 0.5}, False), + (torch.double, True, 1234, None, {0: 0.5}, True), ): bounds = bounds.to(device=self.device, dtype=dtype) mock_acqf.X_baseline = bounds # for testing sample_around_best @@ -261,33 +263,33 @@ def test_gen_batch_initial_conditions(self): "sample_around_best": sample_around_best, }, ) - expected_shape = torch.Size([2, 1, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - batch_shape = ( - torch.Size([]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - raw_samps = mock_acqf_call.call_args[0][0] - batch_shape = ( - torch.Size([20 if sample_around_best else 10]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) - self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + expected_shape = torch.Size([2, 1, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([20 if sample_around_best else 10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) def test_gen_batch_initial_conditions_topn(self): bounds = torch.stack([torch.zeros(2), torch.ones(2)]) @@ -303,15 +305,15 @@ def test_gen_batch_initial_conditions_topn(self): init_batch_limit, ffs, sample_around_best, - ) in product( - [torch.float, torch.double], - [True, False], - [True, False, None], - [True, False], - [None, 1234], - [None, 1], - [None, {0: 0.5}], - [True, False], + ) in ( + (torch.float, True, True, True, None, None, None, True), + (torch.double, False, False, False, 1234, 1, {0: 0.5}, False), + (torch.float, True, None, True, 1234, None, None, False), + (torch.double, False, True, False, None, 1, {0: 0.5}, True), + (torch.float, True, False, False, 1234, None, {0: 0.5}, True), + (torch.double, False, None, True, None, 1, None, False), + (torch.float, True, True, False, 1234, 1, {0: 0.5}, True), + (torch.double, False, False, True, None, None, None, False), ): bounds = bounds.to(device=self.device, dtype=dtype) mock_acqf.X_baseline = bounds # for testing sample_around_best @@ -340,33 +342,33 @@ def test_gen_batch_initial_conditions_topn(self): fixed_features=ffs, options=options, ) - expected_shape = torch.Size([2, 1, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - batch_shape = ( - torch.Size([]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - raw_samps = mock_acqf_call.call_args[0][0] - batch_shape = ( - torch.Size([20 if sample_around_best else 10]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) - self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + expected_shape = torch.Size([2, 1, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([20 if sample_around_best else 10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) def test_gen_batch_initial_conditions_highdim(self): d = 2200 # 2200 * 10 (q) > 21201 (sobol max dim) @@ -374,48 +376,46 @@ def test_gen_batch_initial_conditions_highdim(self): ffs_map = {i: random() for i in range(0, d, 2)} mock_acqf = MockAcquisitionFunction() mock_acqf.objective = lambda y: y.squeeze(-1) - for dtype in (torch.float, torch.double): + for dtype, nonnegative, seed, ffs, sample_around_best in ( + (torch.float, True, None, None, True), + (torch.double, False, 1234, ffs_map, False), + (torch.double, True, 1234, ffs_map, True), + ): bounds = bounds.to(device=self.device, dtype=dtype) mock_acqf.X_baseline = bounds # for testing sample_around_best mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1])) - - for nonnegative, seed, ffs, sample_around_best in product( - [True, False], [None, 1234], [None, ffs_map], [True, False] - ): - with warnings.catch_warnings(record=True) as ws: - warnings.simplefilter( - "ignore", category=BadInitialCandidatesWarning - ) - batch_initial_conditions = gen_batch_initial_conditions( - acq_function=MockAcquisitionFunction(), - bounds=bounds, - q=10, - num_restarts=1, - raw_samples=2, - fixed_features=ffs, - options={ - "nonnegative": nonnegative, - "eta": 0.01, - "alpha": 0.1, - "seed": seed, - "sample_around_best": sample_around_best, - }, - ) + with warnings.catch_warnings(record=True) as ws: + warnings.simplefilter("ignore", category=BadInitialCandidatesWarning) + batch_initial_conditions = gen_batch_initial_conditions( + acq_function=MockAcquisitionFunction(), + bounds=bounds, + q=10, + num_restarts=1, + raw_samples=2, + fixed_features=ffs, + options={ + "nonnegative": nonnegative, + "eta": 0.01, + "alpha": 0.1, + "seed": seed, + "sample_around_best": sample_around_best, + }, + ) + self.assertTrue( + any(issubclass(w.category, SamplingWarning) for w in ws) + ) + expected_shape = torch.Size([1, 10, d]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), 1e-6 + ) + if ffs is not None: + for idx, val in ffs.items(): self.assertTrue( - any(issubclass(w.category, SamplingWarning) for w in ws) + torch.all(batch_initial_conditions[..., idx] == val) ) - expected_shape = torch.Size([1, 10, d]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), 1e-6 - ) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) def test_gen_batch_initial_conditions_warning(self) -> None: for dtype in (torch.float, torch.double): @@ -703,51 +703,51 @@ def test_gen_batch_initial_conditions_constraints(self): inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, ) - expected_shape = torch.Size([2, 1, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - self.assertLess( - _get_max_violation_of_constraints( - batch_initial_conditions, - inequality_constraints, - equality=False, - ), - 1e-6, - ) - self.assertLess( - _get_max_violation_of_constraints( - batch_initial_conditions, - equality_constraints, - equality=True, - ), - 1e-6, - ) + expected_shape = torch.Size([2, 1, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + self.assertLess( + _get_max_violation_of_constraints( + batch_initial_conditions, + inequality_constraints, + equality=False, + ), + 1e-6, + ) + self.assertLess( + _get_max_violation_of_constraints( + batch_initial_conditions, + equality_constraints, + equality=True, + ), + 1e-6, + ) - batch_shape = ( - torch.Size([]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - raw_samps = mock_acqf_call.call_args[0][0] - batch_shape = ( - torch.Size([10]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) - self.assertEqual(raw_samps.shape, expected_raw_samps_shape) - self.assertTrue((raw_samps[..., 0] == 0.5).all()) - self.assertTrue((-4 * raw_samps[..., 1] >= -3).all()) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + self.assertTrue((raw_samps[..., 0] == 0.5).all()) + self.assertTrue((-4 * raw_samps[..., 1] >= -3).all()) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) def test_gen_batch_initial_conditions_interpoint_constraints(self): for dtype in (torch.float, torch.double): @@ -801,33 +801,33 @@ def test_gen_batch_initial_conditions_interpoint_constraints(self): inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, ) - expected_shape = torch.Size([2, 3, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - - self.assertTrue((batch_initial_conditions.sum(dim=-1) <= 1).all()) - - self.assertAllClose( - batch_initial_conditions[0, 0, 0], - batch_initial_conditions[0, 1, 0], - batch_initial_conditions[0, 2, 0], - atol=1e-7, - ) + expected_shape = torch.Size([2, 3, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + + self.assertTrue((batch_initial_conditions.sum(dim=-1) <= 1).all()) + + self.assertAllClose( + batch_initial_conditions[0, 0, 0], + batch_initial_conditions[0, 1, 0], + batch_initial_conditions[0, 2, 0], + atol=1e-7, + ) - self.assertAllClose( - batch_initial_conditions[1, 0, 0], - batch_initial_conditions[1, 1, 0], - batch_initial_conditions[1, 2, 0], - ) - self.assertLess( - _get_max_violation_of_constraints( - batch_initial_conditions, - inequality_constraints, - equality=False, - ), - 1e-6, - ) + self.assertAllClose( + batch_initial_conditions[1, 0, 0], + batch_initial_conditions[1, 1, 0], + batch_initial_conditions[1, 2, 0], + ) + self.assertLess( + _get_max_violation_of_constraints( + batch_initial_conditions, + inequality_constraints, + equality=False, + ), + 1e-6, + ) def test_gen_batch_initial_conditions_generator(self): mock_acqf = MockAcquisitionFunction() @@ -880,20 +880,20 @@ def generator(n: int, q: int, seed: int | None): "init_batch_limit": init_batch_limit, }, ) - expected_shape = torch.Size([4, 2, 3]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertTrue((batch_initial_conditions[..., -1] == 0.42).all()) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) + expected_shape = torch.Size([4, 2, 3]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertTrue((batch_initial_conditions[..., -1] == 0.42).all()) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) def test_error_generator_with_sample_around_best(self): tkwargs = {"device": self.device, "dtype": torch.double} @@ -976,39 +976,39 @@ def test_gen_batch_initial_conditions_fixed_X_fantasies(self): }, fixed_X_fantasies=fixed_X_fantasies, ) - expected_shape = torch.Size([2, 4, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - batch_shape = ( - torch.Size([]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - raw_samps = mock_acqf_call.call_args[0][0] - batch_shape = ( - torch.Size([20 if sample_around_best else 10]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - expected_raw_samps_shape = batch_shape + torch.Size([4, 2]) - self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + expected_shape = torch.Size([2, 4, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([20 if sample_around_best else 10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([4, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., 0, idx] == val) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., 0, idx] == val) + ) + self.assertTrue( + torch.equal( + batch_initial_conditions[:, 1:], + fixed_X_fantasies.unsqueeze(0).expand(2, 3, 2), ) - self.assertTrue( - torch.equal( - batch_initial_conditions[:, 1:], - fixed_X_fantasies.unsqueeze(0).expand(2, 3, 2), ) - ) # test wrong shape msg = ( "`fixed_X_fantasies` and `bounds` must both have the same trailing" From 1ddc929a8c60b7d3b131da49df3a500c6d341ea9 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Tue, 3 Dec 2024 11:34:58 -0500 Subject: [PATCH 14/16] revert: redo the changes to reduce context manager scope --- test/optim/test_initializers.py | 333 ++++++++++++++++---------------- 1 file changed, 167 insertions(+), 166 deletions(-) diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 902ecfc449..16fe1c675a 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -263,33 +263,33 @@ def test_gen_batch_initial_conditions(self): "sample_around_best": sample_around_best, }, ) - expected_shape = torch.Size([2, 1, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - batch_shape = ( - torch.Size([]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - raw_samps = mock_acqf_call.call_args[0][0] - batch_shape = ( - torch.Size([20 if sample_around_best else 10]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) - self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + expected_shape = torch.Size([2, 1, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([20 if sample_around_best else 10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) def test_gen_batch_initial_conditions_topn(self): bounds = torch.stack([torch.zeros(2), torch.ones(2)]) @@ -342,33 +342,33 @@ def test_gen_batch_initial_conditions_topn(self): fixed_features=ffs, options=options, ) - expected_shape = torch.Size([2, 1, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - batch_shape = ( - torch.Size([]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - raw_samps = mock_acqf_call.call_args[0][0] - batch_shape = ( - torch.Size([20 if sample_around_best else 10]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) - self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + expected_shape = torch.Size([2, 1, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([20 if sample_around_best else 10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) def test_gen_batch_initial_conditions_highdim(self): d = 2200 # 2200 * 10 (q) > 21201 (sobol max dim) @@ -703,51 +703,51 @@ def test_gen_batch_initial_conditions_constraints(self): inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, ) - expected_shape = torch.Size([2, 1, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - self.assertLess( - _get_max_violation_of_constraints( - batch_initial_conditions, - inequality_constraints, - equality=False, - ), - 1e-6, - ) - self.assertLess( - _get_max_violation_of_constraints( - batch_initial_conditions, - equality_constraints, - equality=True, - ), - 1e-6, - ) + expected_shape = torch.Size([2, 1, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + self.assertLess( + _get_max_violation_of_constraints( + batch_initial_conditions, + inequality_constraints, + equality=False, + ), + 1e-6, + ) + self.assertLess( + _get_max_violation_of_constraints( + batch_initial_conditions, + equality_constraints, + equality=True, + ), + 1e-6, + ) - batch_shape = ( - torch.Size([]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - raw_samps = mock_acqf_call.call_args[0][0] - batch_shape = ( - torch.Size([10]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) - self.assertEqual(raw_samps.shape, expected_raw_samps_shape) - self.assertTrue((raw_samps[..., 0] == 0.5).all()) - self.assertTrue((-4 * raw_samps[..., 1] >= -3).all()) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + self.assertTrue((raw_samps[..., 0] == 0.5).all()) + self.assertTrue((-4 * raw_samps[..., 1] >= -3).all()) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) def test_gen_batch_initial_conditions_interpoint_constraints(self): for dtype in (torch.float, torch.double): @@ -801,33 +801,33 @@ def test_gen_batch_initial_conditions_interpoint_constraints(self): inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, ) - expected_shape = torch.Size([2, 3, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - - self.assertTrue((batch_initial_conditions.sum(dim=-1) <= 1).all()) - - self.assertAllClose( - batch_initial_conditions[0, 0, 0], - batch_initial_conditions[0, 1, 0], - batch_initial_conditions[0, 2, 0], - atol=1e-7, - ) + expected_shape = torch.Size([2, 3, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertAllClose( - batch_initial_conditions[1, 0, 0], - batch_initial_conditions[1, 1, 0], - batch_initial_conditions[1, 2, 0], - ) - self.assertLess( - _get_max_violation_of_constraints( - batch_initial_conditions, - inequality_constraints, - equality=False, - ), - 1e-6, - ) + self.assertTrue((batch_initial_conditions.sum(dim=-1) <= 1).all()) + + self.assertAllClose( + batch_initial_conditions[0, 0, 0], + batch_initial_conditions[0, 1, 0], + batch_initial_conditions[0, 2, 0], + atol=1e-7, + ) + + self.assertAllClose( + batch_initial_conditions[1, 0, 0], + batch_initial_conditions[1, 1, 0], + batch_initial_conditions[1, 2, 0], + ) + self.assertLess( + _get_max_violation_of_constraints( + batch_initial_conditions, + inequality_constraints, + equality=False, + ), + 1e-6, + ) def test_gen_batch_initial_conditions_generator(self): mock_acqf = MockAcquisitionFunction() @@ -880,20 +880,20 @@ def generator(n: int, q: int, seed: int | None): "init_batch_limit": init_batch_limit, }, ) - expected_shape = torch.Size([4, 2, 3]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertTrue((batch_initial_conditions[..., -1] == 0.42).all()) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) + expected_shape = torch.Size([4, 2, 3]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertTrue((batch_initial_conditions[..., -1] == 0.42).all()) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) def test_error_generator_with_sample_around_best(self): tkwargs = {"device": self.device, "dtype": torch.double} @@ -976,39 +976,40 @@ def test_gen_batch_initial_conditions_fixed_X_fantasies(self): }, fixed_X_fantasies=fixed_X_fantasies, ) - expected_shape = torch.Size([2, 4, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - batch_shape = ( - torch.Size([]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - raw_samps = mock_acqf_call.call_args[0][0] - batch_shape = ( - torch.Size([20 if sample_around_best else 10]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - expected_raw_samps_shape = batch_shape + torch.Size([4, 2]) - self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + expected_shape = torch.Size([2, 4, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([20 if sample_around_best else 10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([4, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., 0, idx] == val) - ) - self.assertTrue( - torch.equal( - batch_initial_conditions[:, 1:], - fixed_X_fantasies.unsqueeze(0).expand(2, 3, 2), + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., 0, idx] == val) ) + self.assertTrue( + torch.equal( + batch_initial_conditions[:, 1:], + fixed_X_fantasies.unsqueeze(0).expand(2, 3, 2), ) + ) + # test wrong shape msg = ( "`fixed_X_fantasies` and `bounds` must both have the same trailing" From 61f6ffb3caa49ccf04ab5af993364af6073e1d6e Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Tue, 3 Dec 2024 11:47:21 -0500 Subject: [PATCH 15/16] nit: change to assertWarns --- test/optim/test_initializers.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 16fe1c675a..65c3e2b6bb 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -110,10 +110,8 @@ def test_initialize_q_batch_nonneg(self): self.assertEqual(ics.dtype, X.dtype) # ensure raises correct warning acq_vals = torch.zeros(5, device=self.device, dtype=dtype) - with warnings.catch_warnings(record=True) as w: + with self.assertWarns(BadInitialCandidatesWarning): ics, _ = initialize_q_batch_nonneg(X=X, acq_vals=acq_vals, n=2) - self.assertEqual(len(w), 1) - self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning)) self.assertEqual(ics.shape, torch.Size([2, 3, 4])) with self.assertRaises(RuntimeError): initialize_q_batch_nonneg(X=X, acq_vals=acq_vals, n=10) @@ -154,10 +152,8 @@ def test_initialize_q_batch(self): self.assertTrue(torch.equal(acq_vals, ics_acq_vals)) # ensure raises correct warning acq_vals = torch.zeros(5, device=self.device, dtype=dtype) - with warnings.catch_warnings(record=True) as w: + with self.assertWarns(BadInitialCandidatesWarning): ics, _ = initialize_q_batch(X=X, acq_vals=acq_vals, n=2) - self.assertEqual(len(w), 1) - self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning)) self.assertEqual(ics.shape, torch.Size([2, *batch_shape, 3, 4])) with self.assertRaises(RuntimeError): initialize_q_batch(X=X, acq_vals=acq_vals, n=10) @@ -186,10 +182,8 @@ def test_initialize_q_batch_topn(self): self.assertEqual(ics.dtype, X.dtype) # ensure raises correct warning acq_vals = torch.zeros(5, device=self.device, dtype=dtype) - with warnings.catch_warnings(record=True) as w: + with self.assertWarns(BadInitialCandidatesWarning): ics, _ = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2) - self.assertEqual(len(w), 1) - self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning)) self.assertEqual(ics.shape, torch.Size([2, 3, 4])) with self.assertRaises(RuntimeError): initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=10) From 975bb293e54d77c1d0e656d67c2073a6e476863b Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 5 Dec 2024 20:30:16 -0500 Subject: [PATCH 16/16] Update botorch/optim/initializers.py Co-authored-by: Sait Cakmak --- botorch/optim/initializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index b8584eab99..753ca86124 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -376,7 +376,7 @@ def gen_batch_initial_conditions( n, q, bounds_cpu.shape[-1], dtype=bounds.dtype ) X_rnd = unnormalize( - X_rnd_nlzd, bounds, update_constant_bounds=False + X_rnd_nlzd, bounds_cpu, update_constant_bounds=False ) else: X_rnd = sample_q_batches_from_polytope(