Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace nnbench.BenchmarkRunner with modular nnbench.collect() and nnbench.run() APIs #193

Merged
merged 3 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/guides/customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ To supply context to your benchmarks, you can give a sequence of context provide
import nnbench

# uses the `platinfo` context provider from above to log platform metadata.
runner = nnbench.BenchmarkRunner()
result = runner.run(__name__, params={}, context=[platinfo])
benchmarks = nnbench.collect(__name__)
result = nnbench.run(benchmarks, params={}, context=[platinfo])
```

## Being type safe by using `nnbench.Parameters`
Expand All @@ -104,8 +104,8 @@ def prod(a: int, b: int) -> int:


params = MyParams(a=1, b=2)
runner = nnbench.BenchmarkRunner()
result = runner.run(__name__, params=params)
benchmarks = nnbench.collect(__name__)
result = nnbench.run(benchmarks, params=params)
```

While this does not have a concrete advantage in terms of type safety over a raw dictionary, it guards against accidental modification of parameters breaking reproducibility.
4 changes: 2 additions & 2 deletions docs/guides/organization.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ Now, to only run data quality benchmarks marked "foo", pass the corresponding ta
```python
import nnbench

runner = nnbench.BenchmarkRunner()
foo_data_metrics = runner.run("benchmarks/data_quality.py", params=..., tags=("foo",))
benchmarks = nnbench.collect("benchmarks/data_quality.py", tags=("foo",))
foo_data_metrics = nnbench.run(benchmarks, params=..., )
```

!!!tip
Expand Down
32 changes: 13 additions & 19 deletions docs/guides/runners.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,8 @@
# Collecting and running benchmarks

nnbench provides the `BenchmarkRunner` as a compact interface to collect and run benchmarks selectively.
nnbench provides the `nnbench.collect` and `nnbench.run` APIs as a compact interface to collect and run benchmarks selectively.

## The abstract `BenchmarkRunner` class
Let's first instantiate and then walk through the base class.

```python
from nnbench import BenchmarkRunner

runner = BenchmarkRunner()
```

Use the `BenchmarkRunner.collect()` method to collect benchmarks from files or directories.
Use the `nnbench.collect()` method to collect benchmarks from files or directories.
Assume we have the following benchmark setup:
```python
# dir_a/bm1.py
Expand Down Expand Up @@ -46,26 +37,29 @@ def the_last_benchmark(d: int) -> int:
Now we can collect benchmarks from files:

```python
runner.collect('dir_a/bm1.py')
import nnbench


benchmarks = nnbench.collect('dir_a/bm1.py')
```
Or directories:

```python
runner.collect('dir_b')
benchmarks = nnbench.collect('dir_b')
```

This collection can happen iteratively. So, after executing the two collections our runner has all four benchmarks ready for execution.

To remove the collected benchmarks again, use the `BenchmarkRunner.clear()` method.
You can also supply tags to the runner to selectively collect only benchmarks with the appropriate tag.
For example, after clearing the runner again, you can collect all benchmarks with the `"tag"` tag as such:

```python
runner.collect('dir_b', tags=("tag",))
import nnbench


tagged_benchmarks = nnbench.collect('dir_b', tags=("tag",))
```

To run the benchmarks, call the `BenchmarkRunner.run()` method and supply the necessary parameters required by the collected benchmarks.
To run the benchmarks, call the `nnbench.run()` method and supply the necessary parameters required by the collected benchmarks.

```python
runner.run("dir_b", params={"b": 1, "c": 2, "d": 3})
result = nnbench.run(benchmarks, params={"b": 1, "c": 2, "d": 3})
```
4 changes: 2 additions & 2 deletions docs/tutorials/duckdb.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import nnbench
from nnbench.context import GitEnvironmentInfo
from nnbench.reporter.file import FileReporter

runner = nnbench.BenchmarkRunner()
record = runner.run("benchmarks.py", params={"a": 1, "b": 1}, context=(GitEnvironmentInfo(),))
benchmarks = nnbench.collect("benchmarks.py")
record = nnbench.run(benchmarks, params={"a": 1, "b": 1}, context=(GitEnvironmentInfo(),))

file_reporter = FileReporter()
file_reporter.write(record, "record.json", driver="ndjson")
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/huggingface.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ In the following `IndexLabelMapMemo` class, we store a dictionary mapping the la

!!! Info
There is no need to type-hint `TokenClassificationModelMemo`s in the corresponding benchmarks -
the benchmark runner takes care of filling in the memoized values for the memos themselves.
the benchmark running method takes care of filling in the memoized values for the memos themselves.

Because we implemented our memoized values as four different memo class types, this modularizes the benchmark input parameters -
we only need to reference memos when they are actually used. Considering the recall benchmarks:
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/mnist.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ To properly structure our project, we avoid mixing training pipeline code and be

This definition is short and sweet, and contains a few important details:

* Both functions are given the `@nnbench.benchmark` decorator - this enables our runner to find and collect them before starting the benchmark run.
* Both functions are given the `@nnbench.benchmark` decorator - this allows us to find and collect them before starting the benchmark run.
* The `modelsize` benchmark is given a custom name (`"Model size (MB)"`), indicating that the resulting number is the combined size of the model weights in megabytes.
This is done for display purposes, to improve interpretability when reporting results.
* The `params` argument is the same in both benchmarks, both in name and type. This is important, since it ensures that both benchmarks will be run with the same model weights.

That's all - now we can shift over to our main pipeline code and see what is necessary to execute the benchmarks and visualize the results.

## Setting up a benchmark runner and parameters
## Setting up a benchmark run and parameters

After finishing the benchmark setup, we only need a few more lines to augment our pipeline.

Expand Down
4 changes: 2 additions & 2 deletions examples/bq/bq.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def main():
autodetect=True, source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON
)

runner = nnbench.BenchmarkRunner()
res = runner.run("benchmarks.py", params={"a": 1, "b": 1}, context=(GitEnvironmentInfo(),))
benchmarks = nnbench.collect("benchmarks.py")
res = nnbench.run(benchmarks, params={"a": 1, "b": 1}, context=(GitEnvironmentInfo(),))

load_job = client.load_table_from_json(res.to_json(), table_id, job_config=job_config)
load_job.result()
Expand Down
4 changes: 2 additions & 2 deletions examples/huggingface/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@


def main() -> None:
runner = nnbench.BenchmarkRunner()
benchmarks = nnbench.collect("benchmark.py", tags=("per-class",))
reporter = nnbench.ConsoleReporter()
result = runner.run("benchmark.py", tags=("per-class",))
result = nnbench.run(benchmarks)
reporter.display(result)


Expand Down
4 changes: 2 additions & 2 deletions examples/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,10 @@ def mnist_jax():
state, data = train(mnist)

# the nnbench portion.
runner = nnbench.BenchmarkRunner()
benchmarks = nnbench.collect(HERE)
reporter = nnbench.FileReporter()
params = MNISTTestParameters(params=state.params, data=data)
result = runner.run(HERE, params=params)
result = nnbench.run(benchmarks, params=params)
reporter.write(result, "result.json")


Expand Down
18 changes: 8 additions & 10 deletions examples/prefect/src/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ async def write(
def run_metric_benchmarks(
model: base.BaseEstimator, X_test: np.ndarray, y_test: np.ndarray
) -> nnbench.types.BenchmarkRecord:
runner = nnbench.BenchmarkRunner()
results = runner.run(
os.path.join(dir_path, "benchmark.py"),
tags=("metric",),
benchmarks = nnbench.collect(os.path.join(dir_path, "benchmark.py"), tags=("metric",))
results = nnbench.run(
benchmarks,
params={"model": model, "X_test": X_test, "y_test": y_test},
)
return results
Expand All @@ -44,10 +43,9 @@ def run_metric_benchmarks(
def run_metadata_benchmarks(
model: base.BaseEstimator, X: np.ndarray
) -> nnbench.types.BenchmarkRecord:
runner = nnbench.BenchmarkRunner()
result = runner.run(
os.path.join(dir_path, "benchmark.py"),
tags=("model-meta",),
benchmarks = nnbench.collect(os.path.join(dir_path, "benchmark.py"), tags=("model-meta",))
result = nnbench.run(
benchmarks,
params={"model": model, "X": X},
)
return result
Expand All @@ -73,7 +71,7 @@ async def train_and_benchmark(
metadata_results: types.BenchmarkRecord = run_metadata_benchmarks(model=model, X=X_test)

metadata_results.context.update(data_params)
metadata_results.context.update(context.PythonInfo())
metadata_results.context.update(context.PythonInfo()())

await reporter.write(
record=metadata_results, key="model-attributes", description="Model Attributes"
Expand All @@ -84,7 +82,7 @@ async def train_and_benchmark(
)

metric_results.context.update(data_params)
metric_results.context.update(context.PythonInfo())
metric_results.context.update(context.PythonInfo()())
await reporter.write(metric_results, key="model-performance", description="Model Performance")
return metadata_results, metric_results

Expand Down
2 changes: 1 addition & 1 deletion src/nnbench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .core import benchmark, parametrize, product
from .reporter import BenchmarkReporter, ConsoleReporter, FileReporter
from .runner import BenchmarkRunner
from .runner import collect, run
from .types import Benchmark, BenchmarkRecord, Memo, Parameters

__version__ = "0.4.0"
8 changes: 4 additions & 4 deletions src/nnbench/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys
from typing import Any

from nnbench import BenchmarkRunner, ConsoleReporter, __version__
from nnbench import ConsoleReporter, __version__, collect, run
from nnbench.config import NNBenchConfig, parse_nnbench_config
from nnbench.reporter import FileReporter

Expand Down Expand Up @@ -205,9 +205,9 @@ def main() -> int:
else:
context[k] = v

record = BenchmarkRunner().run(
args.benchmarks,
tags=tuple(args.tags),
benchmarks = collect(args.benchmarks, tags=tuple(args.tags))
record = run(
benchmarks,
context=[lambda: context],
)

Expand Down
4 changes: 2 additions & 2 deletions src/nnbench/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ class PythonInfo:
Parameters
----------
packages: str
packages: Sequence[str]
Names of the requested packages under which they exist in the current environment.
For packages installed through ``pip``, this equals the PyPI package name.
"""

key = "python"

def __init__(self, packages: Sequence[str] = ()):
self.packages = packages
self.packages = tuple(packages)

def __call__(self) -> dict[str, Any]:
from importlib.metadata import PackageNotFoundError, version
Expand Down
Loading
Loading