Skip to content

Commit

Permalink
Migrate rest of the package to use the new two-step benchmark setup i…
Browse files Browse the repository at this point in the history
…diom

We had relatively few consumers, namely the main module and the CLI, so this was
not as hard as previously thought.
  • Loading branch information
nicholasjng committed Dec 21, 2024
1 parent c7ba2df commit 596dc81
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 45 deletions.
2 changes: 1 addition & 1 deletion src/nnbench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .core import benchmark, parametrize, product
from .reporter import BenchmarkReporter, ConsoleReporter, FileReporter
from .runner import BenchmarkRunner
from .runner import collect, run
from .types import Benchmark, BenchmarkRecord, Memo, Parameters

__version__ = "0.4.0"
8 changes: 4 additions & 4 deletions src/nnbench/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys
from typing import Any

from nnbench import BenchmarkRunner, ConsoleReporter, __version__
from nnbench import ConsoleReporter, __version__, collect, run
from nnbench.config import NNBenchConfig, parse_nnbench_config
from nnbench.reporter import FileReporter

Expand Down Expand Up @@ -205,9 +205,9 @@ def main() -> int:
else:
context[k] = v

record = BenchmarkRunner().run(
args.benchmarks,
tags=tuple(args.tags),
benchmarks = collect(args.benchmarks, tags=tuple(args.tags))
record = run(
benchmarks,
context=[lambda: context],
)

Expand Down
4 changes: 2 additions & 2 deletions src/nnbench/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ class PythonInfo:
Parameters
----------
packages: str
packages: Sequence[str]
Names of the requested packages under which they exist in the current environment.
For packages installed through ``pip``, this equals the PyPI package name.
"""

key = "python"

def __init__(self, packages: Sequence[str] = ()):
self.packages = packages
self.packages = tuple(packages)

def __call__(self) -> dict[str, Any]:
from importlib.metadata import PackageNotFoundError, version
Expand Down
56 changes: 18 additions & 38 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,45 +7,29 @@


def test_runner_collection(testfolder: str) -> None:
r = nnbench.BenchmarkRunner()
benchmarks = nnbench.collect(os.path.join(testfolder, "standard.py"), tags=("runner-collect",))
assert len(benchmarks) == 1

r.collect(os.path.join(testfolder, "standard.py"), tags=("runner-collect",))
assert len(r.benchmarks) == 1
r.clear()
benchmarks = nnbench.collect(testfolder, tags=("non-existing-tag",))
assert len(benchmarks) == 0

r.collect(testfolder, tags=("non-existing-tag",))
assert len(r.benchmarks) == 0
r.clear()

r.collect(testfolder, tags=("runner-collect",))
assert len(r.benchmarks) == 1
benchmarks = nnbench.collect(testfolder, tags=("runner-collect",))
assert len(benchmarks) == 1


def test_tag_selection(testfolder: str) -> None:
PATH = os.path.join(testfolder, "tags.py")

r = nnbench.BenchmarkRunner()

r.collect(PATH)
assert len(r.benchmarks) == 3
r.clear()

r.collect(PATH, tags=("tag1",))
assert len(r.benchmarks) == 2
r.clear()

r.collect(PATH, tags=("tag2",))
assert len(r.benchmarks) == 1
r.clear()
assert len(nnbench.collect(PATH)) == 3
assert len(nnbench.collect(PATH, tags=("tag1",))) == 2
assert len(nnbench.collect(PATH, tags=("tag2",))) == 1


def test_context_assembly(testfolder: str) -> None:
r = nnbench.BenchmarkRunner()

context_providers = [system, cpuarch, python_version]
result = r.run(
testfolder,
tags=("standard",),
benchmarks = nnbench.collect(testfolder, tags=("standard",))
result = nnbench.run(
benchmarks,
params={"x": 1, "y": 1},
context=context_providers,
)
Expand All @@ -57,17 +41,15 @@ def test_context_assembly(testfolder: str) -> None:


def test_error_on_duplicate_context_keys_in_runner(testfolder: str) -> None:
r = nnbench.BenchmarkRunner()

def duplicate_context_provider() -> dict[str, str]:
return {"system": "DuplicateSystem"}

context_providers = [system, duplicate_context_provider]

benchmarks = nnbench.collect(testfolder, tags=("standard",))
with pytest.raises(ValueError, match="got multiple values for context key 'system'"):
r.run(
testfolder,
tags=("standard",),
nnbench.run(
benchmarks,
params={"x": 1, "y": 1},
context=context_providers,
)
Expand All @@ -78,11 +60,9 @@ def test_filter_benchmarks_on_params(testfolder: str) -> None:
def prod(a: int, b: int = 1) -> int:
return a * b

r = nnbench.BenchmarkRunner()
r.benchmarks.append(prod)
# TODO (nicholasjng): This is hacky
rec1 = r.run("", params={"a": 1, "b": 2})
benchmarks = [prod]
rec1 = nnbench.run(benchmarks, params={"a": 1, "b": 2})
assert rec1.benchmarks[0]["parameters"] == {"a": 1, "b": 2}
# Assert that the defaults are also present if not overridden.
rec2 = r.run("", params={"a": 1})
rec2 = nnbench.run(benchmarks, params={"a": 1})
assert rec2.benchmarks[0]["parameters"] == {"a": 1, "b": 1}

0 comments on commit 596dc81

Please sign in to comment.