diff --git a/src/nnbench/__init__.py b/src/nnbench/__init__.py index 0315052..fc35ed9 100644 --- a/src/nnbench/__init__.py +++ b/src/nnbench/__init__.py @@ -2,7 +2,7 @@ from .core import benchmark, parametrize, product from .reporter import BenchmarkReporter, ConsoleReporter, FileReporter -from .runner import BenchmarkRunner +from .runner import collect, run from .types import Benchmark, BenchmarkRecord, Memo, Parameters __version__ = "0.4.0" diff --git a/src/nnbench/cli.py b/src/nnbench/cli.py index bee2d4a..a90cea8 100644 --- a/src/nnbench/cli.py +++ b/src/nnbench/cli.py @@ -6,7 +6,7 @@ import sys from typing import Any -from nnbench import BenchmarkRunner, ConsoleReporter, __version__ +from nnbench import ConsoleReporter, __version__, collect, run from nnbench.config import NNBenchConfig, parse_nnbench_config from nnbench.reporter import FileReporter @@ -205,9 +205,9 @@ def main() -> int: else: context[k] = v - record = BenchmarkRunner().run( - args.benchmarks, - tags=tuple(args.tags), + benchmarks = collect(args.benchmarks, tags=tuple(args.tags)) + record = run( + benchmarks, context=[lambda: context], ) diff --git a/src/nnbench/context.py b/src/nnbench/context.py index 27a3c33..be8cfb2 100644 --- a/src/nnbench/context.py +++ b/src/nnbench/context.py @@ -29,7 +29,7 @@ class PythonInfo: Parameters ---------- - packages: str + packages: Sequence[str] Names of the requested packages under which they exist in the current environment. For packages installed through ``pip``, this equals the PyPI package name. """ @@ -37,7 +37,7 @@ class PythonInfo: key = "python" def __init__(self, packages: Sequence[str] = ()): - self.packages = packages + self.packages = tuple(packages) def __call__(self) -> dict[str, Any]: from importlib.metadata import PackageNotFoundError, version diff --git a/tests/test_runner.py b/tests/test_runner.py index c1aa2ca..3c138cc 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -7,45 +7,29 @@ def test_runner_collection(testfolder: str) -> None: - r = nnbench.BenchmarkRunner() + benchmarks = nnbench.collect(os.path.join(testfolder, "standard.py"), tags=("runner-collect",)) + assert len(benchmarks) == 1 - r.collect(os.path.join(testfolder, "standard.py"), tags=("runner-collect",)) - assert len(r.benchmarks) == 1 - r.clear() + benchmarks = nnbench.collect(testfolder, tags=("non-existing-tag",)) + assert len(benchmarks) == 0 - r.collect(testfolder, tags=("non-existing-tag",)) - assert len(r.benchmarks) == 0 - r.clear() - - r.collect(testfolder, tags=("runner-collect",)) - assert len(r.benchmarks) == 1 + benchmarks = nnbench.collect(testfolder, tags=("runner-collect",)) + assert len(benchmarks) == 1 def test_tag_selection(testfolder: str) -> None: PATH = os.path.join(testfolder, "tags.py") - r = nnbench.BenchmarkRunner() - - r.collect(PATH) - assert len(r.benchmarks) == 3 - r.clear() - - r.collect(PATH, tags=("tag1",)) - assert len(r.benchmarks) == 2 - r.clear() - - r.collect(PATH, tags=("tag2",)) - assert len(r.benchmarks) == 1 - r.clear() + assert len(nnbench.collect(PATH)) == 3 + assert len(nnbench.collect(PATH, tags=("tag1",))) == 2 + assert len(nnbench.collect(PATH, tags=("tag2",))) == 1 def test_context_assembly(testfolder: str) -> None: - r = nnbench.BenchmarkRunner() - context_providers = [system, cpuarch, python_version] - result = r.run( - testfolder, - tags=("standard",), + benchmarks = nnbench.collect(testfolder, tags=("standard",)) + result = nnbench.run( + benchmarks, params={"x": 1, "y": 1}, context=context_providers, ) @@ -57,17 +41,15 @@ def test_context_assembly(testfolder: str) -> None: def test_error_on_duplicate_context_keys_in_runner(testfolder: str) -> None: - r = nnbench.BenchmarkRunner() - def duplicate_context_provider() -> dict[str, str]: return {"system": "DuplicateSystem"} context_providers = [system, duplicate_context_provider] + benchmarks = nnbench.collect(testfolder, tags=("standard",)) with pytest.raises(ValueError, match="got multiple values for context key 'system'"): - r.run( - testfolder, - tags=("standard",), + nnbench.run( + benchmarks, params={"x": 1, "y": 1}, context=context_providers, ) @@ -78,11 +60,9 @@ def test_filter_benchmarks_on_params(testfolder: str) -> None: def prod(a: int, b: int = 1) -> int: return a * b - r = nnbench.BenchmarkRunner() - r.benchmarks.append(prod) - # TODO (nicholasjng): This is hacky - rec1 = r.run("", params={"a": 1, "b": 2}) + benchmarks = [prod] + rec1 = nnbench.run(benchmarks, params={"a": 1, "b": 2}) assert rec1.benchmarks[0]["parameters"] == {"a": 1, "b": 2} # Assert that the defaults are also present if not overridden. - rec2 = r.run("", params={"a": 1}) + rec2 = nnbench.run(benchmarks, params={"a": 1}) assert rec2.benchmarks[0]["parameters"] == {"a": 1, "b": 1}