Skip to content

Commit

Permalink
feat: support for injector library
Browse files Browse the repository at this point in the history
  • Loading branch information
OlehChyhyryn authored and sondrelg committed Nov 29, 2023
1 parent 2f478a4 commit 1ede4bd
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 2 deletions.
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ More examples can be found in the [examples](#examples) section.
<br>

If you're using [pydantic](https://pydantic-docs.helpmanual.io/),
[fastapi](https://fastapi.tiangolo.com/), or [cattrs](https://github.com/python-attrs/cattrs)
[fastapi](https://fastapi.tiangolo.com/), [cattrs](https://github.com/python-attrs/cattrs),
or [injector](https://github.com/python-injector/injector)
see the [configuration](#configuration) for how to enable support.

## Primary features
Expand Down Expand Up @@ -244,6 +245,17 @@ This can be added in the future if needed.
type-checking-cattrs-enabled = true # default false
```

### Injector support

If you're using the injector library, you can enable support.
This will treat any `Inject[Dependency]` types as needed at runtime.

- **name**: `type-checking-injector-enabled`
- **type**: `bool`
```ini
type-checking-injector-enabled = true # default false
```

## Rationale

Why did we create this plugin?
Expand Down
53 changes: 52 additions & 1 deletion flake8_type_checking/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,53 @@ def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
self.visit(argument.annotation)


class InjectorMixin:
"""
Contains the necessary logic for injector (https://github.com/python-injector/injector) support.
For injected dependencies, we want to treat annotations as needed at runtime.
"""

if TYPE_CHECKING:
injector_enabled: bool

def visit(self, node: ast.AST) -> ast.AST: # noqa: D102
...

def visit_FunctionDef(self, node: FunctionDef) -> None:
"""Remove and map function arguments and returns."""
super().visit_FunctionDef(node) # type: ignore[misc]
if self.injector_enabled:
self.handle_injector_declaration(node)

def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
"""Remove and map function arguments and returns."""
super().visit_AsyncFunctionDef(node) # type: ignore[misc]
if self.injector_enabled:
self.handle_injector_declaration(node)

def handle_injector_declaration(self, node: Union[AsyncFunctionDef, FunctionDef]) -> None:
"""
Adjust for injector declaration setting.
When the injector setting is enabled, treat all annotations from within
a function definition (except for return annotations) as needed at runtime.
To achieve this, we just visit the annotations to register them as "uses".
"""
for path in [node.args.args, node.args.kwonlyargs]:
for argument in path:
if hasattr(argument, 'annotation') and argument.annotation:
annotation = argument.annotation
if not hasattr(annotation, 'value'):
continue
value = annotation.value
if hasattr(value, 'id') and value.id == 'Inject':
self.visit(argument.annotation)
if hasattr(value, 'attr') and value.attr == 'Inject':
self.visit(argument.annotation)


class FastAPIMixin:
"""
Contains the necessary logic for FastAPI support.
Expand Down Expand Up @@ -522,7 +569,7 @@ def lookup(self, symbol_name: str, use: HasPosition | None = None, runtime_only:
return parent.lookup(symbol_name, use, runtime_only)


class ImportVisitor(DunderAllMixin, AttrsMixin, FastAPIMixin, PydanticMixin, ast.NodeVisitor):
class ImportVisitor(DunderAllMixin, AttrsMixin, InjectorMixin, FastAPIMixin, PydanticMixin, ast.NodeVisitor):
"""Map all imports outside of type-checking blocks."""

#: The currently active scope
Expand All @@ -534,6 +581,7 @@ def __init__(
pydantic_enabled: bool,
fastapi_enabled: bool,
fastapi_dependency_support_enabled: bool,
injector_enabled: bool,
cattrs_enabled: bool,
pydantic_enabled_baseclass_passlist: list[str],
exempt_modules: Optional[list[str]] = None,
Expand All @@ -545,6 +593,7 @@ def __init__(
self.fastapi_enabled = fastapi_enabled
self.fastapi_dependency_support_enabled = fastapi_dependency_support_enabled
self.cattrs_enabled = cattrs_enabled
self.injector_enabled = injector_enabled
self.pydantic_enabled_baseclass_passlist = pydantic_enabled_baseclass_passlist
self.pydantic_validate_arguments_import_name = None
self.cwd = cwd # we need to know the current directory to guess at which imports are remote and which are not
Expand Down Expand Up @@ -1448,6 +1497,7 @@ def __init__(self, node: ast.Module, options: Optional[Namespace]) -> None:
fastapi_enabled = getattr(options, 'type_checking_fastapi_enabled', False)
fastapi_dependency_support_enabled = getattr(options, 'type_checking_fastapi_dependency_support_enabled', False)
cattrs_enabled = getattr(options, 'type_checking_cattrs_enabled', False)
injector_enabled = getattr(options, 'type_checking_injector_enabled', False)

if fastapi_enabled and not pydantic_enabled:
# FastAPI support must include Pydantic support.
Expand All @@ -1466,6 +1516,7 @@ def __init__(self, node: ast.Module, options: Optional[Namespace]) -> None:
exempt_modules=exempt_modules,
fastapi_dependency_support_enabled=fastapi_dependency_support_enabled,
pydantic_enabled_baseclass_passlist=pydantic_enabled_baseclass_passlist,
injector_enabled=injector_enabled,
)
self.visitor.visit(node)

Expand Down
7 changes: 7 additions & 0 deletions flake8_type_checking/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ def add_options(cls, option_manager: OptionManager) -> None: # pragma: no cover
default=False,
help='Prevent flagging of annotations on attrs class definitions.',
)
option_manager.add_option(
'--type-checking-injector-enabled',
action='store_true',
parse_from_config=True,
default=False,
help='Prevent flagging of annotations on injector class definitions.',
)

def run(self) -> Flake8Generator:
"""Run flake8 plugin and return any relevant errors."""
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def _get_error(example: str, *, error_code_filter: Optional[str] = None, **kwarg
mock_options.type_checking_fastapi_enabled = False
mock_options.type_checking_fastapi_dependency_support_enabled = False
mock_options.type_checking_pydantic_enabled_baseclass_passlist = []
mock_options.type_checking_injector_enabled = False
mock_options.type_checking_strict = False
# kwarg overrides
for k, v in kwargs.items():
Expand Down
1 change: 1 addition & 0 deletions tests/test_import_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def _visit(example: str) -> ImportVisitor:
fastapi_enabled=False,
fastapi_dependency_support_enabled=False,
cattrs_enabled=False,
injector_enabled=False,
pydantic_enabled_baseclass_passlist=[],
)
visitor.visit(ast.parse(example.replace('; ', '\n')))
Expand Down
149 changes: 149 additions & 0 deletions tests/test_injector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""This file tests injector support."""

import textwrap

import pytest

from flake8_type_checking.constants import TC002
from tests.conftest import _get_error


@pytest.mark.parametrize(
('enabled', 'expected'),
[
(True, {'2:0 ' + TC002.format(module='services.Service')}),
(False, {'2:0 ' + TC002.format(module='services.Service')}),
],
)
def test_non_pydantic_model(enabled, expected):
"""A class does not use injector, so error should be risen in both scenarios."""
example = textwrap.dedent('''
from services import Service
class X:
def __init__(self, service: Service) -> None:
self.service = service
''')
assert _get_error(example, error_code_filter='TC002', type_checking_pydantic_enabled=enabled) == expected


@pytest.mark.parametrize(
('enabled', 'expected'),
[
(True, set()),
(False, {'2:0 ' + TC002.format(module='injector.Inject'), '3:0 ' + TC002.format(module='services.Service')}),
],
)
def test_injector_option(enabled, expected):
"""When an injector option is enabled, injector should be ignored."""
example = textwrap.dedent('''
from injector import Inject
from services import Service
class X:
def __init__(self, service: Inject[Service]) -> None:
self.service = service
''')
assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected


@pytest.mark.parametrize(
('enabled', 'expected'),
[
(True, {'4:0 ' + TC002.format(module='other_dependency.OtherDependency')}),
(
False,
{
'2:0 ' + TC002.format(module='injector.Inject'),
'3:0 ' + TC002.format(module='services.Service'),
'4:0 ' + TC002.format(module='other_dependency.OtherDependency'),
},
),
],
)
def test_injector_option_only_allows_injected_dependencies(enabled, expected):
"""Whenever an injector option is enabled, only injected dependencies should be ignored."""
example = textwrap.dedent('''
from injector import Inject
from services import Service
from other_dependency import OtherDependency
class X:
def __init__(self, service: Inject[Service], other: OtherDependency) -> None:
self.service = service
self.other = other
''')
assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected


@pytest.mark.parametrize(
('enabled', 'expected'),
[
(True, {'4:0 ' + TC002.format(module='other_dependency.OtherDependency')}),
(
False,
{
'2:0 ' + TC002.format(module='injector.Inject'),
'3:0 ' + TC002.format(module='services.Service'),
'4:0 ' + TC002.format(module='other_dependency.OtherDependency'),
},
),
],
)
def test_injector_option_only_allows_injector_slices(enabled, expected):
"""
Whenever an injector option is enabled, only injected dependencies should be ignored,
not any dependencies with slices.
"""
example = textwrap.dedent("""
from injector import Inject
from services import Service
from other_dependency import OtherDependency
class X:
def __init__(self, service: Inject[Service], other_deps: list[OtherDependency]) -> None:
self.service = service
self.other_deps = other_deps
""")
assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected


@pytest.mark.parametrize(
('enabled', 'expected'),
[
(True, set()),
(False, {'2:0 ' + TC002.format(module='injector'), '3:0 ' + TC002.format(module='services.Service')}),
],
)
def test_injector_option_allows_injector_as_module(enabled, expected):
"""Whenever an injector option is enabled, injected dependencies should be ignored, even if import as module."""
example = textwrap.dedent('''
import injector
from services import Service
class X:
def __init__(self, service: injector.Inject[Service]) -> None:
self.service = service
''')
assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected


@pytest.mark.parametrize(
('enabled', 'expected'),
[
(True, set()),
(False, {'2:0 ' + TC002.format(module='injector.Inject'), '3:0 ' + TC002.format(module='services.Service')}),
],
)
def test_injector_option_only_mentioned_second_time(enabled, expected):
"""Whenever an injector option is enabled, dependency referenced second time is accepted."""
example = textwrap.dedent("""
from injector import Inject
from services import Service
class X:
def __init__(self, service: Inject[Service], other_deps: list[Service]) -> None:
self.service = service
self.other_deps = other_deps
""")
assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected
1 change: 1 addition & 0 deletions tests/test_name_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def _get_names(example: str) -> Set[str]:
fastapi_enabled=False,
fastapi_dependency_support_enabled=False,
cattrs_enabled=False,
injector_enabled=False,
pydantic_enabled_baseclass_passlist=[],
)
visitor.visit(ast.parse(example))
Expand Down

0 comments on commit 1ede4bd

Please sign in to comment.