diff --git a/botorch/acquisition/input_constructors.py b/botorch/acquisition/input_constructors.py index deb8433321..2de8febc10 100644 --- a/botorch/acquisition/input_constructors.py +++ b/botorch/acquisition/input_constructors.py @@ -1779,7 +1779,7 @@ def optimize_objective( bounds=free_feature_bounds, q=q, num_restarts=optimizer_options.get("num_restarts", 60), - raw_samples=optimizer_options.get("raw_samples", 1024), + raw_samples=optimizer_options.get("raw_samples", 1024), # NOTE potential behaviour change options={ "batch_limit": optimizer_options.get("batch_limit", 8), "maxiter": optimizer_options.get("maxiter", 200), diff --git a/botorch/optim/optimize.py b/botorch/optim/optimize.py index 43b2c634d2..e6fdf2e099 100644 --- a/botorch/optim/optimize.py +++ b/botorch/optim/optimize.py @@ -114,6 +114,19 @@ def __post_init__(self) -> None: f"shape is {batch_initial_conditions_shape}." ) + if ( + self.raw_samples is not None + and (self.raw_samples - batch_initial_conditions_shape[-2]) > 0 + and len(batch_initial_conditions_shape) == 3 + and self.num_restarts is not None + and self.num_restarts != batch_initial_conditions_shape[0] + ): + raise ValueError( + "If using `batch_initial_conditions` together with `raw_samples`, " + "the first repeat dimension of `batch_initial_conditions` must " + "match `num_restarts`." + ) + elif self.ic_generator is None: if self.nonlinear_inequality_constraints is not None: raise RuntimeError( @@ -253,22 +266,44 @@ def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor initial_conditions_provided = opt_inputs.batch_initial_conditions is not None + required_raw_samples = opt_inputs.raw_samples if initial_conditions_provided: - batch_initial_conditions = opt_inputs.batch_initial_conditions + provided_initial_conditions = opt_inputs.batch_initial_conditions + if opt_inputs.raw_samples is not None: + required_raw_samples -= provided_initial_conditions.shape[-2] else: + provided_initial_conditions = None + + if required_raw_samples is not None and required_raw_samples > 0: # pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call. - batch_initial_conditions = opt_inputs.get_ic_generator()( + generated_initial_conditions = opt_inputs.get_ic_generator()( acq_function=opt_inputs.acq_function, bounds=opt_inputs.bounds, q=opt_inputs.q, num_restarts=opt_inputs.num_restarts, - raw_samples=opt_inputs.raw_samples, + raw_samples=required_raw_samples, fixed_features=opt_inputs.fixed_features, options=options, inequality_constraints=opt_inputs.inequality_constraints, equality_constraints=opt_inputs.equality_constraints, **opt_inputs.ic_gen_kwargs, ) + else: + generated_initial_conditions = None + + if provided_initial_conditions is not None and generated_initial_conditions is not None: + provided_initial_conditions = provided_initial_conditions.repeat( + opt_inputs.num_restarts, *([1] * (provided_initial_conditions.dim()-1)) + ) + batch_initial_conditions = torch.cat( + [provided_initial_conditions, generated_initial_conditions], dim=-2 + ) # should this be shuffled? + elif provided_initial_conditions is not None: + batch_initial_conditions = provided_initial_conditions + elif generated_initial_conditions is not None: + batch_initial_conditions = generated_initial_conditions + else: + raise ValueError("Either `batch_initial_conditions` or `raw_samples` must be set.") batch_limit: int = options.get( "batch_limit", @@ -339,24 +374,25 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]: first_warn_msg = ( "Optimization failed in `gen_candidates_scipy` with the following " f"warning(s):\n{[w.message for w in ws]}\nBecause you specified " - "`batch_initial_conditions`, optimization will not be retried with " - "new initial conditions and will proceed with the current solution." - " Suggested remediation: Try again with different " - "`batch_initial_conditions`, or don't provide `batch_initial_conditions.`" - if initial_conditions_provided + "`batch_initial_conditions`>`raw_samples`, optimization will not " + "be retried with new initial conditions and will proceed with the " + "current solution. Suggested remediation: Try again with different " + "`batch_initial_conditions`, don't provide `batch_initial_conditions, " + "or increase `raw_samples`.`" + if required_raw_samples is not None and required_raw_samples > 0 else "Optimization failed in `gen_candidates_scipy` with the following " f"warning(s):\n{[w.message for w in ws]}\nTrying again with a new " "set of initial conditions." ) warnings.warn(first_warn_msg, RuntimeWarning, stacklevel=2) - if not initial_conditions_provided: - batch_initial_conditions = opt_inputs.get_ic_generator()( + if required_raw_samples is not None and required_raw_samples > 0: + generated_initial_conditions = opt_inputs.get_ic_generator()( acq_function=opt_inputs.acq_function, bounds=opt_inputs.bounds, q=opt_inputs.q, num_restarts=opt_inputs.num_restarts, - raw_samples=opt_inputs.raw_samples, + raw_samples=required_raw_samples, fixed_features=opt_inputs.fixed_features, options=options, inequality_constraints=opt_inputs.inequality_constraints, @@ -364,6 +400,13 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]: **opt_inputs.ic_gen_kwargs, ) + if provided_initial_conditions is not None: + batch_initial_conditions = torch.cat( + [provided_initial_conditions, generated_initial_conditions], dim=-2 + ) # should this be shuffled? + else: + batch_initial_conditions = generated_initial_conditions + batch_candidates, batch_acq_values, ws = _optimize_batch_candidates() optimization_warning_raised = any( @@ -1199,11 +1242,46 @@ def optimize_acqf_discrete_local_search( inequality_constraints = inequality_constraints or [] for i in range(q): # generate some starting points - if i == 0 and batch_initial_conditions is not None: - X0 = _filter_invalid(X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid) - X0 = _filter_infeasible( - X=X0, inequality_constraints=inequality_constraints - ).unsqueeze(1) + if i == 0: + + if batch_initial_conditions is not None: + provided_X0 = _filter_invalid(X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid) + provided_X0 = _filter_infeasible( + X=provided_X0, inequality_constraints=inequality_constraints + ).unsqueeze(1) + if raw_samples is not None: + required_raw_samples = raw_samples - batch_initial_conditions.shape[-2] + else: + required_raw_samples = raw_samples + provided_X0 = None + + if required_raw_samples > 0: + X_init = _gen_batch_initial_conditions_local_search( + discrete_choices=discrete_choices, + raw_samples=required_raw_samples, + X_avoid=X_avoid, + inequality_constraints=inequality_constraints, + min_points=num_restarts, + ) + # pick the best starting points + with torch.no_grad(): + acqvals_init = _split_batch_eval_acqf( + acq_function=acq_function, + X=X_init.unsqueeze(1), + max_batch_size=max_batch_size, + ).unsqueeze(-1) + generated_X0 = X_init[acqvals_init.topk(k=num_restarts, largest=True, dim=0).indices] + + if provided_X0 is not None and generated_X0 is not None: + provided_X0 = provided_X0.repeat(num_restarts, *([1] * (provided_X0.ndim - 1))) + X0 = torch.cat([provided_X0, generated_X0], dim=-2) + elif provided_X0 is not None: + X0 = provided_X0 + elif generated_X0 is not None: + X0 = generated_X0 + else: + raise ValueError("Either `batch_initial_conditions` or `raw_samples` must be set.") + else: X_init = _gen_batch_initial_conditions_local_search( discrete_choices=discrete_choices, diff --git a/test/optim/test_optimize.py b/test/optim/test_optimize.py index 23bc04d2b5..abb0e6e13c 100644 --- a/test/optim/test_optimize.py +++ b/test/optim/test_optimize.py @@ -167,7 +167,26 @@ def test_optimize_acqf_joint( cnt += 1 self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) - # test generation with provided initial conditions + # test generation with provided initial conditions less than raw_samples + candidates, acq_vals = optimize_acqf( + acq_function=mock_acq_function, + bounds=bounds, + q=q, + num_restarts=num_restarts, + raw_samples=3, + options=options, + return_best_only=False, + batch_initial_conditions=torch.zeros( + num_restarts, q, 3, device=self.device, dtype=dtype + ), + gen_candidates=mock_gen_candidates, + ) + self.assertTrue(torch.equal(candidates, mock_candidates)) + self.assertTrue(torch.equal(acq_vals, mock_acq_values)) + cnt += 1 + self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) + + # test generation with provided initial conditions greater than raw_samples candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, @@ -543,7 +562,15 @@ def test_optimize_acqf_batch_limit(self) -> None: gen_candidates=gen_candidates, batch_initial_conditions=ics, ) - expected_shape = (num_restarts,) if ics is None else (ics.shape[0],) + expected_shape = ( + (num_restarts,) + if ics is None + else ( + (ics.shape[0],) + if ics.shape[0] > raw_samples + else (ics.shape[0]*num_restarts,) + ) + ) self.assertEqual(acq_value_list.shape, expected_shape) def test_optimize_acqf_runs_given_batch_initial_conditions(self): @@ -635,11 +662,12 @@ def test_optimize_acqf_warns_on_opt_failure(self): "Optimization failed in `gen_candidates_scipy` with the following " "warning(s):\n[OptimizationWarning('Optimization failed within " "`scipy.optimize.minimize` with status 2 and message ABNORMAL_TERMINATION" - "_IN_LNSRCH.')]\nBecause you specified `batch_initial_conditions`, " - "optimization will not be retried with new initial conditions and will " - "proceed with the current solution. Suggested remediation: Try again with " - "different `batch_initial_conditions`, or don't provide " - "`batch_initial_conditions.`" + "_IN_LNSRCH.')]\nBecause you specified " + "`batch_initial_conditions`>`raw_samples`, optimization will not " + "be retried with new initial conditions and will proceed with the " + "current solution. Suggested remediation: Try again with different " + "`batch_initial_conditions`, don't provide `batch_initial_conditions, " + "or increase `raw_samples`.`" ) expected_warning_raised = any( issubclass(w.category, RuntimeWarning) and message in str(w.message) @@ -1841,3 +1869,7 @@ def my_gen(): ) ic_generator = opt_inputs.get_ic_generator() self.assertIs(ic_generator, my_gen) + +if __name__ == "__main__": + import pytest + pytest.main([__file__]) \ No newline at end of file