Skip to content

Commit

Permalink
Fix types
Browse files Browse the repository at this point in the history
  • Loading branch information
kdeldycke committed Apr 4, 2024
1 parent 5382bf3 commit d15fc8f
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 45 deletions.
2 changes: 1 addition & 1 deletion click_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def cache(user_function):
ExtraGroup,
)
from .config import ConfigOption # noqa: E402
from .decorators import ( # type: ignore[no-redef,has-type] # noqa: E402
from .decorators import ( # type: ignore[no-redef] # noqa: E402
color_option,
command,
config_option,
Expand Down
22 changes: 14 additions & 8 deletions click_extra/colorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
import dataclasses
import os
import re
from collections.abc import Iterable
from configparser import RawConfigParser
from dataclasses import dataclass
from gettext import gettext as _
from operator import getitem
from typing import Sequence, cast
from typing import Callable, Sequence, cast

import click
import cloup
Expand Down Expand Up @@ -623,7 +624,7 @@ def colorize(self, match: re.Match) -> str:

return txt

def highlight_extra_keywords(self, help_text):
def highlight_extra_keywords(self, help_text: str) -> str:
"""Highlight extra keywords in help screens based on the theme.
It is based on regular expressions. While this is not a bullet-proof method, it
Expand Down Expand Up @@ -796,14 +797,19 @@ def highlight_extra_keywords(self, help_text):

return help_text

def getvalue(self):
def getvalue(self) -> str:
"""Wrap original `Click.HelpFormatter.getvalue()` to force extra-colorization on
rendering."""
help_text = super().getvalue()
return self.highlight_extra_keywords(help_text)


def highlight(string, substrings, styling_method, ignore_case=False):
def highlight(
string: str,
substrings: Iterable[str],
styling_method: Callable,
ignore_case: bool = False,
) -> str:
"""Highlights parts of the ``string`` that matches ``substrings``.
Takes care of overlapping parts within the ``string``.
Expand All @@ -820,10 +826,10 @@ def highlight(string, substrings, styling_method, ignore_case=False):
}

# Reduce ranges, compute complement ranges, transform them to list of integers.
ranges = ",".join(ranges)
highlight_ranges = int_ranges_from_int_list(ranges)
range_arg = ",".join(ranges)
highlight_ranges = int_ranges_from_int_list(range_arg)
untouched_ranges = int_ranges_from_int_list(
complement_int_list(ranges, range_end=len(string)),
complement_int_list(range_arg, range_end=len(string)),
)

# Apply style to range of characters flagged as matching.
Expand All @@ -832,6 +838,6 @@ def highlight(string, substrings, styling_method, ignore_case=False):
segment = getitem(string, slice(i, j + 1))
if (i, j) in highlight_ranges:
segment = styling_method(segment)
styled_str += segment
styled_str += str(segment)

return styled_str
4 changes: 2 additions & 2 deletions click_extra/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import click
import cloup

from . import Command, Group
from . import Command, Group, Option
from .colorize import ColorOption, ExtraHelpColorsMixin, HelpExtraFormatter, HelpOption
from .config import ConfigOption
from .logging import VerbosityOption
Expand Down Expand Up @@ -124,7 +124,7 @@ def color(self) -> None:
self._color = None


def default_extra_params():
def default_extra_params() -> list[Option]:
"""Default additional options added to ``extra_command`` and ``extra_group``.
.. caution::
Expand Down
16 changes: 9 additions & 7 deletions click_extra/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,11 @@ def default_pattern(self) -> str:
ext_pattern = f"{{{','.join(extensions)}}}"
return f"{app_dir}{os.path.sep}*.{ext_pattern}"

def get_help_record(self, ctx):
def get_help_record(self, ctx: Context) -> tuple[str, str] | None:
"""Replaces the default value by the pretty version of the configuration
matching pattern."""
# Pre-compute pretty_path to bypass infinite recursive loop on get_default.
pretty_path = shrinkuser(Path(self.get_default(ctx)))
pretty_path = shrinkuser(Path(self.get_default(ctx))) # type: ignore[arg-type]
with patch.object(ConfigOption, "get_default") as mock_method:
mock_method.return_value = pretty_path
return super().get_help_record(ctx)
Expand Down Expand Up @@ -239,7 +239,7 @@ def search_and_read_conf(self, pattern: str) -> Iterable[tuple[Path | URL, str]]
logger.debug(f"Configuration file found at {file_path}")
yield file_path, file_path.read_text()

def parse_conf(self, conf_text: str) -> dict | None:
def parse_conf(self, conf_text: str) -> dict[str, Any] | None:
"""Try to parse the provided content with each format in the order provided by
the user.
Expand Down Expand Up @@ -295,7 +295,7 @@ def read_and_parse_conf(
return conf_path, user_conf
return None, None

def load_ini_config(self, content):
def load_ini_config(self, content: str) -> dict[str, Any]:
"""Utility method to parse INI configuration file.
Internal convention is to use a dot (``.``, as set by ``self.SEP``) in
Expand All @@ -307,7 +307,7 @@ def load_ini_config(self, content):
ini_config = ConfigParser(interpolation=ExtendedInterpolation())
ini_config.read_string(content)

conf = {}
conf: dict[str, Any] = {}
for section_id in ini_config.sections():
# Extract all options of the section.
sub_conf = {}
Expand All @@ -318,6 +318,8 @@ def load_ini_config(self, content):
option_id,
)

value: Any

if target_type in (None, str):
value = ini_config.get(section_id, option_id)

Expand Down Expand Up @@ -349,7 +351,7 @@ def load_ini_config(self, content):

return conf

def recursive_update(self, a, b):
def recursive_update(self, a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]:
"""Like standard ``dict.update()``, but recursive so sub-dict gets updated.
Ignore elements present in ``b`` but not in ``a``.
Expand All @@ -365,7 +367,7 @@ def recursive_update(self, a, b):
raise ValueError(msg)
return a

def merge_default_map(self, ctx, user_conf):
def merge_default_map(self, ctx: Context, user_conf: dict) -> None:
"""Save the user configuration into the context's ``default_map``.
Merge the user configuration into the pre-computed template structure, which
Expand Down
2 changes: 1 addition & 1 deletion click_extra/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def all_loggers(self) -> Generator[Logger, None, None]:
for name in ("click_extra", self.logger_name):
yield logging.getLogger(name)

def reset_loggers(self):
def reset_loggers(self) -> None:
"""Forces all loggers managed by the option to be reset to the default level.
Reset loggers in reverse order to ensure the internal logger is reset last.
Expand Down
30 changes: 21 additions & 9 deletions click_extra/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Any,
Callable,
ContextManager,
Iterator,
cast,
)
from unittest.mock import patch
Expand All @@ -42,7 +43,9 @@
from tabulate import tabulate

from . import (
Command,
Option,
Parameter,
ParamType,
Style,
echo,
Expand Down Expand Up @@ -303,7 +306,7 @@ def __init__(
super().__init__(*args, **kwargs)

@staticmethod
def init_tree_dict(*path: str, leaf: Any = None):
def init_tree_dict(*path: str, leaf: Any = None) -> Any:
"""Utility method to recursively create a nested dict structure whose keys are
provided by ``path`` list and at the end is populated by a copy of ``leaf``."""

Expand All @@ -322,7 +325,9 @@ def get_tree_value(tree_dict: dict[str, Any], *path: str) -> Any | None:
except KeyError:
return None

def _flatten_tree_dict_gen(self, tree_dict, parent_key):
def _flatten_tree_dict_gen(
self, tree_dict: MutableMapping, parent_key: str | None = None
) -> Iterable[tuple[str, Any]]:
"""`Source of this snippet
<https://www.freecodecamp.org/news/how-to-flatten-a-dictionary-in-python-in-4-different-ways/>`_.
"""
Expand All @@ -342,7 +347,12 @@ def flatten_tree_dict(
keys are path and values are the leaf's content."""
return dict(self._flatten_tree_dict_gen(tree_dict, parent_key))

def _recurse_cmd(self, cmd, top_level_params, parent_keys):
def _recurse_cmd(
self,
cmd: Command,
top_level_params: Iterable[str],
parent_keys: tuple[str, ...],
) -> Iterator[tuple[tuple[str, ...], Parameter]]:
"""Recursive generator to walk through all subcommands and their parameters."""
if hasattr(cmd, "commands"):
for subcmd_id, subcmd in cmd.commands.items():
Expand All @@ -362,7 +372,7 @@ def _recurse_cmd(self, cmd, top_level_params, parent_keys):
((*parent_keys, subcmd.name)),
)

def walk_params(self):
def walk_params(self) -> Iterator[tuple[tuple[str, ...], Parameter]]:
"""Generates an unfiltered list of all CLI parameters.
Everything is included, from top-level groups to subcommands, and from options
Expand All @@ -374,13 +384,15 @@ def walk_params(self):
"""
ctx = get_current_context()
cli = ctx.find_root().command
assert cli.name is not None

# Keep track of top-level CLI parameter IDs to check conflict with command
# IDs later.
top_level_params = set()

# Global, top-level options shared by all subcommands.
for p in cli.params:
assert p.name is not None
top_level_params.add(p.name)
yield (cli.name, p.name), p

Expand Down Expand Up @@ -410,7 +422,7 @@ def walk_params(self):
This mapping can be seen as a reverse of the ``click.types.convert_type()`` method.
"""

def get_param_type(self, param):
def get_param_type(self, param: Parameter) -> type[str | int | float | bool | list]:
"""Get the Python type of a Click parameter.
See the list of
Expand Down Expand Up @@ -447,7 +459,7 @@ def get_param_type(self, param):
if isinstance(param.type, ParamType):
return str

msg = f"Can't guess the appropriate Python type of {param!r} parameter."
msg = f"Can't guess the appropriate Python type of {param!r} parameter." # type:ignore[unreachable]
raise ValueError(msg)

@cached_property
Expand Down Expand Up @@ -491,7 +503,7 @@ def build_param_trees(self) -> None:
self.params_objects = objects

@cached_property
def params_template(self):
def params_template(self) -> dict[str, Any]:
"""Returns a tree-like dictionary whose keys shadows the CLI options and
subcommands and values are ``None``.
Expand All @@ -501,7 +513,7 @@ def params_template(self):
return self.params_template

@cached_property
def params_types(self):
def params_types(self) -> dict[str, Any]:
"""Returns a tree-like dictionary whose keys shadows the CLI options and
subcommands and values are their expected Python type.
Expand All @@ -511,7 +523,7 @@ def params_types(self):
return self.params_types

@cached_property
def params_objects(self):
def params_objects(self) -> dict[str, Any]:
"""Returns a tree-like dictionary whose keys shadows the CLI options and
subcommands and values are parameter objects.
Expand Down
4 changes: 2 additions & 2 deletions click_extra/platforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import sys
from dataclasses import dataclass, field
from itertools import combinations
from typing import Iterable
from typing import Iterable, Iterator

from . import cache

Expand Down Expand Up @@ -248,7 +248,7 @@ def __post_init__(self):
# Double-check there is no duplicate platforms.
assert len(self.platforms) == len(self.platform_ids)

def __iter__(self):
def __iter__(self) -> Iterator[Platform]:
"""Iterate over the platforms of the group."""
yield from self.platforms

Expand Down
10 changes: 7 additions & 3 deletions click_extra/pygments.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from __future__ import annotations

from typing import Iterable, Iterator

from pygments import lexers
from pygments.filter import Filter
from pygments.filters import TokenMergeFilter
Expand All @@ -36,7 +38,7 @@
from pygments.lexers.special import OutputLexer
from pygments.lexers.sql import PostgresConsoleLexer, SqliteConsoleLexer
from pygments.style import StyleMeta
from pygments.token import Generic, string_to_tokentype
from pygments.token import Generic, _TokenType, string_to_tokentype
from pygments_ansi_color import (
AnsiColorLexer,
ExtendedColorHtmlFormatterMixin,
Expand Down Expand Up @@ -70,7 +72,9 @@ def __init__(self, **options) -> None:
options.get("token_type", DEFAULT_TOKEN_TYPE),
)

def filter(self, lexer, stream):
def filter(
self, lexer: Lexer, stream: Iterable[tuple[_TokenType, str]]
) -> Iterator[tuple[_TokenType, str]]:
"""Transform each token of ``token_type`` type into a stream of ANSI tokens."""
for ttype, value in stream:
if ttype == self.token_type:
Expand Down Expand Up @@ -116,7 +120,7 @@ def __init__(self, *args, **kwargs) -> None:
self.filters.append(AnsiFilter())


def collect_session_lexers():
def collect_session_lexers() -> Iterator[type[Lexer]]:
"""Retrieve all lexers producing shell-like sessions in Pygments.
This function contain a manually-maintained list of lexers, to which we dynamiccaly
Expand Down
11 changes: 7 additions & 4 deletions click_extra/sphinx.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def hello_world(name):

from __future__ import annotations

from typing import Any

from docutils.statemachine import ViewList
from sphinx.highlighting import PygmentsBridge

Expand All @@ -75,20 +77,21 @@ class PatchedViewList(ViewList):
<https://github.com/pallets/pallets-sphinx-themes/pull/62>`_.
"""

def append(self, *args, **kwargs):
def append(self, *args, **kwargs) -> None:
"""Search the default code block and replace it with our own version."""
default_code_block = ".. sourcecode:: shell-session"
new_code_block = ".. code-block:: ansi-shell-session"

if default_code_block in args:
args = list(args)
new_args = list(args)
index = args.index(default_code_block)
args[index] = new_code_block
new_args[index] = new_code_block
args = tuple(new_args)

return super().append(*args, **kwargs)


def setup(app):
def setup(app: Any) -> None:
"""Register new directives, augmented with ANSI coloring.
New directives:
Expand Down
Loading

0 comments on commit d15fc8f

Please sign in to comment.