Skip to content

Commit

Permalink
fea: apply flake8 fixes, update test to account for f"{xyz=}" != f"{x…
Browse files Browse the repository at this point in the history
…yz = }"
  • Loading branch information
CompRhys committed Nov 22, 2024
1 parent 5eefe8e commit 4641d07
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion botorch/acquisition/analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,7 @@ def _log_ei_helper(u: Tensor) -> Tensor:
if not (u.dtype == torch.float32 or u.dtype == torch.float64):
raise TypeError(
f"LogExpectedImprovement only supports torch.float32 and torch.float64 "
f"dtypes, but received {u.dtype = }."
f"dtypes, but received {u.dtype=}."
)
# The function has two branching decisions. The first is u < bound, and in this
# case, just taking the logarithm of the naive _ei_helper implementation works.
Expand Down
4 changes: 2 additions & 2 deletions botorch/acquisition/logei.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def check_tau(tau: FloatOrTensor, name: str) -> FloatOrTensor:
"""Checks the validity of the tau arguments of the functions below, and returns
`tau` if it is valid."""
if isinstance(tau, Tensor) and tau.numel() != 1:
raise ValueError(name + f" is not a scalar: {tau.numel() = }.")
raise ValueError(f"{name} is not a scalar: {tau.numel()=}.")
if not (tau > 0):
raise ValueError(name + f" is non-positive: {tau = }.")
raise ValueError(f"{name} is non-positive: {tau=}.")
return tau
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def _split_hvkg_fantasy_points(
"""
if n_f * num_pareto > X.size(-2):
raise ValueError(
f"`n_f*num_pareto` ({n_f*num_pareto}) must be less than"
f"`n_f*num_pareto` ({n_f * num_pareto}) must be less than"
f" the `q`-batch dimension of `X` ({X.size(-2)})."
)
split_sizes = [X.size(-2) - n_f * num_pareto, n_f * num_pareto]
Expand Down
4 changes: 2 additions & 2 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def _optimize_acqf_sequential_q(
if base_X_pending is not None
else candidates
)
logger.info(f"Generated sequential candidate {i+1} of {opt_inputs.q}")
logger.info(f"Generated sequential candidate {i + 1} of {opt_inputs.q}")
opt_inputs.acq_function.set_X_pending(base_X_pending)
return candidates, torch.stack(acq_value_list)

Expand Down Expand Up @@ -325,7 +325,7 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
opt_warnings += ws
batch_candidates_list.append(batch_candidates_curr)
batch_acq_values_list.append(batch_acq_values_curr)
logger.info(f"Generated candidate batch {i+1} of {len(batched_ics)}.")
logger.info(f"Generated candidate batch {i + 1} of {len(batched_ics)}.")

batch_candidates = torch.cat(batch_candidates_list)
has_scalars = batch_acq_values_list[0].ndim == 0
Expand Down
4 changes: 2 additions & 2 deletions botorch/utils/probability/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def log_ndtr(x: Tensor) -> Tensor:
if not (x.dtype == torch.float32 or x.dtype == torch.float64):
raise TypeError(
f"log_Phi only supports torch.float32 and torch.float64 "
f"dtypes, but received {x.dtype = }."
f"dtypes, but received {x.dtype=}."
)
neg_inv_sqrt_2, log_2 = get_constants_like((_neg_inv_sqrt_2, _log_2), x)
return log_erfc(neg_inv_sqrt_2 * x) - log_2
Expand All @@ -181,7 +181,7 @@ def log_erfc(x: Tensor) -> Tensor:
if not (x.dtype == torch.float32 or x.dtype == torch.float64):
raise TypeError(
f"log_erfc only supports torch.float32 and torch.float64 "
f"dtypes, but received {x.dtype = }."
f"dtypes, but received {x.dtype=}."
)
is_pos = x > 0
x_pos = x.masked_fill(~is_pos, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def test_split_hvkg_fantasy_points(self):
n_f = 100
num_pareto = 3
msg = (
rf".*\({n_f*num_pareto}\) must be less than"
rf".*\({n_f * num_pareto}\) must be less than"
rf" the `q`-batch dimension of `X` \({X.size(-2)}\)\."
)
with self.assertRaisesRegex(ValueError, msg):
Expand Down
2 changes: 1 addition & 1 deletion test/utils/probability/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def test_gaussian_probabilities(self) -> None:

float16_msg = (
"only supports torch.float32 and torch.float64 dtypes, but received "
"x.dtype = torch.float16."
"x.dtype=torch.float16."
)
with self.assertRaisesRegex(TypeError, expected_regex=float16_msg):
log_erfc(torch.tensor(1.0, dtype=torch.float16, device=self.device))
Expand Down

0 comments on commit 4641d07

Please sign in to comment.