Skip to content

Commit

Permalink
Make arguments optional on ContextProviderDef
Browse files Browse the repository at this point in the history
This ensures that if the configured classpath is a function, the user is allowed
to not pass an `arguments` field, and omit the inline arguments table from the
TOML table.
  • Loading branch information
nicholasjng committed Dec 2, 2024
1 parent 4b0d1a7 commit c185b65
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 14 deletions.
12 changes: 9 additions & 3 deletions src/nnbench/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any

from nnbench import BenchmarkRunner, ConsoleReporter, __version__
from nnbench.config import nnbenchConfig, parse_nnbench_config
from nnbench.config import NNBenchConfig, parse_nnbench_config
from nnbench.reporter import FileReporter

_VERSION = f"%(prog)s version {__version__}"
Expand Down Expand Up @@ -72,7 +72,7 @@ def _log_level(log_level: str) -> str:
_log_level.__name__ = "log level"


def construct_parser(config: nnbenchConfig) -> argparse.ArgumentParser:
def construct_parser(config: NNBenchConfig) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser("nnbench", formatter_class=CustomFormatter)
parser.add_argument("--version", action="version", version=_VERSION)
parser.add_argument(
Expand Down Expand Up @@ -185,7 +185,13 @@ def main() -> int:

# TODO: Catch import errors if the module does not exist
klass = getattr(importlib.import_module(modname), classname)
builtin_providers[p.name] = klass(**p.arguments)
if isinstance(klass, type):
# classes can be instantiated with arguments,
# while functions cannot.
builtin_providers[p.name] = klass(**p.arguments)
else:
builtin_providers[p.name] = klass

for val in args.context:
try:
k, v = val.split("=", 1)
Expand Down
24 changes: 13 additions & 11 deletions src/nnbench/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Utilities for parsing an nnbench config block out of a pyproject.toml file."""
"""Utilities for parsing a ``[tool.nnbench]`` config block out of a pyproject.toml file."""

import logging
import os
import sys
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

Expand All @@ -29,13 +29,15 @@ class ContextProviderDef:
"""Name under which the provider should be registered by nnbench."""
classpath: str
"""Full path to the class or callable returning the context dict."""
arguments: dict[str, Any]
"""Arguments needed to instantiate the context provider class,
given as key-value pairs in the table."""
arguments: dict[str, Any] = field(default_factory=dict)
"""
Arguments needed to instantiate the context provider class,
given as key-value pairs in the table.
If the class path points to a function, no arguments may be given."""


@dataclass(frozen=True)
class nnbenchConfig:
class NNBenchConfig:
log_level: str
"""Log level to use for the ``nnbench`` module root logger."""
context: list[ContextProviderDef]
Expand All @@ -56,7 +58,7 @@ def from_toml(cls, d: dict[str, Any]) -> Self:
----------
d: dict[str, Any]
Mapping containing the [tool.nnbench] block as obtained by
``tomllib.load``.
``tomllib.load()``.
Returns
-------
Expand Down Expand Up @@ -93,7 +95,7 @@ def locate_pyproject() -> os.PathLike[str]:
raise RuntimeError("could not locate pyproject.toml")


def parse_nnbench_config(pyproject_path: str | os.PathLike[str] | None = None) -> nnbenchConfig:
def parse_nnbench_config(pyproject_path: str | os.PathLike[str] | None = None) -> NNBenchConfig:
"""
Load an nnbench config from a given pyproject.toml file.
Expand All @@ -107,7 +109,7 @@ def parse_nnbench_config(pyproject_path: str | os.PathLike[str] | None = None) -
Returns
-------
nnbenchConfig
NNBenchConfig
The loaded config if found, or a default config.
"""
Expand All @@ -116,9 +118,9 @@ def parse_nnbench_config(pyproject_path: str | os.PathLike[str] | None = None) -
pyproject_path = locate_pyproject()
except RuntimeError:
# pyproject.toml cannot be found, so return an empty config.
return nnbenchConfig.empty()
return NNBenchConfig.empty()

with open(pyproject_path, "rb") as fp:
pyproject_cfg = tomllib.load(fp)
nnbench_cfg = pyproject_cfg.get("tool", {}).get("nnbench", {})
return nnbenchConfig.from_toml(nnbench_cfg)
return NNBenchConfig.from_toml(nnbench_cfg)

0 comments on commit c185b65

Please sign in to comment.