diff --git a/rdfproxy/checks/checkers.py b/rdfproxy/checks/checkers.py index a0d881d..b9ee080 100644 --- a/rdfproxy/checks/checkers.py +++ b/rdfproxy/checks/checkers.py @@ -1,6 +1,6 @@ """RDFProxy check runners.""" -from collections.abc import Callable, Iterable +from collections.abc import Callable from typing import Annotated, NoReturn, TypeVar from rdfproxy.checks.query_checks import ( @@ -8,27 +8,17 @@ check_select_query, check_solution_modifiers, ) +from toolz import compose_left T = TypeVar("T") +_TCheck = Callable[[T], T | NoReturn] -_TChecks = Iterable[Callable[[T], None | NoReturn]] - -def _check_factory(checks: _TChecks[T]) -> Callable[[T], T | NoReturn]: - """Produce an identity function that runs checks on its argument before returning.""" - - def _check(obj: T) -> T | NoReturn: - for check in checks: - check(obj) - return obj - - return _check +def compose_checker(*checkers: _TCheck) -> _TCheck: + return compose_left(*checkers) check_query: Annotated[ - Callable[[str], str | NoReturn], - "Run query checks and return the query unless an exception is raised.", -] = _check_factory( - checks=(check_parse_query, check_solution_modifiers, check_select_query) -) + _TCheck, "Run a series of checks on a query and return the query." +] = compose_checker(check_parse_query, check_select_query, check_solution_modifiers) diff --git a/rdfproxy/checks/query_checks.py b/rdfproxy/checks/query_checks.py index 2a65c5d..49003a0 100644 --- a/rdfproxy/checks/query_checks.py +++ b/rdfproxy/checks/query_checks.py @@ -1,6 +1,6 @@ """Query checks definitions.""" -from typing import NoReturn +from typing import NoReturn, TypeVar from rdflib.plugins.sparql.parser import parseQuery from rdfproxy.utils._exceptions import UnsupportedQueryException @@ -9,16 +9,21 @@ query_is_select_query, ) +TQuery = TypeVar("TQuery", bound=str) -def check_parse_query(query: str) -> None | NoReturn: + +def check_parse_query(query: TQuery) -> TQuery | NoReturn: parseQuery(query) + return query -def check_select_query(query: str) -> None | NoReturn: +def check_select_query(query: TQuery) -> TQuery | NoReturn: if not query_is_select_query(query): raise UnsupportedQueryException("Only SELECT queries are applicable.") + return query -def check_solution_modifiers(query: str) -> None | NoReturn: +def check_solution_modifiers(query: TQuery) -> TQuery | NoReturn: if query_has_solution_modifiers(query): raise UnsupportedQueryException("SPARQL solution modifieres are not supported.") + return query