Skip to content

Commit

Permalink
Adding optional check
Browse files Browse the repository at this point in the history
  • Loading branch information
casabre committed Jun 5, 2024
1 parent c340611 commit fa3888a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
21 changes: 20 additions & 1 deletion src/pytest_fluent/additional_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,23 @@ def get_additional_information_callback(
return info


def check_type_with_optional(
annotation: typing.Any, exptected_type: typing.Type
) -> bool:
"""Check if is type with optional.
Args:
annotation (typing.Any): Annotation to check
exptected_type (typing.Type): Expected type
Returns:
bool: True if it is the expected type.
"""
is_t = annotation is exptected_type
is_opt_t = annotation is typing.Optional[exptected_type]
return is_t or is_opt_t


def check_allowed_input(func: typing.Callable) -> None:
"""Check that the given function has a specific signature.
Expand All @@ -108,7 +125,9 @@ def check_allowed_input(func: typing.Callable) -> None:
raise TypeError("Not a function")
annotations = func.__annotations__
args = set(annotations.keys())
if "item" in args and not annotations["item"] is pytest.Item:
if "item" in args and not check_type_with_optional(
annotations["item"], pytest.Item
):
raise TypeError("Invalid function signature for 'item'")
if not ("return" in args and annotations["return"] is dict):
raise TypeError("Invalid function signature for return type. Expecting a dict.")
5 changes: 5 additions & 0 deletions tests/test_additional_information.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
from unittest.mock import patch

import pytest
Expand All @@ -19,6 +20,9 @@ def add_1() -> dict:
def add_2(item: pytest.Item) -> dict:
return {}

def add_2_opt(item: typing.Optional[pytest.Item] = None) -> dict:
return {}

def add_3(item: int) -> dict:
return {}

Expand All @@ -27,6 +31,7 @@ def add_4() -> int:

check_allowed_input(add_1)
check_allowed_input(add_2)
check_allowed_input(add_2_opt)

with pytest.raises(TypeError, match="Invalid function signature for 'item'"):
check_allowed_input(add_3)
Expand Down

0 comments on commit fa3888a

Please sign in to comment.