From 00fffdcbb53c06093c018be3775a604b5bbe097e Mon Sep 17 00:00:00 2001 From: David Hagen Date: Wed, 20 Nov 2024 20:20:19 -0500 Subject: [PATCH] Run mypy on src (#115) Many fixes to the types --- docs/utility_functions.md | 4 +- examples/expressions.py | 44 +++--- examples/json.py | 3 +- examples/positioned.py | 22 +-- noxfile.py | 9 +- poetry.lock | 60 +++++++- pyproject.toml | 9 ++ src/parsita/metaclasses.py | 88 +++++++---- src/parsita/options.py | 5 +- src/parsita/parsers/_alternative.py | 49 +++++- src/parsita/parsers/_any.py | 16 +- src/parsita/parsers/_base.py | 165 +++++++++++++++------ src/parsita/parsers/_conversion.py | 15 +- src/parsita/parsers/_debug.py | 28 +++- src/parsita/parsers/_end_of_source.py | 16 +- src/parsita/parsers/_literal.py | 44 ++++-- src/parsita/parsers/_optional.py | 24 ++- src/parsita/parsers/_predicate.py | 8 +- src/parsita/parsers/_regex.py | 29 +++- src/parsita/parsers/_repeated.py | 59 ++++++-- src/parsita/parsers/_repeated_seperated.py | 102 ++++++++----- src/parsita/parsers/_sequential.py | 22 +-- src/parsita/parsers/_success.py | 25 ++-- src/parsita/parsers/_until.py | 18 ++- src/parsita/py.typed | 0 src/parsita/state/__init__.py | 2 +- src/parsita/state/_exceptions.py | 4 +- src/parsita/state/_reader.py | 43 ++++-- src/parsita/state/_result.py | 13 +- src/parsita/state/_state.py | 17 ++- src/parsita/util.py | 26 +++- tests/test_basic.py | 24 +-- tests/test_metaclass_scopes.py | 26 ++-- tests/test_state.py | 12 +- tests/test_util.py | 8 +- 35 files changed, 709 insertions(+), 330 deletions(-) create mode 100644 src/parsita/py.typed diff --git a/docs/utility_functions.md b/docs/utility_functions.md index b6b99bb..608dd5a 100644 --- a/docs/utility_functions.md +++ b/docs/utility_functions.md @@ -24,7 +24,7 @@ assert BooleanParsers.boolean.parse('false') == Success(False) ## `splat(function)`: convert a function of many arguments to take only one list argument -The function `splat(function: Callable[Tuple[*B], A]) -> Callable[Tuple[Tuple[*B]], A]` has a complicated type signature, but does a simple thing. It takes a single function that takes multiple arguments and converts it to a function that takes only one argument, which is a list of all original arguments. It is particularly useful for passing a list of results from a sequential parser `&` to a function that takes each element as an separate argument. By applying `splat` to the function, it now takes the single list that is returned by the sequential parser. +The function `splat(function: Callable[tuple[*B], A]) -> Callable[tuple[tuple[*B]], A]` has a complicated type signature, but does a simple thing. It takes a single function that takes multiple arguments and converts it to a function that takes only one argument, which is a list of all original arguments. It is particularly useful for passing a list of results from a sequential parser `&` to a function that takes each element as an separate argument. By applying `splat` to the function, it now takes the single list that is returned by the sequential parser. ```python from collections import namedtuple @@ -44,7 +44,7 @@ assert UrlParsers.url.parse('https://drhagen.com:443/blog/') == \ ## `unsplat(function)`: convert a function of one list argument to take many arguments -The function `unsplat(function: Callable[Tuple[Tuple[*B]], A]) -> Callable[Tuple[*B], A]` does the opposite of `splat`. It takes a single function that takes a single argument that is a list and converts it to a function that takes multiple arguments, each of which was an element of the original list. It is not very useful for writing parsers because the conversion parser always calls its converter function with a single argument, but is included here to complement `splat`. +The function `unsplat(function: Callable[tuple[tuple[*B]], A]) -> Callable[tuple[*B], A]` does the opposite of `splat`. It takes a single function that takes a single argument that is a list and converts it to a function that takes multiple arguments, each of which was an element of the original list. It is not very useful for writing parsers because the conversion parser always calls its converter function with a single argument, but is included here to complement `splat`. ```python from parsita.util import splat, unsplat diff --git a/examples/expressions.py b/examples/expressions.py index d1cee16..8ac9f29 100644 --- a/examples/expressions.py +++ b/examples/expressions.py @@ -1,6 +1,30 @@ +from typing import Literal, Sequence + from parsita import ParserContext, lit, opt, reg, rep +def make_term(args: tuple[float, Sequence[tuple[Literal["*", "/"], float]]]) -> float: + factor1, factors = args + result = factor1 + for op, factor in factors: + if op == "*": + result = result * factor + else: + result = result / factor + return result + + +def make_expr(args: tuple[float, Sequence[tuple[Literal["+", "-"], float]]]) -> float: + term1, terms = args + result = term1 + for op, term2 in terms: + if op == "+": + result = result + term2 + else: + result = result - term2 + return result + + class ExpressionParsers(ParserContext, whitespace=r"[ ]*"): number = reg(r"[+-]?\d+(\.\d+)?(e[+-]?\d+)?") > float @@ -8,28 +32,8 @@ class ExpressionParsers(ParserContext, whitespace=r"[ ]*"): factor = base & opt("^" >> base) > (lambda x: x[0] ** x[1][0] if x[1] else x[0]) - def make_term(args): - factor1, factors = args - result = factor1 - for op, factor in factors: - if op == "*": - result = result * factor - else: - result = result / factor - return result - term = factor & rep(lit("*", "/") & factor) > make_term - def make_expr(args): - term1, terms = args - result = term1 - for op, term2 in terms: - if op == "+": - result = result + term2 - else: - result = result - term2 - return result - expr = term & rep(lit("+", "-") & term) > make_expr diff --git a/examples/json.py b/examples/json.py index e484b06..4ccf140 100644 --- a/examples/json.py +++ b/examples/json.py @@ -13,7 +13,7 @@ class JsonStringParsers(ParserContext): line_feed = lit(r"\n") > constant("\n") carriage_return = lit(r"\r") > constant("\r") tab = lit(r"\t") > constant("\t") - uni = reg(r"\\u([0-9a-fA-F]{4})") > (lambda x: chr(int(x.group(1), 16))) + uni = reg(r"\\u[0-9a-fA-F]{4}") > (lambda x: chr(int(x[2:], 16))) escaped = ( quote @@ -61,6 +61,7 @@ class JsonParsers(ParserContext, whitespace=r"[ \t\n\r]*"): "width" : 4.0 }""", '{"text" : ""}', + r'"\u2260"', ] for string in strings: diff --git a/examples/positioned.py b/examples/positioned.py index 0477506..e924912 100644 --- a/examples/positioned.py +++ b/examples/positioned.py @@ -8,7 +8,7 @@ from abc import abstractmethod from dataclasses import dataclass -from typing import Generic +from typing import Generic, Optional from parsita import Parser, ParserContext, Reader, reg from parsita.state import Continue, Input, Output, State @@ -45,7 +45,7 @@ def __init__(self, parser: Parser[Input, PositionAware[Output]]): super().__init__() self.parser = parser - def _consume(self, state: State, reader: Reader[Input]): + def _consume(self, state: State, reader: Reader[Input]) -> Optional[Continue[Input, Output]]: start = reader.position status = self.parser.consume(state, reader) @@ -55,11 +55,11 @@ def _consume(self, state: State, reader: Reader[Input]): else: return status - def __repr__(self): + def __repr__(self) -> str: return self.name_or_nothing() + f"positioned({self.parser.name_or_repr()})" -def positioned(parser: Parser[Input, PositionAware[Output]]): +def positioned(parser: Parser[Input, PositionAware[Output]]) -> PositionedParser[Input, Output]: """Set the position on a PositionAware value. This parser matches ``parser`` and, if successful, calls ``set_position`` @@ -75,18 +75,18 @@ def positioned(parser: Parser[Input, PositionAware[Output]]): # Everything below here is an example use case @dataclass -class UnfinishedVariable(PositionAware): +class Variable: name: str - - def set_position(self, start: int, length: int): - return Variable(self.name, start, length) + start: int + length: int @dataclass -class Variable: +class UnfinishedVariable(PositionAware[Variable]): name: str - start: int - length: int + + def set_position(self, start: int, length: int) -> Variable: + return Variable(self.name, start, length) @dataclass diff --git a/noxfile.py b/noxfile.py index 4780123..ed5fed8 100644 --- a/noxfile.py +++ b/noxfile.py @@ -3,7 +3,7 @@ from nox import options, parametrize from nox_poetry import Session, session -options.sessions = ["test", "coverage", "lint"] +options.sessions = ["test", "coverage", "lint", "type_check"] @session(python=["3.9", "3.10", "3.11", "3.12", "3.13"]) @@ -27,6 +27,11 @@ def lint(s: Session, command: list[str]): @session(venv_backend="none") -def format(s: Session) -> None: +def type_check(s: Session): + s.run("mypy", "src") + + +@session(venv_backend="none") +def format(s: Session): s.run("ruff", "check", ".", "--select", "I", "--fix") s.run("ruff", "format", ".") diff --git a/poetry.lock b/poetry.lock index fa2f787..2f7c965 100644 --- a/poetry.lock +++ b/poetry.lock @@ -582,6 +582,64 @@ files = [ {file = "mkdocs_material_extensions-1.3.1.tar.gz", hash = "sha256:10c9511cea88f568257f960358a467d12b970e1f7b2c0e5fb2bb48cab1928443"}, ] +[[package]] +name = "mypy" +version = "1.11.2" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy-1.11.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d42a6dd818ffce7be66cce644f1dff482f1d97c53ca70908dff0b9ddc120b77a"}, + {file = "mypy-1.11.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:801780c56d1cdb896eacd5619a83e427ce436d86a3bdf9112527f24a66618fef"}, + {file = "mypy-1.11.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:41ea707d036a5307ac674ea172875f40c9d55c5394f888b168033177fce47383"}, + {file = "mypy-1.11.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6e658bd2d20565ea86da7d91331b0eed6d2eee22dc031579e6297f3e12c758c8"}, + {file = "mypy-1.11.2-cp310-cp310-win_amd64.whl", hash = "sha256:478db5f5036817fe45adb7332d927daa62417159d49783041338921dcf646fc7"}, + {file = "mypy-1.11.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:75746e06d5fa1e91bfd5432448d00d34593b52e7e91a187d981d08d1f33d4385"}, + {file = "mypy-1.11.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a976775ab2256aadc6add633d44f100a2517d2388906ec4f13231fafbb0eccca"}, + {file = "mypy-1.11.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:cd953f221ac1379050a8a646585a29574488974f79d8082cedef62744f0a0104"}, + {file = "mypy-1.11.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:57555a7715c0a34421013144a33d280e73c08df70f3a18a552938587ce9274f4"}, + {file = "mypy-1.11.2-cp311-cp311-win_amd64.whl", hash = "sha256:36383a4fcbad95f2657642a07ba22ff797de26277158f1cc7bd234821468b1b6"}, + {file = "mypy-1.11.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e8960dbbbf36906c5c0b7f4fbf2f0c7ffb20f4898e6a879fcf56a41a08b0d318"}, + {file = "mypy-1.11.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06d26c277962f3fb50e13044674aa10553981ae514288cb7d0a738f495550b36"}, + {file = "mypy-1.11.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6e7184632d89d677973a14d00ae4d03214c8bc301ceefcdaf5c474866814c987"}, + {file = "mypy-1.11.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3a66169b92452f72117e2da3a576087025449018afc2d8e9bfe5ffab865709ca"}, + {file = "mypy-1.11.2-cp312-cp312-win_amd64.whl", hash = "sha256:969ea3ef09617aff826885a22ece0ddef69d95852cdad2f60c8bb06bf1f71f70"}, + {file = "mypy-1.11.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:37c7fa6121c1cdfcaac97ce3d3b5588e847aa79b580c1e922bb5d5d2902df19b"}, + {file = "mypy-1.11.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4a8a53bc3ffbd161b5b2a4fff2f0f1e23a33b0168f1c0778ec70e1a3d66deb86"}, + {file = "mypy-1.11.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2ff93107f01968ed834f4256bc1fc4475e2fecf6c661260066a985b52741ddce"}, + {file = "mypy-1.11.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:edb91dded4df17eae4537668b23f0ff6baf3707683734b6a818d5b9d0c0c31a1"}, + {file = "mypy-1.11.2-cp38-cp38-win_amd64.whl", hash = "sha256:ee23de8530d99b6db0573c4ef4bd8f39a2a6f9b60655bf7a1357e585a3486f2b"}, + {file = "mypy-1.11.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:801ca29f43d5acce85f8e999b1e431fb479cb02d0e11deb7d2abb56bdaf24fd6"}, + {file = "mypy-1.11.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:af8d155170fcf87a2afb55b35dc1a0ac21df4431e7d96717621962e4b9192e70"}, + {file = "mypy-1.11.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f7821776e5c4286b6a13138cc935e2e9b6fde05e081bdebf5cdb2bb97c9df81d"}, + {file = "mypy-1.11.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:539c570477a96a4e6fb718b8d5c3e0c0eba1f485df13f86d2970c91f0673148d"}, + {file = "mypy-1.11.2-cp39-cp39-win_amd64.whl", hash = "sha256:3f14cd3d386ac4d05c5a39a51b84387403dadbd936e17cb35882134d4f8f0d24"}, + {file = "mypy-1.11.2-py3-none-any.whl", hash = "sha256:b499bc07dbdcd3de92b0a8b29fdf592c111276f6a12fe29c30f6c417dd546d12"}, + {file = "mypy-1.11.2.tar.gz", hash = "sha256:7f9993ad3e0ffdc95c2a14b66dee63729f021968bff8ad911867579c65d13a79"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = ">=4.6.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + [[package]] name = "nox" version = "2024.10.9" @@ -1178,4 +1236,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "9ba16ffbf8c1b1fd1fe55e741ea91811415528917ca208c98fb6534da5ba4222" +content-hash = "86615f580388a66f3fd4eea2fb687b1e9073bf0596a6d248e6c94622ab9475df" diff --git a/pyproject.toml b/pyproject.toml index 51afd77..6a06342 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,9 @@ pytest-timeout = "*" # Lint ruff = "^0.6" +# Type checking +mypy = "^1" + # Docs mkdocs-material = "^9" @@ -49,6 +52,7 @@ exclude_lines = [ "pragma: no cover", "raise NotImplementedError", "if TYPE_CHECKING", + "@overload", ] [tool.coverage.paths] @@ -85,6 +89,11 @@ extend-ignore = ["F821", "N805"] "__init__.py" = ["F401"] +[tool.mypy] +strict = true +implicit_reexport = true + + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" diff --git a/src/parsita/metaclasses.py b/src/parsita/metaclasses.py index 0f2e4e3..c3d55ff 100644 --- a/src/parsita/metaclasses.py +++ b/src/parsita/metaclasses.py @@ -1,25 +1,38 @@ +from __future__ import annotations + __all__ = ["ForwardDeclaration", "fwd", "ParserContext"] import builtins import inspect import re +from dataclasses import dataclass from re import Pattern -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Generic, NoReturn, Optional, Union, no_type_check from . import options from .parsers import LiteralParser, Parser, RegexParser -from .state import Input +from .state import Continue, Input, Output, Reader, State + +missing: Any = object() + -missing = object() +@dataclass(frozen=True) +class Options: + whitespace: Optional[Parser[Any, object]] = None -class ParsersDict(dict): - def __init__(self, old_options: dict): +class ParsersDict(dict[str, Any]): + def __init__(self, old_options: Options): super().__init__() - self.old_options = old_options # Holds state of options at start of definition - self.forward_declarations = {} # Stores forward declarations as they are discovered - def __missing__(self, key): + # Holds state of options at start of definition + self.old_options = old_options + + # Stores forward declarations as they are discovered + self.forward_declarations: dict[str, ForwardDeclaration[Any, Any]] = {} + + @no_type_check # mypy cannot handle all the frame inspection + def __missing__(self, key: str) -> ForwardDeclaration[Any, Any]: frame = inspect.currentframe() # Should be the frame of __missing__ while frame.f_code.co_name != "__missing__": # pragma: no cover # But sometimes debuggers add frames on top of the stack; @@ -43,7 +56,7 @@ def __missing__(self, key): self.forward_declarations[key] = new_forward_declaration return new_forward_declaration - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> None: if isinstance(value, Parser): # Protects against accidental concatenation of sequential parsers value.protected = True @@ -54,21 +67,28 @@ def __setitem__(self, key, value): super().__setitem__(key, value) -class ForwardDeclaration(Parser): - def __init__(self): - self._definition = None +class ForwardDeclaration(Generic[Input, Output], Parser[Input, Output]): + def __init__(self) -> None: + self._definition: Optional[Parser[Input, Output]] = None - def __getattribute__(self, member): + def __getattribute__(self, member: str) -> Any: if member != "_definition" and self._definition is not None: return getattr(self._definition, member) else: return object.__getattribute__(self, member) - def define(self, parser: Parser) -> None: + if TYPE_CHECKING: + # Type checkers don't know that `_consume` is implemented in `__getattribute__` + + def _consume( + self, state: State, reader: Reader[Input] + ) -> Optional[Continue[Input, Output]]: ... + + def define(self, parser: Parser[Input, Output]) -> None: self._definition = parser -def fwd() -> ForwardDeclaration: +def fwd() -> ForwardDeclaration[Input, Output]: """Manually create a forward declaration. Normally, forward declarations are created automatically by the contexts. @@ -79,16 +99,20 @@ def fwd() -> ForwardDeclaration: class ParserContextMeta(type): - default_whitespace: Union[Parser[Input, Any], Pattern, str, None] = None + default_whitespace: Union[Parser[Any, object], Pattern[str], str, None] = None @classmethod def __prepare__( mcs, # noqa: N804 - name, - bases, + name: str, + bases: tuple[type, ...], + /, *, - whitespace: Union[Parser[Input, Any], Pattern, str, None] = missing, - ): + whitespace: Union[Parser[Any, object], Pattern[str], str, None] = missing, + **kwargs: Any, + ) -> ParsersDict: + super().__prepare__(name, bases, **kwargs) + if whitespace is missing: whitespace = mcs.default_whitespace @@ -98,15 +122,13 @@ def __prepare__( if isinstance(whitespace, Pattern): whitespace = RegexParser(whitespace) - old_options = { - "whitespace": options.whitespace, - } + old_options = Options(whitespace=options.whitespace) # Store whitespace in global location options.whitespace = whitespace return ParsersDict(old_options) - def __init__(cls, name, bases, dct, **_): + def __init__(cls, name: str, bases: tuple[type, ...], dct: ParsersDict, /, **_: Any) -> None: old_options = dct.old_options super().__init__(name, bases, dct) @@ -119,15 +141,21 @@ def __init__(cls, name, bases, dct, **_): forward_declaration._definition = obj # Reset global variables - for key, value in old_options.items(): - setattr(options, key, value) - - def __new__(mcs, name, bases, dct, **_): # noqa: N804 + options.whitespace = old_options.whitespace + + def __new__( + mcs: type[ParserContextMeta], # noqa: N804 + name: str, + bases: tuple[type, ...], + dct: ParsersDict, + /, + whitespace: Union[Parser[Any, Any], Pattern[str], str, None] = missing, + ) -> ParserContextMeta: return super().__new__(mcs, name, bases, dct) - def __call__(cls, *args, **kwargs): + def __call__(cls, *args: object, **kwargs: object) -> NoReturn: raise TypeError( - "Parsers cannot be instantiated. They use class bodies purely as contexts for " + "ParserContexts cannot be instantiated. They use class bodies purely as contexts for " "managing defaults and allowing forward declarations. Access the individual parsers " "as static attributes." ) diff --git a/src/parsita/options.py b/src/parsita/options.py index caeb874..809ffce 100644 --- a/src/parsita/options.py +++ b/src/parsita/options.py @@ -1,9 +1,8 @@ __all__ = ["whitespace"] -from typing import Any +from typing import Any, Optional from .parsers import Parser -from .state import Input # Global mutable state -whitespace: Parser[Input, Any] = None +whitespace: Optional[Parser[Any, Any]] = None diff --git a/src/parsita/parsers/_alternative.py b/src/parsita/parsers/_alternative.py index 879d96c..e816d77 100644 --- a/src/parsita/parsers/_alternative.py +++ b/src/parsita/parsers/_alternative.py @@ -1,6 +1,6 @@ __all__ = ["FirstAlternativeParser", "first", "LongestAlternativeParser", "longest"] -from typing import Generic, Optional, Sequence, Union +from typing import Generic, Optional, Sequence, Union, overload from ..state import Continue, Input, Output, Reader, State from ._base import Parser, wrap_literal @@ -11,7 +11,7 @@ def __init__(self, parser: Parser[Input, Output], *parsers: Parser[Input, Output super().__init__() self.parsers = (parser, *parsers) - def _consume(self, state: State[Input], reader: Reader[Input]): + def _consume(self, state: State, reader: Reader[Input]) -> Optional[Continue[Input, Output]]: for parser in self.parsers: status = parser.consume(state, reader) if isinstance(status, Continue): @@ -19,7 +19,7 @@ def _consume(self, state: State[Input], reader: Reader[Input]): return None - def __repr__(self): + def __repr__(self) -> str: names = [] for parser in self.parsers: names.append(parser.name_or_repr()) @@ -27,10 +27,27 @@ def __repr__(self): return self.name_or_nothing() + f"first({', '.join(names)})" +@overload def first( parser: Union[Parser[Input, Output], Sequence[Input]], *parsers: Union[Parser[Input, Output], Sequence[Input]], -) -> FirstAlternativeParser[Input, Output]: +) -> FirstAlternativeParser[Input, Sequence[Input]]: + # This signature is not quite right because Python has no higher-kinded + # types to express that Output must be a subtype of Sequence[Input]. + ... + + +@overload +def first( + parser: Parser[Input, Output], + *parsers: Parser[Input, Output], +) -> FirstAlternativeParser[Input, Output]: ... + + +def first( + parser: Union[Parser[Input, Output], Sequence[Input]], + *parsers: Union[Parser[Input, Output], Sequence[Input]], +) -> FirstAlternativeParser[Input, Union[Output, Sequence[Input]]]: """Match the first of several alternative parsers. A ``AlternativeParser`` attempts to match each supplied parser. If a parser @@ -54,8 +71,8 @@ def __init__(self, parser: Parser[Input, Output], *parsers: Parser[Input, Output super().__init__() self.parsers = (parser, *parsers) - def _consume(self, state: State[Input], reader: Reader[Input]): - longest_success: Optional[Continue] = None + def _consume(self, state: State, reader: Reader[Input]) -> Optional[Continue[Input, Output]]: + longest_success: Optional[Continue[Input, Output]] = None for parser in self.parsers: status = parser.consume(state, reader) if isinstance(status, Continue): @@ -67,7 +84,7 @@ def _consume(self, state: State[Input], reader: Reader[Input]): return longest_success - def __repr__(self): + def __repr__(self) -> str: names = [] for parser in self.parsers: names.append(parser.name_or_repr()) @@ -75,10 +92,26 @@ def __repr__(self): return self.name_or_nothing() + " | ".join(names) +# This signature is not quite right because Python has no higher-kinded +# types to express that Output must be a subtype of Sequence[Input]. +@overload +def longest( + parser: Union[Parser[Input, Output], Sequence[Input]], + *parsers: Union[Parser[Input, Output], Sequence[Input]], +) -> LongestAlternativeParser[Input, Sequence[Input]]: ... + + +@overload +def longest( + parser: Parser[Input, Output], + *parsers: Parser[Input, Output], +) -> LongestAlternativeParser[Input, Output]: ... + + def longest( parser: Union[Parser[Input, Output], Sequence[Input]], *parsers: Union[Parser[Input, Output], Sequence[Input]], -) -> LongestAlternativeParser[Input, Output]: +) -> LongestAlternativeParser[Input, Union[Output, Sequence[Input]]]: """Match the longest of several alternative parsers. A ``LongestAlternativeParser`` attempts to match all supplied parsers. If diff --git a/src/parsita/parsers/_any.py b/src/parsita/parsers/_any.py index c6017b6..169635b 100644 --- a/src/parsita/parsers/_any.py +++ b/src/parsita/parsers/_any.py @@ -1,12 +1,12 @@ __all__ = ["AnyParser", "any1"] -from typing import Generic, Optional +from typing import Any, Generic, Optional -from ..state import Continue, Input, Reader, State +from ..state import Continue, Element, Reader, State from ._base import Parser -class AnyParser(Generic[Input], Parser[Input, Input]): +class AnyParser(Generic[Element], Parser[Element, Element]): """Match any single element. This parser matches any single element, returning it. This is useful when it @@ -15,20 +15,20 @@ class AnyParser(Generic[Input], Parser[Input, Input]): the end of the stream. """ - def __init__(self): + def __init__(self) -> None: super().__init__() def _consume( - self, state: State[Input], reader: Reader[Input] - ) -> Optional[Continue[Input, Input]]: + self, state: State, reader: Reader[Element] + ) -> Optional[Continue[Element, Element]]: if reader.finished: state.register_failure("anything", reader) return None else: return Continue(reader.rest, reader.first) - def __repr__(self): + def __repr__(self) -> str: return self.name_or_nothing() + "any1" -any1 = AnyParser() +any1: AnyParser[Any] = AnyParser() diff --git a/src/parsita/parsers/_base.py b/src/parsita/parsers/_base.py index ec94c4f..d8efce7 100644 --- a/src/parsita/parsers/_base.py +++ b/src/parsita/parsers/_base.py @@ -3,7 +3,7 @@ __all__ = ["Parser", "wrap_literal"] from abc import abstractmethod -from typing import Any, Generic, Optional, Sequence, Union +from typing import Any, Callable, Generic, Optional, Sequence, TypeVar, Union, overload from .. import options from ..state import ( @@ -20,11 +20,24 @@ Success, ) +OtherOutput = TypeVar("OtherOutput") + # Singleton indicating that no result is yet in the memo -missing = object() +# Use Ellipsis instead of object() to avoid mypy errors +missing = ... + + +@overload +def wrap_literal(obj: Sequence[Input]) -> Parser[Input, Sequence[Input]]: ... + + +@overload +def wrap_literal(obj: Parser[Input, Output]) -> Parser[Input, Output]: ... -def wrap_literal(obj: Any) -> Parser: +def wrap_literal( + obj: Union[Parser[Input, Output], Sequence[Input]], +) -> Union[Parser[Input, Output], Parser[Input, Sequence[Input]]]: from ._literal import LiteralParser if isinstance(obj, Parser): @@ -65,9 +78,7 @@ class Parser(Generic[Input, Output]): name. """ - def consume( - self, state: State[Input], reader: Reader[Input] - ) -> Optional[Continue[Input, Output]]: + def consume(self, state: State, reader: Reader[Input]) -> Optional[Continue[Input, Output]]: """Match this parser at the given location. This is a concrete wrapper around ``consume``. This method implements @@ -107,9 +118,7 @@ def consume( return result @abstractmethod - def _consume( - self, state: State[Input], reader: Reader[Input] - ) -> Optional[Continue[Input, Output]]: + def _consume(self, state: State, reader: Reader[Input]) -> Optional[Continue[Input, Output]]: """Abstract method for matching this parser at the given location. This is the central method of every parser combinator. @@ -124,7 +133,7 @@ def _consume( """ raise NotImplementedError() - def parse(self, source: Union[Sequence[Input], Reader]) -> Result[Output]: + def parse(self, source: Union[Sequence[Input], Reader[Input]]) -> Result[Output]: """Completely parse a source. Args: @@ -134,13 +143,13 @@ def parse(self, source: Union[Sequence[Input], Reader]) -> Result[Output]: If the parser succeeded in matching and consumed the entire output, the value from ``Continue`` is copied to make a ``Success``. If the parser failed in matching, the expected patterns at the farthest - point in the source are used to construct a ``ParseError`, which is + point in the source are used to construct a ``ParseError``, which is then used to contruct a ``Failure``. If the parser succeeded but the source was not completely consumed, it returns a ``Failure`` with a - ``ParseError` indicating this. + ``ParseError`` indicating this. If a ``Reader`` is passed in, it is used directly. Otherwise, the source - is converted to an appropriate ``Reader``. If the source is ``str`, a + is converted to an appropriate ``Reader``. If the source is ``str``, a ``StringReader`` is used. Otherwise, a ``SequenceReader`` is used. """ from ._end_of_source import eof @@ -148,11 +157,11 @@ def parse(self, source: Union[Sequence[Input], Reader]) -> Result[Output]: if isinstance(source, Reader): reader = source elif isinstance(source, str): - reader = StringReader(source, 0) + reader = StringReader(source, 0) # type: ignore else: reader = SequenceReader(source) - state: State[Input] = State() + state: State = State() status = (self << eof).consume(state, reader) @@ -166,7 +175,9 @@ def parse(self, source: Union[Sequence[Input], Reader]) -> Result[Output]: used.add(expected) unique_expected.append(expected) - return Failure(ParseError(state.farthest, unique_expected)) + # mypy does not understand that state.farthest cannot be None when there is a failure + parse_error = ParseError(state.farthest, unique_expected) # type: ignore + return Failure(parse_error) name: Optional[str] = None @@ -178,70 +189,128 @@ def name_or_repr(self) -> str: else: return self.name - def name_or_nothing(self) -> Optional[str]: + def name_or_nothing(self) -> str: if self.name is None: return "" else: return self.name + " = " - def __or__(self, other) -> Parser: + @overload + def __or__(self, other: Sequence[Input]) -> Parser[Input, Union[Output, Sequence[Input]]]: ... + + @overload + def __or__( + self, other: Parser[Input, OtherOutput] + ) -> Parser[Input, Union[Output, OtherOutput]]: ... + + def __or__( + self, other: Union[Sequence[Input], Parser[Input, OtherOutput]] + ) -> Parser[Input, object]: from ._alternative import LongestAlternativeParser - other = wrap_literal(other) - parsers: list[Parser] = [] + narrowed_other = wrap_literal(other) + parsers: list[Parser[Input, object]] = [] if isinstance(self, LongestAlternativeParser) and not self.protected: parsers.extend(self.parsers) else: parsers.append(self) - if isinstance(other, LongestAlternativeParser) and not other.protected: - parsers.extend(other.parsers) + if isinstance(narrowed_other, LongestAlternativeParser) and not narrowed_other.protected: + parsers.extend(narrowed_other.parsers) else: - parsers.append(other) + parsers.append(narrowed_other) return LongestAlternativeParser(*parsers) - def __ror__(self, other) -> Parser: - other = wrap_literal(other) - return other.__or__(self) + @overload + def __ror__(self, other: Sequence[Input]) -> Parser[Input, Union[Sequence[Input], Output]]: ... + + @overload + def __ror__( + self, other: Parser[Input, OtherOutput] + ) -> Parser[Input, Union[OtherOutput, Output]]: ... + + def __ror__( + self, other: Union[Sequence[Input], Parser[Input, OtherOutput]] + ) -> Parser[Input, object]: + narrowed_other = wrap_literal(other) + return narrowed_other.__or__(self) - def __and__(self, other) -> Parser: + @overload + def __and__(self, other: Sequence[Input]) -> Parser[Input, Sequence[Any]]: ... + + @overload + def __and__(self, other: Parser[Input, OtherOutput]) -> Parser[Input, Sequence[Any]]: ... + + def __and__( + self, other: Union[Sequence[Input], Parser[Input, OtherOutput]] + ) -> Parser[Input, Sequence[Any]]: from ._sequential import SequentialParser - other = wrap_literal(other) - if isinstance(self, SequentialParser) and not self.protected: - return SequentialParser(*self.parsers, other) + narrowed_other = wrap_literal(other) + if isinstance(self, SequentialParser) and not self.protected: # type: ignore + return SequentialParser(*self.parsers, narrowed_other) # type: ignore else: - return SequentialParser(self, other) + return SequentialParser(self, narrowed_other) + + @overload + def __rand__(self, other: Sequence[Input]) -> Parser[Input, Sequence[Any]]: ... + + @overload + def __rand__(self, other: Parser[Input, OtherOutput]) -> Parser[Input, Sequence[Any]]: ... - def __rand__(self, other) -> Parser: - other = wrap_literal(other) - return other.__and__(self) + def __rand__( + self, other: Union[Sequence[Input], Parser[Input, OtherOutput]] + ) -> Parser[Input, Sequence[Any]]: + narrowed_other = wrap_literal(other) + return narrowed_other.__and__(self) - def __rshift__(self, other) -> Parser: + @overload + def __rshift__(self, other: Sequence[Input]) -> Parser[Input, Sequence[Input]]: ... + + @overload + def __rshift__(self, other: Parser[Input, OtherOutput]) -> Parser[Input, OtherOutput]: ... + + def __rshift__( + self, other: Union[Sequence[Input], Parser[Input, OtherOutput]] + ) -> Parser[Input, object]: from ._sequential import DiscardLeftParser - other = wrap_literal(other) - return DiscardLeftParser(self, other) + narrowed_other = wrap_literal(other) + return DiscardLeftParser(self, narrowed_other) - def __rrshift__(self, other) -> Parser: - other = wrap_literal(other) - return other.__rshift__(self) + def __rrshift__( + self, other: Union[Sequence[Input], Parser[Input, OtherOutput]] + ) -> Parser[Input, Output]: + narrowed_other = wrap_literal(other) + return narrowed_other.__rshift__(self) - def __lshift__(self, other) -> Parser: + def __lshift__( + self, other: Union[Sequence[Input], Parser[Input, OtherOutput]] + ) -> Parser[Input, Output]: from ._sequential import DiscardRightParser - other = wrap_literal(other) - return DiscardRightParser(self, other) + narrowed_other = wrap_literal(other) + return DiscardRightParser(self, narrowed_other) + + @overload + def __rlshift__(self, other: Sequence[Input]) -> Parser[Input, Sequence[Input]]: ... + + @overload + def __rlshift__(self, other: Parser[Input, OtherOutput]) -> Parser[Input, OtherOutput]: ... - def __rlshift__(self, other) -> Parser: - other = wrap_literal(other) - return other.__lshift__(self) + def __rlshift__( + self, other: Union[Sequence[Input], Parser[Input, OtherOutput]] + ) -> Parser[Input, object]: + narrowed_other = wrap_literal(other) + return narrowed_other.__lshift__(self) - def __gt__(self, other) -> Parser: + def __gt__(self, other: Callable[[Output], OtherOutput]) -> Parser[Input, OtherOutput]: from ._conversion import ConversionParser return ConversionParser(self, other) - def __ge__(self, other) -> Parser: + def __ge__( + self, other: Callable[[Output], Parser[Input, OtherOutput]] + ) -> Parser[Input, OtherOutput]: from ._conversion import TransformationParser return TransformationParser(self, other) diff --git a/src/parsita/parsers/_conversion.py b/src/parsita/parsers/_conversion.py index d3ee8b2..b65c3b1 100644 --- a/src/parsita/parsers/_conversion.py +++ b/src/parsita/parsers/_conversion.py @@ -2,10 +2,11 @@ from typing import Callable, Generic, Optional, TypeVar -from ..state import Continue, Input, Output, Reader, State +from ..state import Continue, Input, Reader, State from ._base import Parser -Convert = TypeVar("Convert") +Output = TypeVar("Output") +Convert = TypeVar("Convert", covariant=True) class ConversionParser(Generic[Input, Output, Convert], Parser[Input, Convert]): @@ -14,9 +15,7 @@ def __init__(self, parser: Parser[Input, Output], converter: Callable[[Output], self.parser = parser self.converter = converter - def _consume( - self, state: State[Input], reader: Reader[Input] - ) -> Optional[Continue[Input, Convert]]: + def _consume(self, state: State, reader: Reader[Input]) -> Optional[Continue[Input, Convert]]: status = self.parser.consume(state, reader) if isinstance(status, Continue): @@ -24,7 +23,7 @@ def _consume( else: return None - def __repr__(self): + def __repr__(self) -> str: return self.name_or_nothing() + f"{self.parser!r} > {self.converter.__name__}" @@ -38,9 +37,7 @@ def __init__( self.parser = parser self.transformer = transformer - def _consume( - self, state: State[Input], reader: Reader[Input] - ) -> Optional[Continue[Input, Convert]]: + def _consume(self, state: State, reader: Reader[Input]) -> Optional[Continue[Input, Convert]]: status = self.parser.consume(state, reader) if isinstance(status, Continue): diff --git a/src/parsita/parsers/_debug.py b/src/parsita/parsers/_debug.py index e82094b..4af1148 100644 --- a/src/parsita/parsers/_debug.py +++ b/src/parsita/parsers/_debug.py @@ -1,8 +1,8 @@ __all__ = ["DebugParser", "debug"] -from typing import Callable, Generic, Optional +from typing import Any, Callable, Generic, Optional, Sequence, Union, overload -from ..state import Input, Output, Reader, State +from ..state import Continue, Input, Output, Reader, State from ._base import Parser, wrap_literal @@ -19,7 +19,7 @@ def __init__( self.callback = callback self._parser_string = repr(parser) - def _consume(self, state: State[Input], reader: Reader[Input]): + def _consume(self, state: State, reader: Reader[Input]) -> Optional[Continue[Input, Output]]: if self.verbose: print(f"""Evaluating token {reader.next_token()} using parser {self._parser_string}""") @@ -33,16 +33,34 @@ def _consume(self, state: State[Input], reader: Reader[Input]): return result - def __repr__(self): + def __repr__(self) -> str: return self.name_or_nothing() + f"debug({self.parser.name_or_repr()})" +@overload +def debug( + parser: Sequence[Input], + *, + verbose: bool = False, + callback: Optional[Callable[[Parser[Input, Sequence[Input]], Reader[Input]], None]] = None, +) -> DebugParser[Input, Sequence[Input]]: ... + + +@overload def debug( parser: Parser[Input, Output], *, verbose: bool = False, callback: Optional[Callable[[Parser[Input, Output], Reader[Input]], None]] = None, -) -> DebugParser: +) -> DebugParser[Input, Output]: ... + + +def debug( + parser: Union[Parser[Input, Output], Sequence[Input]], + *, + verbose: bool = False, + callback: Optional[Callable[[Parser[Input, Output], Reader[Input]], None]] = None, +) -> DebugParser[Input, Any]: """Execute debugging hooks before a parser. This parser is used purely for debugging purposes. From a parsing diff --git a/src/parsita/parsers/_end_of_source.py b/src/parsita/parsers/_end_of_source.py index 21baf7a..a2c8654 100644 --- a/src/parsita/parsers/_end_of_source.py +++ b/src/parsita/parsers/_end_of_source.py @@ -1,25 +1,27 @@ __all__ = ["EndOfSourceParser", "eof"] -from typing import Generic, Optional +from typing import Optional, TypeVar -from ..state import Continue, Input, Reader, State +from ..state import Continue, Reader, State from ._base import Parser +FunctionInput = TypeVar("FunctionInput") -class EndOfSourceParser(Generic[Input], Parser[Input, None]): - def __init__(self): + +class EndOfSourceParser(Parser[object, None]): + def __init__(self) -> None: super().__init__() def _consume( - self, state: State[Input], reader: Reader[Input] - ) -> Optional[Continue[Input, None]]: + self, state: State, reader: Reader[FunctionInput] + ) -> Optional[Continue[FunctionInput, None]]: if reader.finished: return Continue(reader, None) else: state.register_failure("end of source", reader) return None - def __repr__(self): + def __repr__(self) -> str: return self.name_or_nothing() + "eof" diff --git a/src/parsita/parsers/_literal.py b/src/parsita/parsers/_literal.py index e638dee..860d86f 100644 --- a/src/parsita/parsers/_literal.py +++ b/src/parsita/parsers/_literal.py @@ -1,26 +1,31 @@ __all__ = ["LiteralParser", "lit"] -from typing import Any, Optional, Sequence +from typing import Any, Generic, Optional, Sequence, TypeVar, Union, overload from .. import options -from ..state import Continue, Input, Reader, State, StringReader +from ..state import Continue, Element, Reader, State, StringReader from ._base import Parser +# The bound should be Sequence[Element], but mypy doesn't support higher-kinded types. +Literal = TypeVar("Literal", bound=Sequence[Any], covariant=True) -class LiteralParser(Parser[Input, Input]): - def __init__(self, pattern: Sequence[Input], whitespace: Optional[Parser[Input, Any]] = None): + +class LiteralParser(Generic[Element, Literal], Parser[Element, Literal]): + def __init__(self, pattern: Literal, whitespace: Optional[Parser[Element, object]] = None): super().__init__() self.pattern = pattern self.whitespace = whitespace - def _consume(self, state: State[Input], reader: Reader[Input]): + def _consume( + self, state: State, reader: Reader[Element] + ) -> Optional[Continue[Element, Literal]]: if self.whitespace is not None: status = self.whitespace.consume(state, reader) - reader = status.remainder + reader = status.remainder # type: ignore # whitespace is infallible if isinstance(reader, StringReader): - if reader.source.startswith(self.pattern, reader.position): - reader = reader.drop(len(self.pattern)) + if reader.source.startswith(self.pattern, reader.position): # type: ignore + reader = reader.drop(len(self.pattern)) # type: ignore else: state.register_failure(repr(self.pattern), reader) return None @@ -37,15 +42,32 @@ def _consume(self, state: State[Input], reader: Reader[Input]): if self.whitespace is not None: status = self.whitespace.consume(state, reader) - reader = status.remainder + reader = status.remainder # type: ignore # whitespace is infallible return Continue(reader, self.pattern) - def __repr__(self): + def __repr__(self) -> str: return self.name_or_nothing() + repr(self.pattern) -def lit(literal: Sequence[Input], *literals: Sequence[Input]) -> Parser[Input, Input]: +FunctionLiteral = TypeVar("FunctionLiteral", bound=Sequence[Any]) + + +@overload +def lit(literal: str, *literals: str) -> Parser[str, str]: ... + + +@overload +def lit(literal: bytes, *literals: bytes) -> Parser[int, bytes]: ... + + +@overload +def lit(literal: FunctionLiteral, *literals: FunctionLiteral) -> Parser[Any, FunctionLiteral]: ... + + +def lit( + literal: Union[FunctionLiteral, str, bytes], *literals: Union[FunctionLiteral, str, bytes] +) -> Parser[Element, object]: """Match a literal sequence. This parser returns successfully if the subsequence of the parsing input diff --git a/src/parsita/parsers/_optional.py b/src/parsita/parsers/_optional.py index 2ecf768..3081561 100644 --- a/src/parsita/parsers/_optional.py +++ b/src/parsita/parsers/_optional.py @@ -1,17 +1,17 @@ __all__ = ["OptionalParser", "opt"] -from typing import Generic, List, Sequence, Union +from typing import Generic, Sequence, Union, overload from ..state import Continue, Input, Output, Reader, State from ._base import Parser, wrap_literal -class OptionalParser(Generic[Input, Output], Parser[Input, List[Output]]): +class OptionalParser(Generic[Input, Output], Parser[Input, Sequence[Output]]): def __init__(self, parser: Parser[Input, Output]): super().__init__() self.parser = parser - def _consume(self, state: State[Input], reader: Reader[Input]): + def _consume(self, state: State, reader: Reader[Input]) -> Continue[Input, Sequence[Output]]: status = self.parser.consume(state, reader) if isinstance(status, Continue): @@ -19,11 +19,25 @@ def _consume(self, state: State[Input], reader: Reader[Input]): else: return Continue(reader, []) - def __repr__(self): + def __repr__(self) -> str: return self.name_or_nothing() + f"opt({self.parser.name_or_repr()})" -def opt(parser: Union[Parser[Input, Output], Sequence[Input]]) -> OptionalParser[Input, Output]: +@overload +def opt( + parser: Sequence[Input], +) -> OptionalParser[Input, Sequence[Input]]: ... + + +@overload +def opt( + parser: Parser[Input, Output], +) -> OptionalParser[Input, Output]: ... + + +def opt( + parser: Union[Parser[Input, Output], Sequence[Input]], +) -> OptionalParser[Input, object]: """Optionally match a parser. An ``OptionalParser`` attempts to match ``parser``. If it succeeds, it diff --git a/src/parsita/parsers/_predicate.py b/src/parsita/parsers/_predicate.py index 43f5d6a..5754f22 100644 --- a/src/parsita/parsers/_predicate.py +++ b/src/parsita/parsers/_predicate.py @@ -1,12 +1,12 @@ __all__ = ["PredicateParser", "pred"] -from typing import Callable, Generic +from typing import Callable, Generic, Optional from ..state import Continue, Input, Output, Reader, State from ._base import Parser, wrap_literal -class PredicateParser(Generic[Input, Output], Parser[Input, Input]): +class PredicateParser(Generic[Input, Output], Parser[Input, Output]): def __init__( self, parser: Parser[Input, Output], predicate: Callable[[Output], bool], description: str ): @@ -15,7 +15,7 @@ def __init__( self.predicate = predicate self.description = description - def _consume(self, state: State[Input], reader: Reader[Input]): + def _consume(self, state: State, reader: Reader[Input]) -> Optional[Continue[Input, Output]]: status = self.parser.consume(state, reader) if isinstance(status, Continue): if self.predicate(status.value): @@ -26,7 +26,7 @@ def _consume(self, state: State[Input], reader: Reader[Input]): else: return status - def __repr__(self): + def __repr__(self) -> str: return self.name_or_nothing() + f"pred({self.parser.name_or_repr()}, {self.description})" diff --git a/src/parsita/parsers/_regex.py b/src/parsita/parsers/_regex.py index 4af22cc..d27b422 100644 --- a/src/parsita/parsers/_regex.py +++ b/src/parsita/parsers/_regex.py @@ -1,22 +1,35 @@ __all__ = ["RegexParser", "reg"] import re -from typing import Any, Generic, Optional, TypeVar, Union +from typing import Any, Generic, Optional, TypeVar, Union, no_type_check from .. import options -from ..state import Continue, Reader, State +from ..state import Continue, State, StringReader from ._base import Parser StringType = TypeVar("StringType", str, bytes) -class RegexParser(Generic[StringType], Parser[StringType, StringType]): - def __init__(self, pattern: re.Pattern, whitespace: Optional[Parser[StringType, Any]] = None): +# The Element type is str for str and int for bytes, but there is no way to +# express that in Python. +class RegexParser(Generic[StringType], Parser[Any, StringType]): + def __init__( + self, + pattern: re.Pattern[StringType], + whitespace: Optional[Parser[StringType, object]] = None, + ): super().__init__() - self.pattern = pattern - self.whitespace = whitespace + self.pattern: re.Pattern[StringType] = pattern + self.whitespace: Optional[Parser[StringType, object]] = whitespace - def _consume(self, state: State[StringType], reader: Reader[StringType]): + # RegexParser is special in that is assumes StringReader is the only + # possible reader for strings and bytes. This is technically unsound. + @no_type_check + def _consume( + self, + state: State, + reader: StringReader, + ) -> Optional[Continue[StringType, StringType]]: if self.whitespace is not None: status = self.whitespace.consume(state, reader) reader = status.remainder @@ -40,7 +53,7 @@ def __repr__(self) -> str: return self.name_or_nothing() + f"reg({self.pattern.pattern!r})" -def reg(pattern: Union[re.Pattern, StringType]) -> RegexParser[StringType]: +def reg(pattern: Union[re.Pattern[StringType], StringType]) -> RegexParser[StringType]: """Match with a regular expression. This matches the text with a regular expression. The regular expressions is diff --git a/src/parsita/parsers/_repeated.py b/src/parsita/parsers/_repeated.py index c368b66..0079ad4 100644 --- a/src/parsita/parsers/_repeated.py +++ b/src/parsita/parsers/_repeated.py @@ -1,6 +1,6 @@ __all__ = ["RepeatedOnceParser", "rep1", "RepeatedParser", "rep"] -from typing import Generic, List, Optional, Sequence, Union +from typing import Generic, Optional, Sequence, Union, overload from ..state import Continue, Input, Output, Reader, RecursionError, State from ._base import Parser, wrap_literal @@ -11,14 +11,16 @@ def __init__(self, parser: Parser[Input, Output]): super().__init__() self.parser = parser - def _consume(self, state: State[Input], reader: Reader[Input]): - status = self.parser.consume(state, reader) + def _consume( + self, state: State, reader: Reader[Input] + ) -> Optional[Continue[Input, Sequence[Output]]]: + initial_status = self.parser.consume(state, reader) - if status is None: + if initial_status is None: return None else: - output = [status.value] - remainder = status.remainder + output = [initial_status.value] + remainder = initial_status.remainder while True: status = self.parser.consume(state, remainder) if isinstance(status, Continue): @@ -30,13 +32,21 @@ def _consume(self, state: State[Input], reader: Reader[Input]): else: return Continue(remainder, output) - def __repr__(self): + def __repr__(self) -> str: return self.name_or_nothing() + f"rep1({self.parser.name_or_repr()})" +@overload +def rep1(parser: Sequence[Input]) -> RepeatedOnceParser[Input, Sequence[Input]]: ... + + +@overload +def rep1(parser: Parser[Input, Output]) -> RepeatedOnceParser[Input, Output]: ... + + def rep1( parser: Union[Parser[Input, Output], Sequence[Input]], -) -> RepeatedOnceParser[Input, Output]: +) -> RepeatedOnceParser[Input, object]: """Match a parser one or more times repeatedly. This matches ``parser`` multiple times in a row. If it matches as least @@ -57,8 +67,10 @@ def __init__(self, parser: Parser[Input, Output], *, min: int = 0, max: Optional self.min = min self.max = max - def _consume(self, state: State[Input], reader: Reader[Input]): - output: List[Output] = [] + def _consume( + self, state: State, reader: Reader[Input] + ) -> Optional[Continue[Input, Sequence[Output]]]: + output: list[Output] = [] remainder = reader while self.max is None or len(output) < self.max: @@ -77,16 +89,37 @@ def _consume(self, state: State[Input], reader: Reader[Input]): else: return None - def __repr__(self): + def __repr__(self) -> str: min_string = f", min={self.min}" if self.min > 0 else "" max_string = f", max={self.max}" if self.max is not None else "" string = f"rep({self.parser.name_or_repr()}{min_string}{max_string})" return self.name_or_nothing() + string +@overload def rep( - parser: Union[Parser, Sequence[Input]], *, min: int = 0, max: Optional[int] = None -) -> RepeatedParser[Input, Output]: + parser: Sequence[Input], + *, + min: int = 0, + max: Optional[int] = None, +) -> RepeatedParser[Input, Sequence[Input]]: ... + + +@overload +def rep( + parser: Parser[Input, Output], + *, + min: int = 0, + max: Optional[int] = None, +) -> RepeatedParser[Input, Output]: ... + + +def rep( + parser: Union[Parser[Input, Output], Sequence[Input]], + *, + min: int = 0, + max: Optional[int] = None, +) -> RepeatedParser[Input, object]: """Match a parser zero or more times repeatedly. This matches ``parser`` multiple times in a row. A list is returned diff --git a/src/parsita/parsers/_repeated_seperated.py b/src/parsita/parsers/_repeated_seperated.py index b299d4f..0a59965 100644 --- a/src/parsita/parsers/_repeated_seperated.py +++ b/src/parsita/parsers/_repeated_seperated.py @@ -1,6 +1,6 @@ __all__ = ["RepeatedSeparatedParser", "repsep", "RepeatedOnceSeparatedParser", "rep1sep"] -from typing import Any, Generic, Optional, Sequence, Union +from typing import Generic, Optional, Sequence, Union, overload from ..state import Continue, Input, Output, Reader, RecursionError, State from ._base import Parser, wrap_literal @@ -10,7 +10,7 @@ class RepeatedSeparatedParser(Generic[Input, Output], Parser[Input, Sequence[Out def __init__( self, parser: Parser[Input, Output], - separator: Parser[Input, Any], + separator: Parser[Input, object], *, min: int = 0, max: Optional[int] = None, @@ -21,30 +21,32 @@ def __init__( self.min = min self.max = max - def _consume(self, state: State[Input], reader: Reader[Input]): - status = self.parser.consume(state, reader) + def _consume( + self, state: State, reader: Reader[Input] + ) -> Optional[Continue[Input, Sequence[Output]]]: + initial_status = self.parser.consume(state, reader) - if not isinstance(status, Continue): + if not isinstance(initial_status, Continue): output = [] remainder = reader else: - output = [status.value] - remainder = status.remainder + output = [initial_status.value] + remainder = initial_status.remainder while self.max is None or len(output) < self.max: # If the separator matches, but the parser does not, the # remainder from the last successful parser step must be used, # not the remainder from any separator. That is why the parser # starts from the remainder on the status, but remainder is not # updated until after the parser succeeds. - status = self.separator.consume(state, remainder) - if isinstance(status, Continue): - status = self.parser.consume(state, status.remainder) - if isinstance(status, Continue): - if remainder.position == status.remainder.position: + separator_status = self.separator.consume(state, remainder) + if isinstance(separator_status, Continue): + parser_status = self.parser.consume(state, separator_status.remainder) + if isinstance(parser_status, Continue): + if remainder.position == parser_status.remainder.position: raise RecursionError(self, remainder) - remainder = status.remainder - output.append(status.value) + remainder = parser_status.remainder + output.append(parser_status.value) else: break else: @@ -55,7 +57,7 @@ def _consume(self, state: State[Input], reader: Reader[Input]): else: return None - def __repr__(self): + def __repr__(self) -> str: rep_string = self.parser.name_or_repr() sep_string = self.separator.name_or_repr() min_string = f", min={self.min}" if self.min > 0 else "" @@ -64,13 +66,33 @@ def __repr__(self): return self.name_or_nothing() + string +@overload +def repsep( + parser: Sequence[Input], + separator: Union[Parser[Input, object], Sequence[Input]], + *, + min: int = 0, + max: Optional[int] = None, +) -> RepeatedSeparatedParser[Input, Sequence[Input]]: ... + + +@overload +def repsep( + parser: Parser[Input, Output], + separator: Union[Parser[Input, object], Sequence[Input]], + *, + min: int = 0, + max: Optional[int] = None, +) -> RepeatedSeparatedParser[Input, Output]: ... + + def repsep( parser: Union[Parser[Input, Output], Sequence[Input]], - separator: Union[Parser[Input, Any], Sequence[Input]], + separator: Union[Parser[Input, object], Sequence[Input]], *, min: int = 0, max: Optional[int] = None, -) -> RepeatedSeparatedParser[Input, Output]: +) -> RepeatedSeparatedParser[Input, object]: """Match a parser zero or more times separated by another parser. This matches repeated sequences of ``parser`` separated by ``separator``. A @@ -90,48 +112,62 @@ def repsep( class RepeatedOnceSeparatedParser(Generic[Input, Output], Parser[Input, Sequence[Output]]): - def __init__(self, parser: Parser[Input, Output], separator: Parser[Input, Any]): + def __init__(self, parser: Parser[Input, Output], separator: Parser[Input, object]): super().__init__() self.parser = parser self.separator = separator - def _consume(self, state: State[Input], reader: Reader[Input]): - status = self.parser.consume(state, reader) + def _consume( + self, state: State, reader: Reader[Input] + ) -> Optional[Continue[Input, Sequence[Output]]]: + initial_status = self.parser.consume(state, reader) - if status is None: + if initial_status is None: return None else: - output = [status.value] - remainder = status.remainder + output = [initial_status.value] + remainder = initial_status.remainder while True: # If the separator matches, but the parser does not, the # remainder from the last successful parser step must be used, # not the remainder from any separator. That is why the parser # starts from the remainder on the status, but remainder is not # updated until after the parser succeeds. - status = self.separator.consume(state, remainder) - if isinstance(status, Continue): - status = self.parser.consume(state, status.remainder) - if isinstance(status, Continue): - if remainder.position == status.remainder.position: + separator_status = self.separator.consume(state, remainder) + if isinstance(separator_status, Continue): + parser_status = self.parser.consume(state, separator_status.remainder) + if isinstance(parser_status, Continue): + if remainder.position == parser_status.remainder.position: raise RecursionError(self, remainder) - remainder = status.remainder - output.append(status.value) + remainder = parser_status.remainder + output.append(parser_status.value) else: return Continue(remainder, output) else: return Continue(remainder, output) - def __repr__(self): + def __repr__(self) -> str: string = f"rep1sep({self.parser.name_or_repr()}, {self.separator.name_or_repr()})" return self.name_or_nothing() + string +@overload +def rep1sep( + parser: Sequence[Input], separator: Union[Parser[Input, object], Sequence[Input]] +) -> RepeatedOnceSeparatedParser[Input, Sequence[Input]]: ... + + +@overload +def rep1sep( + parser: Parser[Input, Output], separator: Union[Parser[Input, object], Sequence[Input]] +) -> RepeatedOnceSeparatedParser[Input, Output]: ... + + def rep1sep( parser: Union[Parser[Input, Output], Sequence[Input]], - separator: Union[Parser[Input, Any], Sequence[Input]], -) -> RepeatedOnceSeparatedParser[Input, Output]: + separator: Union[Parser[Input, object], Sequence[Input]], +) -> RepeatedOnceSeparatedParser[Input, object]: """Match a parser one or more times separated by another parser. This matches repeated sequences of ``parser`` separated by ``separator``. diff --git a/src/parsita/parsers/_sequential.py b/src/parsita/parsers/_sequential.py index 7ec04f6..b9a3dcd 100644 --- a/src/parsita/parsers/_sequential.py +++ b/src/parsita/parsers/_sequential.py @@ -1,18 +1,20 @@ __all__ = ["SequentialParser", "DiscardLeftParser", "DiscardRightParser"] -from typing import Any, Generic, List +from typing import Any, Generic, Optional, Sequence from ..state import Continue, Input, Output, Reader, State from ._base import Parser # Type of this class is inexpressible -class SequentialParser(Generic[Input], Parser[Input, List[Any]]): +class SequentialParser(Generic[Input], Parser[Input, Sequence[Any]]): def __init__(self, parser: Parser[Input, Any], *parsers: Parser[Input, Any]): super().__init__() self.parsers = (parser, *parsers) - def _consume(self, state: State[Input], reader: Reader[Input]): + def _consume( + self, state: State, reader: Reader[Input] + ) -> Optional[Continue[Input, Sequence[Any]]]: output = [] remainder = reader @@ -26,7 +28,7 @@ def _consume(self, state: State[Input], reader: Reader[Input]): return Continue(remainder, output) - def __repr__(self): + def __repr__(self) -> str: names = [] for parser in self.parsers: names.append(parser.name_or_repr()) @@ -35,30 +37,30 @@ def __repr__(self): class DiscardLeftParser(Generic[Input, Output], Parser[Input, Output]): - def __init__(self, left: Parser[Input, Any], right: Parser[Input, Output]): + def __init__(self, left: Parser[Input, object], right: Parser[Input, Output]): super().__init__() self.left = left self.right = right - def _consume(self, state: State[Input], reader: Reader[Input]): + def _consume(self, state: State, reader: Reader[Input]) -> Optional[Continue[Input, Output]]: status = self.left.consume(state, reader) if isinstance(status, Continue): return self.right.consume(state, status.remainder) else: return None - def __repr__(self): + def __repr__(self) -> str: string = f"{self.left.name_or_repr()} >> {self.right.name_or_repr()}" return self.name_or_nothing() + string class DiscardRightParser(Generic[Input, Output], Parser[Input, Output]): - def __init__(self, left: Parser[Input, Output], right: Parser[Input, Any]): + def __init__(self, left: Parser[Input, Output], right: Parser[Input, object]): super().__init__() self.left = left self.right = right - def _consume(self, state: State[Input], reader: Reader[Input]): + def _consume(self, state: State, reader: Reader[Input]) -> Optional[Continue[Input, Output]]: status1 = self.left.consume(state, reader) if isinstance(status1, Continue): status2 = self.right.consume(state, status1.remainder) @@ -69,6 +71,6 @@ def _consume(self, state: State[Input], reader: Reader[Input]): else: return None - def __repr__(self): + def __repr__(self) -> str: string = f"{self.left.name_or_repr()} << {self.right.name_or_repr()}" return self.name_or_nothing() + string diff --git a/src/parsita/parsers/_success.py b/src/parsita/parsers/_success.py index e84d558..c349956 100644 --- a/src/parsita/parsers/_success.py +++ b/src/parsita/parsers/_success.py @@ -1,24 +1,29 @@ __all__ = ["SuccessParser", "success", "FailureParser", "failure"] -from typing import Any, Generic, NoReturn +from typing import Generic, NoReturn, TypeVar -from ..state import Continue, Input, Output, Reader, State +from ..state import Continue, Output, Reader, State from ._base import Parser +FunctionInput = TypeVar("FunctionInput") +FunctionOutput = TypeVar("FunctionOutput") -class SuccessParser(Generic[Input, Output], Parser[Any, Output]): + +class SuccessParser(Generic[Output], Parser[object, Output]): def __init__(self, value: Output): super().__init__() self.value = value - def _consume(self, state: State[Input], reader: Reader[Input]) -> Continue[Input, Output]: + def _consume( + self, state: State, reader: Reader[FunctionInput] + ) -> Continue[FunctionInput, Output]: return Continue(reader, self.value) - def __repr__(self): + def __repr__(self) -> str: return self.name_or_nothing() + f"success({self.value!r})" -def success(value: Output) -> SuccessParser[Input, Output]: +def success(value: FunctionOutput) -> SuccessParser[FunctionOutput]: """Always succeed in matching and return a value. This parser always succeeds and returns the given ``value``. No input is @@ -31,20 +36,20 @@ def success(value: Output) -> SuccessParser[Input, Output]: return SuccessParser(value) -class FailureParser(Generic[Input], Parser[Input, NoReturn]): +class FailureParser(Parser[object, NoReturn]): def __init__(self, expected: str): super().__init__() self.expected = expected - def _consume(self, state: State[Input], reader: Reader[Input]) -> None: + def _consume(self, state: State, reader: Reader[object]) -> None: state.register_failure(self.expected, reader) return None - def __repr__(self): + def __repr__(self) -> str: return self.name_or_nothing() + f"failure({self.expected!r})" -def failure(expected: str = "") -> FailureParser[Input]: +def failure(expected: str = "") -> FailureParser: """Always fail in matching with a given message. This parser always backtracks with a message that it is expecting the diff --git a/src/parsita/parsers/_until.py b/src/parsita/parsers/_until.py index 9633f63..3aaf796 100644 --- a/src/parsita/parsers/_until.py +++ b/src/parsita/parsers/_until.py @@ -1,17 +1,19 @@ __all__ = ["UntilParser", "until"] -from typing import Any, Generic +from typing import Any, Generic, Optional, Sequence, Union -from ..state import Continue, Input, Output, Reader, State +from ..state import Continue, Element, Reader, State from ._base import Parser, wrap_literal -class UntilParser(Generic[Input], Parser[Input, Input]): - def __init__(self, parser: Parser[Input, Any]): +class UntilParser(Generic[Element], Parser[Element, Sequence[Element]]): + def __init__(self, parser: Parser[Element, Any]): super().__init__() self.parser = parser - def _consume(self, state: State[Input], reader: Reader[Input]): + def _consume( + self, state: State, reader: Reader[Element] + ) -> Optional[Continue[Element, Sequence[Element]]]: start_position = reader.position while True: status = self.parser.consume(state, reader) @@ -25,14 +27,14 @@ def _consume(self, state: State[Input], reader: Reader[Input]): return Continue(reader, reader.source[start_position : reader.position]) - def __repr__(self): + def __repr__(self) -> str: return self.name_or_nothing() + f"until({self.parser.name_or_repr()})" -def until(parser: Parser[Input, Output]) -> UntilParser: +def until(parser: Union[Parser[Element, object], Sequence[Element]]) -> UntilParser[Element]: """Match everything until it matches the provided parser. - This parser matches all input until it encounters a position in the input + This parser matches all Element until it encounters a position in the Element where the given ``parser`` succeeds. Args: diff --git a/src/parsita/py.typed b/src/parsita/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/parsita/state/__init__.py b/src/parsita/state/__init__.py index 1897318..69587f1 100644 --- a/src/parsita/state/__init__.py +++ b/src/parsita/state/__init__.py @@ -1,4 +1,4 @@ from ._exceptions import ParseError, RecursionError from ._reader import Reader, SequenceReader, StringReader from ._result import Failure, Result, Success -from ._state import Continue, Input, Output, State +from ._state import Continue, Element, Input, Output, State diff --git a/src/parsita/state/_exceptions.py b/src/parsita/state/_exceptions.py index 9bab13f..27355c9 100644 --- a/src/parsita/state/_exceptions.py +++ b/src/parsita/state/_exceptions.py @@ -21,7 +21,7 @@ class ParseError(Exception): farthest: Reader[Any] expected: list[str] - def __str__(self): + def __str__(self) -> str: return self.farthest.expected_error(self.expected) @@ -35,5 +35,5 @@ class RecursionError(Exception): parser: Parser[Any, Any] context: Reader[Any] - def __str__(self): + def __str__(self) -> str: return self.context.recursion_error(repr(self.parser)) diff --git a/src/parsita/state/_reader.py b/src/parsita/state/_reader.py index 1056439..c7c2ab5 100644 --- a/src/parsita/state/_reader.py +++ b/src/parsita/state/_reader.py @@ -5,9 +5,9 @@ import re from dataclasses import dataclass from io import StringIO -from typing import Generic, Sequence, TypeVar +from typing import TYPE_CHECKING, Generic, Sequence, TypeVar -Input = TypeVar("Input") +Input = TypeVar("Input", covariant=True) class Reader(Generic[Input]): @@ -25,16 +25,27 @@ class Reader(Generic[Input]): source (Sequence[Input]): The full source being read. """ - # Despite what mypy says, these cannot be converted to properties because - # they will break the dataclass attributes of the subclasses. - first: Input - rest: Reader[Input] - position: int - finished: bool - source: Sequence[Input] + if TYPE_CHECKING: + # These abstract properties cannot exist at runtime or they will break the + # dataclass subclasses + + @property + def first(self) -> Input: ... + + @property + def rest(self) -> Reader[Input]: ... + + @property + def position(self) -> int: ... + + @property + def finished(self) -> bool: ... + + @property + def source(self) -> Sequence[Input]: ... def drop(self, count: int) -> Reader[Input]: - """Advance the reader by a ``count`` elements. + """Advance the reader by ``count`` elements. Both ``SequenceReader`` and ``StringReader`` override this method with a more efficient implementation. @@ -44,7 +55,7 @@ def drop(self, count: int) -> Reader[Input]: rest = rest.rest return rest - def next_token(self): + def next_token(self) -> Input: return self.first def expected_error(self, expected: Sequence[str]) -> str: @@ -71,7 +82,7 @@ def expected_error(self, expected: Sequence[str]) -> str: f"at index {self.position}" ) - def recursion_error(self, repeated_parser: str): + def recursion_error(self, repeated_parser: str) -> str: """Generate an error to indicate that infinite recursion was encountered. A parser can supply a representation of itself to this method and the @@ -98,7 +109,7 @@ def recursion_error(self, repeated_parser: str): @dataclass(frozen=True) -class SequenceReader(Reader[Input]): +class SequenceReader(Generic[Input], Reader[Input]): """A reader for sequences that should not be sliced. Python makes a copy when a sequence is sliced. This reader avoids making @@ -173,7 +184,7 @@ def next_token(self) -> str: else: return self.source[match.start() : match.end()] - def _current_line(self): + def _current_line(self) -> tuple[int, int, str, str]: # StringIO is not consistent in how it treats empty strings # and other strings not ending in newlines. Ensure that the # source always ends in a newline. @@ -203,7 +214,7 @@ def _current_line(self): # Add one to indexes to account for 0-indexes return line_index + 1, character_index + 1, line, pointer - def expected_error(self, expected: str) -> str: + def expected_error(self, expected: Sequence[str]) -> str: """Generate a basic error to include the current state. A parser can supply only a representation of what it is expecting to @@ -230,7 +241,7 @@ def expected_error(self, expected: str) -> str: f"Line {line_index}, character {character_index}\n\n{line}{pointer}" ) - def recursion_error(self, repeated_parser: str): + def recursion_error(self, repeated_parser: str) -> str: """Generate an error to indicate that infinite recursion was encountered. A parser can supply a representation of itself to this method and the diff --git a/src/parsita/state/_result.py b/src/parsita/state/_result.py index 2397b82..fe8a65a 100644 --- a/src/parsita/state/_result.py +++ b/src/parsita/state/_result.py @@ -1,6 +1,6 @@ __all__ = ["Result", "Success", "Failure"] -from typing import TYPE_CHECKING, TypeVar +from typing import TypeVar from returns import result @@ -9,11 +9,10 @@ Output = TypeVar("Output") # Reexport Returns Result types +# Failure and Result fail in isinstance +# Failure is replaced by plain Failure, which works at runtime +# Result is left as is because cannot be fixed without breaking eager type annotations Result = result.Result[Output, ParseError] Success = result.Success -if TYPE_CHECKING: - # This object fails in isinstance - # Result does too, but that cannot be fixed without breaking eager type annotations - Failure = result.Failure[ParseError] -else: - Failure = result.Failure +Failure: type[result.Failure[ParseError]] = result.Failure[ParseError] +Failure = result.Failure diff --git a/src/parsita/state/_state.py b/src/parsita/state/_state.py index e775940..c0cb852 100644 --- a/src/parsita/state/_state.py +++ b/src/parsita/state/_state.py @@ -1,6 +1,6 @@ from __future__ import annotations -__all__ = ["State", "Continue", "Input", "Output"] +__all__ = ["State", "Continue", "Input", "Output", "Element"] from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar @@ -10,17 +10,18 @@ if TYPE_CHECKING: from ..parsers import Parser -Input = TypeVar("Input") -Output = TypeVar("Output") +Input = TypeVar("Input", contravariant=True) +Output = TypeVar("Output", covariant=True) +Element = TypeVar("Element") -class State(Generic[Input]): - def __init__(self): - self.farthest: Optional[Reader[Any]] = None +class State: + def __init__(self) -> None: + self.farthest: Optional[Reader[object]] = None self.expected: list[str] = [] - self.memo: dict[tuple[Parser[Input, Any], int], Optional[Continue[Input, Any]]] = {} + self.memo: dict[tuple[Parser[Any, Any], int], Optional[Continue[Any, Any]]] = {} - def register_failure(self, expected: str, reader: Reader[Any]): + def register_failure(self, expected: str, reader: Reader[object]) -> None: if self.farthest is None or self.farthest.position < reader.position: self.expected.clear() self.expected.append(expected) diff --git a/src/parsita/util.py b/src/parsita/util.py index ed15c88..f6c69f8 100644 --- a/src/parsita/util.py +++ b/src/parsita/util.py @@ -1,11 +1,20 @@ +from __future__ import annotations + __all__ = ["constant", "splat", "unsplat"] -from typing import Callable, Iterable, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Sequence + +if TYPE_CHECKING: + # ParamSpec was introduced in Python 3.10 + # TypeVarTuple and Unpack were introduced in Python 3.11 + from typing import ParamSpec, TypeVar, TypeVarTuple, Unpack -A = TypeVar("A") + A = TypeVar("A") + P = ParamSpec("P") + Ts = TypeVarTuple("Ts") -def constant(x: A) -> Callable[..., A]: +def constant(x: A) -> Callable[P, A]: """Produce a function that always returns a supplied value. Args: @@ -16,13 +25,14 @@ def constant(x: A) -> Callable[..., A]: discards them, and returns ``x``. """ - def constanted(*args, **kwargs): + def constanted(*args: P.args, **kwargs: P.kwargs) -> A: return x return constanted -def splat(f: Callable[..., A]) -> Callable[[Iterable], A]: +# This signature cannot be expressed narrowly because SequenceParser does not return a tuple +def splat(f: Callable[[Unpack[Ts]], A], /) -> Callable[[Sequence[Any]], A]: """Convert a function of multiple arguments into a function of a single iterable argument. Args: @@ -41,13 +51,13 @@ def splat(f: Callable[..., A]) -> Callable[[Iterable], A]: $ g([1, 2, 3]) # 6 """ - def splatted(args): + def splatted(args: Sequence[Any], /) -> A: return f(*args) return splatted -def unsplat(f: Callable[[Iterable], A]) -> Callable[..., A]: +def unsplat(f: Callable[[tuple[Unpack[Ts]]], A]) -> Callable[..., A]: """Convert a function of a single iterable argument into a function of multiple arguments. Args: @@ -66,7 +76,7 @@ def unsplat(f: Callable[[Iterable], A]) -> Callable[..., A]: $ g(1, 2, 3) # 6 """ - def unsplatted(*args): + def unsplatted(*args: Unpack[Ts]) -> A: return f(args) return unsplatted diff --git a/tests/test_basic.py b/tests/test_basic.py index 80d2edb..d7fb37d 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,3 +1,5 @@ +from typing import Union + import pytest from parsita import ( @@ -424,14 +426,14 @@ class TestParsers(ParserContext): def test_conversion(): + def make_twentyone(x: tuple[float, float]) -> float: + return x[0] * 10 + x[1] + class TestParsers(ParserContext): one = lit("1") > int two = lit("2") > int twelve = one & two > (lambda x: x[0] * 10 + x[1]) - def make_twentyone(x): - return x[0] * 10 + x[1] - twentyone = two & one > make_twentyone assert TestParsers.one.parse("1") == Success(1) @@ -442,6 +444,14 @@ def make_twentyone(x): def test_recursion(): + def make_expr(x: tuple[float, Union[tuple[()], tuple[float]]]) -> float: + digits1, maybe_expr = x + if maybe_expr: + digits2 = maybe_expr[0] + return digits1 + digits2 + else: + return digits1 + class TestParsers(ParserContext): one = lit("1") > float six = lit("6") > float @@ -449,14 +459,6 @@ class TestParsers(ParserContext): numbers = eleven | one | six - def make_expr(x): - digits1, maybe_expr = x - if maybe_expr: - digits2 = maybe_expr[0] - return digits1 + digits2 - else: - return digits1 - expr = numbers & opt("+" >> expr) > make_expr assert TestParsers.expr.parse("11") == Success(11) diff --git a/tests/test_metaclass_scopes.py b/tests/test_metaclass_scopes.py index 0ba0185..0a260d8 100644 --- a/tests/test_metaclass_scopes.py +++ b/tests/test_metaclass_scopes.py @@ -3,7 +3,7 @@ from parsita import ParserContext, Success, lit -def convert(value: str): +def convert(value: str) -> str: return "global" @@ -20,7 +20,7 @@ def test_global_class_global_function(): class GlobalLocal(ParserContext): - def convert(value: str): + def convert(value: str) -> str: return "local" x = lit("x") > convert @@ -35,7 +35,7 @@ def test_global_class_local_function(): class GlobalInner(ParserContext): - def convert(value: str): + def convert(value: str) -> str: return "local" x = lit("x") > convert @@ -65,7 +65,7 @@ class Inner(ParserContext): def test_local_class_local_function(): class LocalLocal(ParserContext): - def convert(value: str): + def convert(value: str) -> str: return "local" x = lit("x") > convert @@ -79,13 +79,13 @@ class Inner(ParserContext): def test_inner_class_inner_function(): class LocalLocal(ParserContext): - def convert(value: str): + def convert(value: str) -> str: return "local" x = lit("x") > convert class Inner(ParserContext): - def convert(value: str): + def convert(value: str) -> str: return "nested" x = lit("x") > convert @@ -95,7 +95,7 @@ def convert(value: str): def test_nested_class_global_function(): - def nested(): + def nested() -> type[ParserContext]: class LocalLocal(ParserContext): x = lit("x") > convert @@ -110,7 +110,7 @@ class Inner(ParserContext): def factory(): - def convert(value: str): + def convert(value: str) -> str: return "local" class LocalLocal(ParserContext): @@ -123,7 +123,7 @@ class Inner(ParserContext): def test_factory_class_local_function(): - def convert(value: str): + def convert(value: str) -> str: return "caller" returned_class = factory() @@ -133,10 +133,10 @@ def convert(value: str): def test_nested_class_nonlocal_function(): - def convert(value: str): + def convert(value: str) -> str: return "nonlocal" - def nested(): + def nested() -> type[ParserContext]: class LocalLocal(ParserContext): x = lit("x") > convert @@ -152,10 +152,10 @@ class Inner(ParserContext): def test_nested_class_local_function(): - def convert(value: str): + def convert(value: str) -> str: return "nonlocal" - def nested(): + def nested() -> type[ParserContext]: def convert(value: str): return "local" diff --git a/tests/test_state.py b/tests/test_state.py index e3419c8..01ecb21 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -9,11 +9,13 @@ from parsita.state import Continue, State -def test_state_creation(): +def test_sequence_reader_creation(): read = SequenceReader([1, 2, 3]) assert read.first == 1 assert read.rest.first == 2 + +def test_string_reader_creation(): read = StringReader("a b") assert read.first == "a" assert read.rest.first == " " @@ -38,7 +40,7 @@ def test_parse_error_str_string_reader(): @pytest.mark.parametrize("source", ["a a", "a a\n"]) -def test_parse_error_str_string_reader_end_of_source(source): +def test_parse_error_str_string_reader_end_of_source(source: str): err = ParseError(StringReader(source, 4), ["'b'"]) assert str(err) == "Expected 'b' but found end of source\nLine 1, character 4\n\na a\n ^" @@ -52,6 +54,7 @@ def test_register_failure_first(): state = State() state.register_failure("foo", StringReader("bar baz", 0)) assert state.expected == ["foo"] + assert isinstance(state.farthest, Reader) assert state.farthest.position == 0 @@ -59,6 +62,7 @@ def test_register_failure_at_middle(): state = State() state.register_failure("foo", StringReader("bar baz", 4)) assert state.expected == ["foo"] + assert isinstance(state.farthest, Reader) assert state.farthest.position == 4 @@ -67,6 +71,7 @@ def test_register_failure_latest(): state.register_failure("foo", StringReader("bar baz", 0)) state.register_failure("egg", StringReader("bar baz", 4)) assert state.expected == ["egg"] + assert isinstance(state.farthest, Reader) assert state.farthest.position == 4 @@ -75,6 +80,7 @@ def test_register_failure_tied(): state.register_failure("foo", StringReader("bar baz", 4)) state.register_failure("egg", StringReader("bar baz", 4)) assert state.expected == ["foo", "egg"] + assert isinstance(state.farthest, Reader) assert state.farthest.position == 4 @@ -120,7 +126,7 @@ def test_reader_drop(): # drop, so this test is for that in case someone extends Reader. @dataclass(frozen=True) - class BytesReader(Reader): + class BytesReader(Reader[int]): source: bytes position: int = 0 diff --git a/tests/test_util.py b/tests/test_util.py index 9c3c536..bbb3d17 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -10,21 +10,21 @@ def test_constant(): def test_splat(): - def f(a, b, c): + def f(a: int, b: int, c: int) -> int: return a + b + c assert f(1, 2, 3) == 6 g = splat(f) - args = [1, 2, 3] + args = (1, 2, 3) assert g(args) == 6 def test_unsplat(): - def f(a): + def f(a: tuple[int, int, int]) -> int: return a[0] + a[1] + a[2] - args = [1, 2, 3] + args = (1, 2, 3) assert f(args) == 6 g = unsplat(f)