Skip to content

Commit

Permalink
rebase: correct the merge issues with rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Nov 18, 2024
1 parent 92d73e4 commit 4b1196b
Show file tree
Hide file tree
Showing 4 changed files with 356 additions and 80 deletions.
244 changes: 195 additions & 49 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,44 @@ def __post_init__(self) -> None:
"3-dimensional. Its shape is "
f"{batch_initial_conditions_shape}."
)

if batch_initial_conditions_shape[-1] != d:
raise ValueError(
f"batch_initial_conditions.shape[-1] must be {d}. The "
f"shape is {batch_initial_conditions_shape}."
)

if (
self.raw_samples is not None
and (self.raw_samples - batch_initial_conditions_shape[-2]) > 0
):
if len(batch_initial_conditions_shape) == 2:
warnings.warn(
"If using `batch_initial_conditions` together with "
"`raw_samples`, the `batch_initial_conditions` must "
"have 3 dimensions. Defaulting to old behavior of ignoring "
"`raw_samples` by setting it to None.",
RuntimeWarning,
)
# Use object.__setattr__ to bypass immutability and set a value
object.__setattr__(self, "raw_samples", None)

elif (
len(batch_initial_conditions_shape) == 3 # should always be true
and self.num_restarts is not None
and batch_initial_conditions_shape[0] not in [1, self.num_restarts]
):
warnings.warn(
"If using `batch_initial_conditions` together with "
"`raw_samples`, the first repeat dimension of "
"`batch_initial_conditions` must match `num_restarts` "
"or be 1 to allow repeat matching. Defaulting to old "
"behavior of ignoring `raw_samples` by setting it to None.",
RuntimeWarning,
)
# Use object.__setattr__ to bypass immutability and set a value
object.__setattr__(self, "raw_samples", None)

elif self.ic_generator is None:
if self.nonlinear_inequality_constraints is not None:
raise RuntimeError(
Expand Down Expand Up @@ -248,32 +280,78 @@ def _optimize_acqf_sequential_q(
if base_X_pending is not None
else candidates
)
logger.info(f"Generated sequential candidate {i+1} of {opt_inputs.q}")
logger.info(f"Generated sequential candidate {i + 1} of {opt_inputs.q}")
opt_inputs.acq_function.set_X_pending(base_X_pending)
return candidates, torch.stack(acq_value_list)


def _combine_initial_conditions(
provided_initial_conditions: Tensor | None = None,
generated_initial_conditions: Tensor | None = None,
num_restarts: int | None = None,
) -> Tensor:
if (
provided_initial_conditions is not None
and generated_initial_conditions is not None
):
if ( # Repeat the provided initial conditions to match the number of restarts
provided_initial_conditions.shape[0] == 1
and num_restarts is not None
and num_restarts > 1
):
provided_initial_conditions = provided_initial_conditions.repeat(
num_restarts, *([1] * (provided_initial_conditions.dim() - 1))
)
initial_conditions = torch.cat(
[provided_initial_conditions, generated_initial_conditions], dim=-2
)
perm = torch.randperm(
initial_conditions.shape[-2], device=initial_conditions.device
)
return initial_conditions.gather(
-2, perm.unsqueeze(-1).expand_as(initial_conditions)
)
elif provided_initial_conditions is not None:
return provided_initial_conditions
elif generated_initial_conditions is not None:
return generated_initial_conditions
else:
raise ValueError(
"Either `batch_initial_conditions` or `raw_samples` must be set."
)


def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor]:
options = opt_inputs.options or {}

initial_conditions_provided = opt_inputs.batch_initial_conditions is not None
required_raw_samples = opt_inputs.raw_samples
generated_initial_conditions = None

if initial_conditions_provided:
batch_initial_conditions = opt_inputs.batch_initial_conditions
else:
# pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call.
batch_initial_conditions = opt_inputs.get_ic_generator()(
acq_function=opt_inputs.acq_function,
bounds=opt_inputs.bounds,
q=opt_inputs.q,
num_restarts=opt_inputs.num_restarts,
raw_samples=opt_inputs.raw_samples,
fixed_features=opt_inputs.fixed_features,
options=options,
inequality_constraints=opt_inputs.inequality_constraints,
equality_constraints=opt_inputs.equality_constraints,
**opt_inputs.ic_gen_kwargs,
)
if required_raw_samples is not None:
if opt_inputs.batch_initial_conditions is not None:
required_raw_samples -= opt_inputs.batch_initial_conditions.shape[-2]

if required_raw_samples > 0:
# pyre-ignore[28]: Unexpected keyword argument `acq_function`
# to anonymous call.
generated_initial_conditions = opt_inputs.get_ic_generator()(
acq_function=opt_inputs.acq_function,
bounds=opt_inputs.bounds,
q=opt_inputs.q,
num_restarts=opt_inputs.num_restarts,
raw_samples=required_raw_samples,
fixed_features=opt_inputs.fixed_features,
options=options,
inequality_constraints=opt_inputs.inequality_constraints,
equality_constraints=opt_inputs.equality_constraints,
**opt_inputs.ic_gen_kwargs,
)

batch_initial_conditions = _combine_initial_conditions(
provided_initial_conditions=opt_inputs.batch_initial_conditions,
generated_initial_conditions=generated_initial_conditions,
num_restarts=opt_inputs.num_restarts,
)

batch_limit: int = options.get(
"batch_limit",
Expand Down Expand Up @@ -325,7 +403,7 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
opt_warnings += ws
batch_candidates_list.append(batch_candidates_curr)
batch_acq_values_list.append(batch_acq_values_curr)
logger.info(f"Generated candidate batch {i+1} of {len(batched_ics)}.")
logger.info(f"Generated candidate batch {i + 1} of {len(batched_ics)}.")

batch_candidates = torch.cat(batch_candidates_list)
has_scalars = batch_acq_values_list[0].ndim == 0
Expand All @@ -344,31 +422,38 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
first_warn_msg = (
"Optimization failed in `gen_candidates_scipy` with the following "
f"warning(s):\n{[w.message for w in ws]}\nBecause you specified "
"`batch_initial_conditions`, optimization will not be retried with "
"new initial conditions and will proceed with the current solution."
" Suggested remediation: Try again with different "
"`batch_initial_conditions`, or don't provide `batch_initial_conditions.`"
if initial_conditions_provided
"`batch_initial_conditions`>`raw_samples`, optimization will not "
"be retried with new initial conditions and will proceed with the "
"current solution. Suggested remediation: Try again with different "
"`batch_initial_conditions`, don't provide `batch_initial_conditions`, "
"or increase `raw_samples`.`"
if required_raw_samples is not None and required_raw_samples <= 0
else "Optimization failed in `gen_candidates_scipy` with the following "
f"warning(s):\n{[w.message for w in ws]}\nTrying again with a new "
"set of initial conditions."
)
warnings.warn(first_warn_msg, RuntimeWarning, stacklevel=2)

if not initial_conditions_provided:
batch_initial_conditions = opt_inputs.get_ic_generator()(
if required_raw_samples is not None and required_raw_samples > 0:
generated_initial_conditions = opt_inputs.get_ic_generator()(
acq_function=opt_inputs.acq_function,
bounds=opt_inputs.bounds,
q=opt_inputs.q,
num_restarts=opt_inputs.num_restarts,
raw_samples=opt_inputs.raw_samples,
raw_samples=required_raw_samples,
fixed_features=opt_inputs.fixed_features,
options=options,
inequality_constraints=opt_inputs.inequality_constraints,
equality_constraints=opt_inputs.equality_constraints,
**opt_inputs.ic_gen_kwargs,
)

batch_initial_conditions = _combine_initial_conditions(
provided_initial_conditions=opt_inputs.batch_initial_conditions,
generated_initial_conditions=generated_initial_conditions,
num_restarts=opt_inputs.num_restarts,
)

batch_candidates, batch_acq_values, ws = _optimize_batch_candidates()

optimization_warning_raised = any(
Expand Down Expand Up @@ -1177,7 +1262,7 @@ def _gen_batch_initial_conditions_local_search(
inequality_constraints: list[tuple[Tensor, Tensor, float]],
min_points: int,
max_tries: int = 100,
):
) -> Tensor:
"""Generate initial conditions for local search."""
device = discrete_choices[0].device
dtype = discrete_choices[0].dtype
Expand All @@ -1197,6 +1282,63 @@ def _gen_batch_initial_conditions_local_search(
raise RuntimeError(f"Failed to generate at least {min_points} initial conditions")


def _gen_starting_points_local_search(
discrete_choices: list[Tensor],
raw_samples: int,
batch_initial_conditions: Tensor,
X_avoid: Tensor,
inequality_constraints: list[tuple[Tensor, Tensor, float]],
min_points: int,
acq_function: AcquisitionFunction,
max_batch_size: int = 2048,
max_tries: int = 100,
) -> Tensor:
required_min_points = min_points
provided_X0 = None
generated_X0 = None

if batch_initial_conditions is not None:
provided_X0 = _filter_invalid(
X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid
)
provided_X0 = _filter_infeasible(
X=provided_X0, inequality_constraints=inequality_constraints
).unsqueeze(1)
required_min_points -= batch_initial_conditions.shape[0]

if required_min_points > 0:
generated_X0 = _gen_batch_initial_conditions_local_search(
discrete_choices=discrete_choices,
raw_samples=raw_samples,
X_avoid=X_avoid,
inequality_constraints=inequality_constraints,
min_points=min_points,
max_tries=max_tries,
)

# pick the best starting points
with torch.no_grad():
acqvals_init = _split_batch_eval_acqf(
acq_function=acq_function,
X=generated_X0.unsqueeze(1),
max_batch_size=max_batch_size,
).unsqueeze(-1)

generated_X0 = generated_X0[
acqvals_init.topk(k=min_points, largest=True, dim=0).indices
]

# permute to match the required behavior of _combine_initial_conditions
return _combine_initial_conditions(
provided_initial_conditions=provided_X0.permute(1, 0, 2)
if provided_X0 is not None
else None,
generated_initial_conditions=generated_X0.permute(1, 0, 2)
if generated_X0 is not None
else None,
).permute(1, 0, 2)


def optimize_acqf_discrete_local_search(
acq_function: AcquisitionFunction,
discrete_choices: list[Tensor],
Expand All @@ -1207,6 +1349,7 @@ def optimize_acqf_discrete_local_search(
X_avoid: Tensor | None = None,
batch_initial_conditions: Tensor | None = None,
max_batch_size: int = 2048,
max_tries: int = 100,
unique: bool = True,
) -> tuple[Tensor, Tensor]:
r"""Optimize acquisition function over a lattice.
Expand Down Expand Up @@ -1238,6 +1381,8 @@ def optimize_acqf_discrete_local_search(
max_batch_size: The maximum number of choices to evaluate in batch.
A large limit can cause excessive memory usage if the model has
a large training set.
max_tries: Maximum number of iterations to try when generating initial
conditions.
unique: If True return unique choices, o/w choices may be repeated
(only relevant if `q > 1`).
Expand All @@ -1247,6 +1392,16 @@ def optimize_acqf_discrete_local_search(
- a `q x d`-dim tensor of generated candidates.
- an associated acquisition value.
"""
if batch_initial_conditions is not None:
if not (
len(batch_initial_conditions.shape) == 3
and batch_initial_conditions.shape[-2] == 1
):
raise ValueError(
"batch_initial_conditions must have shape `n x 1 x d` if "
f"given (recieved {batch_initial_conditions})."
)

candidate_list = []
base_X_pending = acq_function.X_pending if q > 1 else None
base_X_avoid = X_avoid
Expand All @@ -1259,27 +1414,18 @@ def optimize_acqf_discrete_local_search(
inequality_constraints = inequality_constraints or []
for i in range(q):
# generate some starting points
if i == 0 and batch_initial_conditions is not None:
X0 = _filter_invalid(X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid)
X0 = _filter_infeasible(
X=X0, inequality_constraints=inequality_constraints
).unsqueeze(1)
else:
X_init = _gen_batch_initial_conditions_local_search(
discrete_choices=discrete_choices,
raw_samples=raw_samples,
X_avoid=X_avoid,
inequality_constraints=inequality_constraints,
min_points=num_restarts,
)
# pick the best starting points
with torch.no_grad():
acqvals_init = _split_batch_eval_acqf(
acq_function=acq_function,
X=X_init.unsqueeze(1),
max_batch_size=max_batch_size,
).unsqueeze(-1)
X0 = X_init[acqvals_init.topk(k=num_restarts, largest=True, dim=0).indices]
X0 = _gen_starting_points_local_search(
discrete_choices=discrete_choices,
raw_samples=raw_samples,
batch_initial_conditions=batch_initial_conditions,
X_avoid=X_avoid,
inequality_constraints=inequality_constraints,
min_points=num_restarts,
acq_function=acq_function,
max_batch_size=max_batch_size,
max_tries=max_tries,
)
batch_initial_conditions = None

# optimize from the best starting points
best_xs = torch.zeros(len(X0), dim, device=device, dtype=dtype)
Expand Down
5 changes: 4 additions & 1 deletion botorch/optim/optimize_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def optimize_acqf_homotopy(
"""
shared_optimize_acqf_kwargs = {
"num_restarts": num_restarts,
"raw_samples": raw_samples,
"inequality_constraints": inequality_constraints,
"equality_constraints": equality_constraints,
"nonlinear_inequality_constraints": nonlinear_inequality_constraints,
Expand All @@ -181,11 +180,14 @@ def optimize_acqf_homotopy(
homotopy.restart()

while not homotopy.should_stop:
# After the first iteration we don't want to generate new raw samples
requested_raw_samples = raw_samples if candidates is None else None
candidates, acq_values = optimize_acqf(
acq_function=acq_function,
bounds=bounds,
q=1,
options=options,
raw_samples=requested_raw_samples,
batch_initial_conditions=candidates,
**shared_optimize_acqf_kwargs,
)
Expand All @@ -204,6 +206,7 @@ def optimize_acqf_homotopy(
bounds=bounds,
q=1,
options=final_options,
raw_samples=None,
batch_initial_conditions=candidates,
**shared_optimize_acqf_kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions botorch/optim/optimize_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ def continuous_step(
updated_opt_inputs = dataclasses.replace(
opt_inputs,
q=1,
raw_samples=None,
num_restarts=1,
batch_initial_conditions=current_x.unsqueeze(0),
fixed_features={
Expand Down
Loading

0 comments on commit 4b1196b

Please sign in to comment.