From 1ede4bd98acb314d41154a2f31a53c0e52acf26a Mon Sep 17 00:00:00 2001 From: Oleh Chyhyryn Date: Tue, 14 Nov 2023 15:11:31 +0200 Subject: [PATCH] feat: support for injector library --- README.md | 14 ++- flake8_type_checking/checker.py | 53 +++++++++++- flake8_type_checking/plugin.py | 7 ++ tests/conftest.py | 1 + tests/test_import_visitors.py | 1 + tests/test_injector.py | 149 ++++++++++++++++++++++++++++++++ tests/test_name_visitor.py | 1 + 7 files changed, 224 insertions(+), 2 deletions(-) create mode 100644 tests/test_injector.py diff --git a/README.md b/README.md index 0db6316..d334554 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,8 @@ More examples can be found in the [examples](#examples) section.
If you're using [pydantic](https://pydantic-docs.helpmanual.io/), -[fastapi](https://fastapi.tiangolo.com/), or [cattrs](https://github.com/python-attrs/cattrs) +[fastapi](https://fastapi.tiangolo.com/), [cattrs](https://github.com/python-attrs/cattrs), +or [injector](https://github.com/python-injector/injector) see the [configuration](#configuration) for how to enable support. ## Primary features @@ -244,6 +245,17 @@ This can be added in the future if needed. type-checking-cattrs-enabled = true # default false ``` +### Injector support + +If you're using the injector library, you can enable support. +This will treat any `Inject[Dependency]` types as needed at runtime. + +- **name**: `type-checking-injector-enabled` +- **type**: `bool` +```ini +type-checking-injector-enabled = true # default false +``` + ## Rationale Why did we create this plugin? diff --git a/flake8_type_checking/checker.py b/flake8_type_checking/checker.py index 11523c4..7a01ddd 100644 --- a/flake8_type_checking/checker.py +++ b/flake8_type_checking/checker.py @@ -246,6 +246,53 @@ def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None: self.visit(argument.annotation) +class InjectorMixin: + """ + Contains the necessary logic for injector (https://github.com/python-injector/injector) support. + + For injected dependencies, we want to treat annotations as needed at runtime. + """ + + if TYPE_CHECKING: + injector_enabled: bool + + def visit(self, node: ast.AST) -> ast.AST: # noqa: D102 + ... + + def visit_FunctionDef(self, node: FunctionDef) -> None: + """Remove and map function arguments and returns.""" + super().visit_FunctionDef(node) # type: ignore[misc] + if self.injector_enabled: + self.handle_injector_declaration(node) + + def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None: + """Remove and map function arguments and returns.""" + super().visit_AsyncFunctionDef(node) # type: ignore[misc] + if self.injector_enabled: + self.handle_injector_declaration(node) + + def handle_injector_declaration(self, node: Union[AsyncFunctionDef, FunctionDef]) -> None: + """ + Adjust for injector declaration setting. + + When the injector setting is enabled, treat all annotations from within + a function definition (except for return annotations) as needed at runtime. + + To achieve this, we just visit the annotations to register them as "uses". + """ + for path in [node.args.args, node.args.kwonlyargs]: + for argument in path: + if hasattr(argument, 'annotation') and argument.annotation: + annotation = argument.annotation + if not hasattr(annotation, 'value'): + continue + value = annotation.value + if hasattr(value, 'id') and value.id == 'Inject': + self.visit(argument.annotation) + if hasattr(value, 'attr') and value.attr == 'Inject': + self.visit(argument.annotation) + + class FastAPIMixin: """ Contains the necessary logic for FastAPI support. @@ -522,7 +569,7 @@ def lookup(self, symbol_name: str, use: HasPosition | None = None, runtime_only: return parent.lookup(symbol_name, use, runtime_only) -class ImportVisitor(DunderAllMixin, AttrsMixin, FastAPIMixin, PydanticMixin, ast.NodeVisitor): +class ImportVisitor(DunderAllMixin, AttrsMixin, InjectorMixin, FastAPIMixin, PydanticMixin, ast.NodeVisitor): """Map all imports outside of type-checking blocks.""" #: The currently active scope @@ -534,6 +581,7 @@ def __init__( pydantic_enabled: bool, fastapi_enabled: bool, fastapi_dependency_support_enabled: bool, + injector_enabled: bool, cattrs_enabled: bool, pydantic_enabled_baseclass_passlist: list[str], exempt_modules: Optional[list[str]] = None, @@ -545,6 +593,7 @@ def __init__( self.fastapi_enabled = fastapi_enabled self.fastapi_dependency_support_enabled = fastapi_dependency_support_enabled self.cattrs_enabled = cattrs_enabled + self.injector_enabled = injector_enabled self.pydantic_enabled_baseclass_passlist = pydantic_enabled_baseclass_passlist self.pydantic_validate_arguments_import_name = None self.cwd = cwd # we need to know the current directory to guess at which imports are remote and which are not @@ -1448,6 +1497,7 @@ def __init__(self, node: ast.Module, options: Optional[Namespace]) -> None: fastapi_enabled = getattr(options, 'type_checking_fastapi_enabled', False) fastapi_dependency_support_enabled = getattr(options, 'type_checking_fastapi_dependency_support_enabled', False) cattrs_enabled = getattr(options, 'type_checking_cattrs_enabled', False) + injector_enabled = getattr(options, 'type_checking_injector_enabled', False) if fastapi_enabled and not pydantic_enabled: # FastAPI support must include Pydantic support. @@ -1466,6 +1516,7 @@ def __init__(self, node: ast.Module, options: Optional[Namespace]) -> None: exempt_modules=exempt_modules, fastapi_dependency_support_enabled=fastapi_dependency_support_enabled, pydantic_enabled_baseclass_passlist=pydantic_enabled_baseclass_passlist, + injector_enabled=injector_enabled, ) self.visitor.visit(node) diff --git a/flake8_type_checking/plugin.py b/flake8_type_checking/plugin.py index cd4535f..dd6cdc3 100644 --- a/flake8_type_checking/plugin.py +++ b/flake8_type_checking/plugin.py @@ -85,6 +85,13 @@ def add_options(cls, option_manager: OptionManager) -> None: # pragma: no cover default=False, help='Prevent flagging of annotations on attrs class definitions.', ) + option_manager.add_option( + '--type-checking-injector-enabled', + action='store_true', + parse_from_config=True, + default=False, + help='Prevent flagging of annotations on injector class definitions.', + ) def run(self) -> Flake8Generator: """Run flake8 plugin and return any relevant errors.""" diff --git a/tests/conftest.py b/tests/conftest.py index b5d3811..a941f98 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,6 +41,7 @@ def _get_error(example: str, *, error_code_filter: Optional[str] = None, **kwarg mock_options.type_checking_fastapi_enabled = False mock_options.type_checking_fastapi_dependency_support_enabled = False mock_options.type_checking_pydantic_enabled_baseclass_passlist = [] + mock_options.type_checking_injector_enabled = False mock_options.type_checking_strict = False # kwarg overrides for k, v in kwargs.items(): diff --git a/tests/test_import_visitors.py b/tests/test_import_visitors.py index 67f00d0..781f4d5 100644 --- a/tests/test_import_visitors.py +++ b/tests/test_import_visitors.py @@ -19,6 +19,7 @@ def _visit(example: str) -> ImportVisitor: fastapi_enabled=False, fastapi_dependency_support_enabled=False, cattrs_enabled=False, + injector_enabled=False, pydantic_enabled_baseclass_passlist=[], ) visitor.visit(ast.parse(example.replace('; ', '\n'))) diff --git a/tests/test_injector.py b/tests/test_injector.py new file mode 100644 index 0000000..c306daa --- /dev/null +++ b/tests/test_injector.py @@ -0,0 +1,149 @@ +"""This file tests injector support.""" + +import textwrap + +import pytest + +from flake8_type_checking.constants import TC002 +from tests.conftest import _get_error + + +@pytest.mark.parametrize( + ('enabled', 'expected'), + [ + (True, {'2:0 ' + TC002.format(module='services.Service')}), + (False, {'2:0 ' + TC002.format(module='services.Service')}), + ], +) +def test_non_pydantic_model(enabled, expected): + """A class does not use injector, so error should be risen in both scenarios.""" + example = textwrap.dedent(''' + from services import Service + + class X: + def __init__(self, service: Service) -> None: + self.service = service + ''') + assert _get_error(example, error_code_filter='TC002', type_checking_pydantic_enabled=enabled) == expected + + +@pytest.mark.parametrize( + ('enabled', 'expected'), + [ + (True, set()), + (False, {'2:0 ' + TC002.format(module='injector.Inject'), '3:0 ' + TC002.format(module='services.Service')}), + ], +) +def test_injector_option(enabled, expected): + """When an injector option is enabled, injector should be ignored.""" + example = textwrap.dedent(''' + from injector import Inject + from services import Service + + class X: + def __init__(self, service: Inject[Service]) -> None: + self.service = service + ''') + assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected + + +@pytest.mark.parametrize( + ('enabled', 'expected'), + [ + (True, {'4:0 ' + TC002.format(module='other_dependency.OtherDependency')}), + ( + False, + { + '2:0 ' + TC002.format(module='injector.Inject'), + '3:0 ' + TC002.format(module='services.Service'), + '4:0 ' + TC002.format(module='other_dependency.OtherDependency'), + }, + ), + ], +) +def test_injector_option_only_allows_injected_dependencies(enabled, expected): + """Whenever an injector option is enabled, only injected dependencies should be ignored.""" + example = textwrap.dedent(''' + from injector import Inject + from services import Service + from other_dependency import OtherDependency + + class X: + def __init__(self, service: Inject[Service], other: OtherDependency) -> None: + self.service = service + self.other = other + ''') + assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected + + +@pytest.mark.parametrize( + ('enabled', 'expected'), + [ + (True, {'4:0 ' + TC002.format(module='other_dependency.OtherDependency')}), + ( + False, + { + '2:0 ' + TC002.format(module='injector.Inject'), + '3:0 ' + TC002.format(module='services.Service'), + '4:0 ' + TC002.format(module='other_dependency.OtherDependency'), + }, + ), + ], +) +def test_injector_option_only_allows_injector_slices(enabled, expected): + """ + Whenever an injector option is enabled, only injected dependencies should be ignored, + not any dependencies with slices. + """ + example = textwrap.dedent(""" + from injector import Inject + from services import Service + from other_dependency import OtherDependency + + class X: + def __init__(self, service: Inject[Service], other_deps: list[OtherDependency]) -> None: + self.service = service + self.other_deps = other_deps + """) + assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected + + +@pytest.mark.parametrize( + ('enabled', 'expected'), + [ + (True, set()), + (False, {'2:0 ' + TC002.format(module='injector'), '3:0 ' + TC002.format(module='services.Service')}), + ], +) +def test_injector_option_allows_injector_as_module(enabled, expected): + """Whenever an injector option is enabled, injected dependencies should be ignored, even if import as module.""" + example = textwrap.dedent(''' + import injector + from services import Service + + class X: + def __init__(self, service: injector.Inject[Service]) -> None: + self.service = service + ''') + assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected + + +@pytest.mark.parametrize( + ('enabled', 'expected'), + [ + (True, set()), + (False, {'2:0 ' + TC002.format(module='injector.Inject'), '3:0 ' + TC002.format(module='services.Service')}), + ], +) +def test_injector_option_only_mentioned_second_time(enabled, expected): + """Whenever an injector option is enabled, dependency referenced second time is accepted.""" + example = textwrap.dedent(""" + from injector import Inject + from services import Service + + class X: + def __init__(self, service: Inject[Service], other_deps: list[Service]) -> None: + self.service = service + self.other_deps = other_deps + """) + assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected diff --git a/tests/test_name_visitor.py b/tests/test_name_visitor.py index ecd1d9a..ee79fd8 100644 --- a/tests/test_name_visitor.py +++ b/tests/test_name_visitor.py @@ -19,6 +19,7 @@ def _get_names(example: str) -> Set[str]: fastapi_enabled=False, fastapi_dependency_support_enabled=False, cattrs_enabled=False, + injector_enabled=False, pydantic_enabled_baseclass_passlist=[], ) visitor.visit(ast.parse(example))