Skip to content

Commit

Permalink
fea: add pre-commit hook
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Nov 22, 2024
1 parent 5d37606 commit 5eefe8e
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 5 deletions.
28 changes: 28 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
repos:
- repo: local
hooks:
- id: check-requirements-versions
name: Check pre-commit formatting versions
entry: python scripts/check_pre_commit_reqs.py
language: python
always_run: true
pass_filenames: false
additional_dependencies:
- PyYAML

- repo: https://github.com/omnilib/ufmt
rev: v2.8.0
hooks:
- id: ufmt
additional_dependencies:
- black==24.4.2
- usort==1.0.8.post1
- ruff-api==0.1.0
args: [format]

- repo: https://github.com/pycqa/flake8
rev: 7.0.0
hooks:
- id: flake8
additional_dependencies:
- flake8-docstrings
7 changes: 7 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ flake8 .

from the repository root.

#### Pre-commit hooks

Contributors can use [pre-commit](https://pre-commit.com/) to run `ufmt` and
`flake8` as part of the commit process. To install the hooks, install `pre-commit`
via `pip install pre-commit` and run `pre-commit install` from the repository
root.

#### Docstring formatting

BoTorch uses
Expand Down
8 changes: 5 additions & 3 deletions botorch/posteriors/posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from __future__ import annotations

from abc import ABC, abstractmethod, abstractproperty
from abc import ABC, abstractmethod

import torch
from torch import Tensor
Expand Down Expand Up @@ -77,12 +77,14 @@ def sample(self, sample_shape: torch.Size | None = None) -> Tensor:
with torch.no_grad():
return self.rsample(sample_shape=sample_shape)

@abstractproperty
@property
@abstractmethod
def device(self) -> torch.device:
r"""The torch device of the distribution."""
pass # pragma: no cover

@abstractproperty
@property
@abstractmethod
def dtype(self) -> torch.dtype:
r"""The torch dtype of the distribution."""
pass # pragma: no cover
Expand Down
5 changes: 3 additions & 2 deletions botorch/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import math
import warnings
from abc import abstractproperty
from abc import abstractmethod
from collections import OrderedDict
from collections.abc import Sequence
from itertools import product
Expand Down Expand Up @@ -138,7 +138,8 @@ def test_forward_and_evaluate_true(self):
)
self.assertEqual(res.shape, batch_shape + tail_shape)

@abstractproperty
@property
@abstractmethod
def functions(self) -> Sequence[BaseTestProblem]:
# The functions that should be tested. Typically defined as a class
# attribute on the test case subclassing this class.
Expand Down
98 changes: 98 additions & 0 deletions scripts/check_pre_commit_reqs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#!/usr/bin/env python3

import sys
from pathlib import Path

import yaml


def parse_requirements(filepath):
"""Parse requirements file and return a dict of package versions."""
versions = {}
with open(filepath) as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
# Handle different requirement formats
if "==" in line:
pkg, version = line.split("==")
versions[pkg.strip().lower()] = version.strip()
return versions


def parse_precommit_config(filepath):
"""Parse pre-commit config and extract ufmt repo rev and hook dependencies."""
with open(filepath) as f:
config = yaml.safe_load(f)

versions = {}
for repo in config["repos"]:
if "https://github.com/omnilib/ufmt" in repo.get("repo", ""):
# Get ufmt version from rev - assumes fixed format: vX.Y.Z
versions["ufmt"] = repo.get("rev", "").replace("v", "")

# Get dependency versions
for hook in repo["hooks"]:
if hook["id"] == "ufmt":
for dep in hook.get("additional_dependencies", []):
if "==" in dep:
pkg, version = dep.split("==")
versions[pkg.strip().lower()] = version.strip()
break
return versions


def main():
# Find the pre-commit config and requirements files
config_file = Path(".pre-commit-config.yaml")
requirements_file = Path("requirements-fmt.txt")

if not config_file.exists():
print(f"Error: Could not find {config_file}")
sys.exit(1)

if not requirements_file.exists():
print(f"Error: Could not find {requirements_file}")
sys.exit(1)

# Parse both files
req_versions = parse_requirements(requirements_file)
config_versions = parse_precommit_config(config_file)

# Packages to check
packages = ["ufmt", "black", "usort", "ruff-api"]

# Check versions
mismatches = []
for pkg in packages:
req_ver = req_versions.get(pkg, None)
config_ver = config_versions.get(pkg, None)

if req_ver != config_ver:
found_version_str = f"{pkg}: {requirements_file} has {req_ver},"
if pkg == "ufmt":
mismatches.append(
f"{found_version_str} pre-commit config rev has v{config_ver}"
)
else:
mismatches.append(
f"{found_version_str} pre-commit config has {config_ver}"
)

# Report results
if mismatches:
msg_str = "".join("\n\t" + msg for msg in mismatches)
print(
f"Version mismatches found:{msg_str}"
"\nPlease update the versions in `.pre-commit-config.yaml` to be "
"consistent with those in `requirements-fmt.txt` (source of truth)."
"\nNote: all versions must be pinned exactly ('==X.Y.Z') in both files."
)
sys.exit(1)
else:
print("All versions match!")
sys.exit(0)


if __name__ == "__main__":
main()

0 comments on commit 5eefe8e

Please sign in to comment.