diff --git a/src/pytest_fluent/additional_information.py b/src/pytest_fluent/additional_information.py index 47760dc..3793abe 100644 --- a/src/pytest_fluent/additional_information.py +++ b/src/pytest_fluent/additional_information.py @@ -98,6 +98,23 @@ def get_additional_information_callback( return info +def check_type_with_optional( + annotation: typing.Any, exptected_type: typing.Type +) -> bool: + """Check if is type with optional. + + Args: + annotation (typing.Any): Annotation to check + exptected_type (typing.Type): Expected type + + Returns: + bool: True if it is the expected type. + """ + is_t = annotation is exptected_type + is_opt_t = annotation is typing.Optional[exptected_type] + return is_t or is_opt_t + + def check_allowed_input(func: typing.Callable) -> None: """Check that the given function has a specific signature. @@ -108,7 +125,9 @@ def check_allowed_input(func: typing.Callable) -> None: raise TypeError("Not a function") annotations = func.__annotations__ args = set(annotations.keys()) - if "item" in args and not annotations["item"] is pytest.Item: + if "item" in args and not check_type_with_optional( + annotations["item"], pytest.Item + ): raise TypeError("Invalid function signature for 'item'") if not ("return" in args and annotations["return"] is dict): raise TypeError("Invalid function signature for return type. Expecting a dict.") diff --git a/tests/test_additional_information.py b/tests/test_additional_information.py index e0904f4..af627be 100644 --- a/tests/test_additional_information.py +++ b/tests/test_additional_information.py @@ -1,3 +1,4 @@ +import typing from unittest.mock import patch import pytest @@ -19,6 +20,9 @@ def add_1() -> dict: def add_2(item: pytest.Item) -> dict: return {} + def add_2_opt(item: typing.Optional[pytest.Item] = None) -> dict: + return {} + def add_3(item: int) -> dict: return {} @@ -27,6 +31,7 @@ def add_4() -> int: check_allowed_input(add_1) check_allowed_input(add_2) + check_allowed_input(add_2_opt) with pytest.raises(TypeError, match="Invalid function signature for 'item'"): check_allowed_input(add_3)