Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pre-commit hook to stop linting errors being pushed. #2632

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
repos:
- repo: local
hooks:
- id: check-requirements-versions
name: Check pre-commit formatting versions
entry: python scripts/check_pre_commit_reqs.py
language: python
always_run: true
pass_filenames: false
additional_dependencies:
- PyYAML

- repo: https://github.com/omnilib/ufmt
rev: v2.8.0
hooks:
- id: ufmt
additional_dependencies:
- black==24.4.2
- usort==1.0.8.post1
- ruff-api==0.1.0
args: [format]

- repo: https://github.com/pycqa/flake8
rev: 7.0.0
hooks:
- id: flake8
additional_dependencies:
- flake8-docstrings
7 changes: 7 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ flake8 .

from the repository root.

#### Pre-commit hooks

Contributors can use [pre-commit](https://pre-commit.com/) to run `ufmt` and
`flake8` as part of the commit process. To install the hooks, install `pre-commit`
via `pip install pre-commit` and run `pre-commit install` from the repository
root.

#### Docstring formatting

BoTorch uses
Expand Down
2 changes: 1 addition & 1 deletion botorch/acquisition/analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,7 @@ def _log_ei_helper(u: Tensor) -> Tensor:
if not (u.dtype == torch.float32 or u.dtype == torch.float64):
raise TypeError(
f"LogExpectedImprovement only supports torch.float32 and torch.float64 "
f"dtypes, but received {u.dtype = }."
f"dtypes, but received {u.dtype=}."
)
# The function has two branching decisions. The first is u < bound, and in this
# case, just taking the logarithm of the naive _ei_helper implementation works.
Expand Down
4 changes: 2 additions & 2 deletions botorch/acquisition/logei.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def check_tau(tau: FloatOrTensor, name: str) -> FloatOrTensor:
"""Checks the validity of the tau arguments of the functions below, and returns
`tau` if it is valid."""
if isinstance(tau, Tensor) and tau.numel() != 1:
raise ValueError(name + f" is not a scalar: {tau.numel() = }.")
raise ValueError(f"{name} is not a scalar: {tau.numel()=}.")
if not (tau > 0):
raise ValueError(name + f" is non-positive: {tau = }.")
raise ValueError(f"{name} is non-positive: {tau=}.")
return tau
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def _split_hvkg_fantasy_points(
"""
if n_f * num_pareto > X.size(-2):
raise ValueError(
f"`n_f*num_pareto` ({n_f*num_pareto}) must be less than"
f"`n_f*num_pareto` ({n_f * num_pareto}) must be less than"
f" the `q`-batch dimension of `X` ({X.size(-2)})."
)
split_sizes = [X.size(-2) - n_f * num_pareto, n_f * num_pareto]
Expand Down
4 changes: 2 additions & 2 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def _optimize_acqf_sequential_q(
if base_X_pending is not None
else candidates
)
logger.info(f"Generated sequential candidate {i+1} of {opt_inputs.q}")
logger.info(f"Generated sequential candidate {i + 1} of {opt_inputs.q}")
opt_inputs.acq_function.set_X_pending(base_X_pending)
return candidates, torch.stack(acq_value_list)

Expand Down Expand Up @@ -325,7 +325,7 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
opt_warnings += ws
batch_candidates_list.append(batch_candidates_curr)
batch_acq_values_list.append(batch_acq_values_curr)
logger.info(f"Generated candidate batch {i+1} of {len(batched_ics)}.")
logger.info(f"Generated candidate batch {i + 1} of {len(batched_ics)}.")

batch_candidates = torch.cat(batch_candidates_list)
has_scalars = batch_acq_values_list[0].ndim == 0
Expand Down
8 changes: 5 additions & 3 deletions botorch/posteriors/posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from __future__ import annotations

from abc import ABC, abstractmethod, abstractproperty
from abc import ABC, abstractmethod

import torch
from torch import Tensor
Expand Down Expand Up @@ -77,12 +77,14 @@ def sample(self, sample_shape: torch.Size | None = None) -> Tensor:
with torch.no_grad():
return self.rsample(sample_shape=sample_shape)

@abstractproperty
@property
@abstractmethod
def device(self) -> torch.device:
r"""The torch device of the distribution."""
pass # pragma: no cover

@abstractproperty
@property
@abstractmethod
def dtype(self) -> torch.dtype:
r"""The torch dtype of the distribution."""
pass # pragma: no cover
Expand Down
4 changes: 2 additions & 2 deletions botorch/utils/probability/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def log_ndtr(x: Tensor) -> Tensor:
if not (x.dtype == torch.float32 or x.dtype == torch.float64):
raise TypeError(
f"log_Phi only supports torch.float32 and torch.float64 "
f"dtypes, but received {x.dtype = }."
f"dtypes, but received {x.dtype=}."
)
neg_inv_sqrt_2, log_2 = get_constants_like((_neg_inv_sqrt_2, _log_2), x)
return log_erfc(neg_inv_sqrt_2 * x) - log_2
Expand All @@ -181,7 +181,7 @@ def log_erfc(x: Tensor) -> Tensor:
if not (x.dtype == torch.float32 or x.dtype == torch.float64):
raise TypeError(
f"log_erfc only supports torch.float32 and torch.float64 "
f"dtypes, but received {x.dtype = }."
f"dtypes, but received {x.dtype=}."
)
is_pos = x > 0
x_pos = x.masked_fill(~is_pos, 0)
Expand Down
5 changes: 3 additions & 2 deletions botorch/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import math
import warnings
from abc import abstractproperty
from abc import abstractmethod
from collections import OrderedDict
from collections.abc import Sequence
from itertools import product
Expand Down Expand Up @@ -138,7 +138,8 @@ def test_forward_and_evaluate_true(self):
)
self.assertEqual(res.shape, batch_shape + tail_shape)

@abstractproperty
@property
@abstractmethod
def functions(self) -> Sequence[BaseTestProblem]:
# The functions that should be tested. Typically defined as a class
# attribute on the test case subclassing this class.
Expand Down
98 changes: 98 additions & 0 deletions scripts/check_pre_commit_reqs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#!/usr/bin/env python3
CompRhys marked this conversation as resolved.
Show resolved Hide resolved

import sys
from pathlib import Path

import yaml


def parse_requirements(filepath):
"""Parse requirements file and return a dict of package versions."""
versions = {}
with open(filepath) as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
# Handle different requirement formats
if "==" in line:
pkg, version = line.split("==")
versions[pkg.strip().lower()] = version.strip()
return versions


def parse_precommit_config(filepath):
"""Parse pre-commit config and extract ufmt repo rev and hook dependencies."""
with open(filepath) as f:
config = yaml.safe_load(f)

versions = {}
for repo in config["repos"]:
if "https://github.com/omnilib/ufmt" in repo.get("repo", ""):
# Get ufmt version from rev - assumes fixed format: vX.Y.Z
versions["ufmt"] = repo.get("rev", "").replace("v", "")

# Get dependency versions
for hook in repo["hooks"]:
if hook["id"] == "ufmt":
for dep in hook.get("additional_dependencies", []):
if "==" in dep:
pkg, version = dep.split("==")
versions[pkg.strip().lower()] = version.strip()
break
return versions


def main():
# Find the pre-commit config and requirements files
config_file = Path(".pre-commit-config.yaml")
requirements_file = Path("requirements-fmt.txt")

if not config_file.exists():
print(f"Error: Could not find {config_file}")
sys.exit(1)

if not requirements_file.exists():
print(f"Error: Could not find {requirements_file}")
sys.exit(1)

# Parse both files
req_versions = parse_requirements(requirements_file)
config_versions = parse_precommit_config(config_file)

# Packages to check
packages = ["ufmt", "black", "usort", "ruff-api"]

# Check versions
mismatches = []
for pkg in packages:
req_ver = req_versions.get(pkg, None)
config_ver = config_versions.get(pkg, None)

if req_ver != config_ver:
found_version_str = f"{pkg}: {requirements_file} has {req_ver},"
if pkg == "ufmt":
mismatches.append(
f"{found_version_str} pre-commit config rev has v{config_ver}"
)
else:
mismatches.append(
f"{found_version_str} pre-commit config has {config_ver}"
)

# Report results
if mismatches:
msg_str = "".join("\n\t" + msg for msg in mismatches)
print(
f"Version mismatches found:{msg_str}"
"\nPlease update the versions in `.pre-commit-config.yaml` to be "
"consistent with those in `requirements-fmt.txt` (source of truth)."
"\nNote: all versions must be pinned exactly ('==X.Y.Z') in both files."
)
sys.exit(1)
else:
print("All versions match!")
sys.exit(0)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def test_split_hvkg_fantasy_points(self):
n_f = 100
num_pareto = 3
msg = (
rf".*\({n_f*num_pareto}\) must be less than"
rf".*\({n_f * num_pareto}\) must be less than"
rf" the `q`-batch dimension of `X` \({X.size(-2)}\)\."
)
with self.assertRaisesRegex(ValueError, msg):
Expand Down
2 changes: 1 addition & 1 deletion test/utils/probability/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def test_gaussian_probabilities(self) -> None:

float16_msg = (
"only supports torch.float32 and torch.float64 dtypes, but received "
"x.dtype = torch.float16."
"x.dtype=torch.float16."
)
with self.assertRaisesRegex(TypeError, expected_regex=float16_msg):
log_erfc(torch.tensor(1.0, dtype=torch.float16, device=self.device))
Expand Down