diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 690638e..29511cf 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -14,6 +14,31 @@ * The `ExponentialBackoff` and `LinearBackoff` classes now require keyword arguments for their constructor. This change was made to make the classes easier to use and to avoid confusion with the order of the arguments. +* The `BaseApiClient` class is not generic anymore, and doesn't take a function to create the stub. Instead, subclasses should create their own stub right after calling the parent constructor. This enables subclasses to cast the stub to the generated `XxxAsyncStub` class, which have proper `async` type hints. To convert you client: + + ```python + # Old + class MyApiClient(BaseApiClient[MyServiceStub]): + def __init__(self, server_url: str, *, ...) -> None: + super().__init__(server_url, MyServiceStub, ...) + ... + + # New + class MyApiClient(BaseApiClient): + def __init__(self, server_url: str, *, ...) -> None: + super().__init__(server_url, connect=connect) + self._stub = cast(MyServiceAsyncStub, MyServiceStub(self.channel)) + ... + + @property + def stub(self) -> ExampleAsyncStub: + if self._channel is None: + raise ClientNotConnected(server_url=self.server_url, operation="stub") + return self._stub + ``` + + After this, you should be able to remove a lot of `cast`s or `type: ignore` from the code when calling the stub `async` methods. + ## New Features diff --git a/src/frequenz/client/base/client.py b/src/frequenz/client/base/client.py index 0f73cfb..cf27533 100644 --- a/src/frequenz/client/base/client.py +++ b/src/frequenz/client/base/client.py @@ -6,18 +6,15 @@ import abc import inspect from collections.abc import Awaitable, Callable -from typing import Any, Generic, Self, TypeVar, overload +from typing import Any, Self, TypeVar, overload from grpc.aio import AioRpcError, Channel from .channel import ChannelOptions, parse_grpc_uri from .exception import ApiClientError, ClientNotConnected -StubT = TypeVar("StubT") -"""The type of the gRPC stub.""" - -class BaseApiClient(abc.ABC, Generic[StubT]): +class BaseApiClient(abc.ABC): """A base class for API clients. This class provides a common interface for API clients that communicate with a API @@ -32,12 +29,31 @@ class BaseApiClient(abc.ABC, Generic[StubT]): a class that helps sending messages from a gRPC stream to a [Broadcast][frequenz.channels.Broadcast] channel. + Note: + Because grpcio doesn't provide proper type hints, a hack is needed to have + propepr async type hints for the stubs generated by protoc. When using + `mypy-protobuf`, a `XxxAsyncStub` class is generated for each `XxxStub` class + but in the `.pyi` file, so the type can be used to specify type hints, but + **not** in any other context, as the class doesn't really exist for the Python + interpreter. This include generics, and because of this, this class can't be + even parametrized using the async class, so the instantiation of the stub can't + be done in the base class. + + Because of this, subclasses need to create the stubs by themselves, using the + real stub class and casting it to the `XxxAsyncStub` class, so `mypy` can use + the async version of the stubs. + + It is recommended to define a `stub` property that returns the async stub, so + this hack is completely hidden from clients, even if they need to access the + stub for more advanced uses. + Example: This example illustrates how to create a simple API client that connects to a gRPC server and calls a method on a stub. ```python from collections.abc import AsyncIterable + from typing import cast from frequenz.client.base.client import BaseApiClient, call_stub_method from frequenz.client.base.streaming import GrpcStreamBroadcaster from frequenz.channels import Receiver @@ -57,18 +73,29 @@ async def example_method( ) -> ExampleResponse: ... - def example_stream(self) -> AsyncIterable[ExampleResponse]: + def example_stream(self, _: ExampleRequest) -> AsyncIterable[ExampleResponse]: + ... + + class ExampleAsyncStub: + async def example_method( + self, + request: ExampleRequest # pylint: disable=unused-argument + ) -> ExampleResponse: + ... + + def example_stream(self, _: ExampleRequest) -> AsyncIterable[ExampleResponse]: ... # End of generated classes class ExampleResponseWrapper: - def __init__(self, response: ExampleResponse): + def __init__(self, response: ExampleResponse) -> None: self.transformed_value = f"{response.float_value:.2f}" - class MyApiClient(BaseApiClient[ExampleStub]): - def __init__(self, server_url: str, *, connect: bool = True): - super().__init__( - server_url, ExampleStub, connect=connect + class MyApiClient(BaseApiClient): + def __init__(self, server_url: str, *, connect: bool = True) -> None: + super().__init__(server_url, connect=connect) + self._stub = cast( + ExampleAsyncStub, ExampleStub(self.channel) ) self._broadcaster = GrpcStreamBroadcaster( "stream", @@ -76,6 +103,12 @@ def __init__(self, server_url: str, *, connect: bool = True): ExampleResponseWrapper, ) + @property + def stub(self) -> ExampleAsyncStub: + if self._channel is None: + raise ClientNotConnected(server_url=self.server_url, operation="stub") + return self._stub + async def example_method( self, int_value: int, str_value: str ) -> ExampleResponseWrapper: @@ -114,7 +147,6 @@ async def main(): def __init__( self, server_url: str, - create_stub: Callable[[Channel], StubT], *, connect: bool = True, channel_defaults: ChannelOptions = ChannelOptions(), @@ -123,7 +155,6 @@ def __init__( Args: server_url: The URL of the server to connect to. - create_stub: A function that creates a stub from a channel. connect: Whether to connect to the server as soon as a client instance is created. If `False`, the client will not connect to the server until [connect()][frequenz.client.base.client.BaseApiClient.connect] is @@ -132,10 +163,8 @@ def __init__( the server URL. """ self._server_url: str = server_url - self._create_stub: Callable[[Channel], StubT] = create_stub self._channel_defaults: ChannelOptions = channel_defaults self._channel: Channel | None = None - self._stub: StubT | None = None if connect: self.connect(server_url) @@ -165,22 +194,6 @@ def channel_defaults(self) -> ChannelOptions: """The default options for the gRPC channel.""" return self._channel_defaults - @property - def stub(self) -> StubT: - """The underlying gRPC stub. - - Warning: - This stub is provided as a last resort for advanced users. It is not - recommended to use this property directly unless you know what you are - doing and you don't care about being tied to a specific gRPC library. - - Raises: - ClientNotConnected: If the client is not connected to the server. - """ - if self._stub is None: - raise ClientNotConnected(server_url=self.server_url, operation="stub") - return self._stub - @property def is_connected(self) -> bool: """Whether the client is connected to the server.""" @@ -202,7 +215,6 @@ def connect(self, server_url: str | None = None) -> None: elif self.is_connected: return self._channel = parse_grpc_uri(self._server_url, self._channel_defaults) - self._stub = self._create_stub(self._channel) async def disconnect(self) -> None: """Disconnect from the server. @@ -227,7 +239,6 @@ async def __aexit__( return None result = await self._channel.__aexit__(_exc_type, _exc_val, _exc_tb) self._channel = None - self._stub = None return result @@ -240,7 +251,7 @@ async def __aexit__( @overload async def call_stub_method( - client: BaseApiClient[StubT], + client: BaseApiClient, stub_method: Callable[[], Awaitable[StubOutT]], *, method_name: str | None = None, @@ -250,7 +261,7 @@ async def call_stub_method( @overload async def call_stub_method( - client: BaseApiClient[StubT], + client: BaseApiClient, stub_method: Callable[[], Awaitable[StubOutT]], *, method_name: str | None = None, @@ -261,7 +272,7 @@ async def call_stub_method( # We need the `noqa: DOC503` because `pydoclint` can't figure out that # `ApiClientError.from_grpc_error()` returns a `GrpcError` instance. async def call_stub_method( # noqa: DOC503 - client: BaseApiClient[StubT], + client: BaseApiClient, stub_method: Callable[[], Awaitable[StubOutT]], *, method_name: str | None = None, diff --git a/tests/test_client.py b/tests/test_client.py index 2cbdcbc..dcf3d5f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -12,7 +12,7 @@ import pytest_mock from frequenz.client.base.channel import ChannelOptions, SslOptions -from frequenz.client.base.client import BaseApiClient, StubT, call_stub_method +from frequenz.client.base.client import BaseApiClient, call_stub_method from frequenz.client.base.exception import ClientNotConnected, UnknownError @@ -20,7 +20,7 @@ def _auto_connect_name(auto_connect: bool) -> str: return f"{auto_connect=}" -def _assert_is_disconnected(client: BaseApiClient[StubT]) -> None: +def _assert_is_disconnected(client: BaseApiClient) -> None: """Assert that the client is disconnected.""" assert not client.is_connected @@ -30,17 +30,9 @@ def _assert_is_disconnected(client: BaseApiClient[StubT]) -> None: assert exc.server_url == _DEFAULT_SERVER_URL assert exc.operation == "channel" - with pytest.raises(ClientNotConnected, match=r"") as exc_info: - _ = client.stub - exc = exc_info.value - assert exc.server_url == _DEFAULT_SERVER_URL - assert exc.operation == "stub" - @dataclass(kw_only=True, frozen=True) class _ClientMocks: - stub: mock.MagicMock - create_stub: mock.MagicMock channel: mock.MagicMock parse_grpc_uri: mock.MagicMock @@ -54,10 +46,8 @@ def create_client_with_mocks( auto_connect: bool = True, server_url: str = _DEFAULT_SERVER_URL, channel_defaults: ChannelOptions | None = None, -) -> tuple[BaseApiClient[mock.MagicMock], _ClientMocks]: +) -> tuple[BaseApiClient, _ClientMocks]: """Create a BaseApiClient instance with mocks.""" - mock_stub = mock.MagicMock(name="stub") - mock_create_stub = mock.MagicMock(name="create_stub", return_value=mock_stub) mock_channel = mock.MagicMock(name="channel", spec=grpc.aio.Channel) mock_parse_grpc_uri = mocker.patch( "frequenz.client.base.client.parse_grpc_uri", return_value=mock_channel @@ -67,13 +57,10 @@ def create_client_with_mocks( kwargs["channel_defaults"] = channel_defaults client = BaseApiClient( server_url=server_url, - create_stub=mock_create_stub, connect=auto_connect, **kwargs, ) return client, _ClientMocks( - stub=mock_stub, - create_stub=mock_create_stub, channel=mock_channel, parse_grpc_uri=mock_parse_grpc_uri, ) @@ -92,13 +79,10 @@ def test_base_api_client_init( client.server_url, ChannelOptions() ) assert client.channel is mocks.channel - assert client.stub is mocks.stub assert client.is_connected - mocks.create_stub.assert_called_once_with(mocks.channel) else: _assert_is_disconnected(client) mocks.parse_grpc_uri.assert_not_called() - mocks.create_stub.assert_not_called() def test_base_api_client_init_with_channel_defaults( @@ -110,9 +94,7 @@ def test_base_api_client_init_with_channel_defaults( assert client.server_url == _DEFAULT_SERVER_URL mocks.parse_grpc_uri.assert_called_once_with(client.server_url, channel_defaults) assert client.channel is mocks.channel - assert client.stub is mocks.stub assert client.is_connected - mocks.create_stub.assert_called_once_with(mocks.channel) @pytest.mark.parametrize( @@ -129,12 +111,10 @@ def test_base_api_client_connect( # We want to check only what happens when we call connect, so we reset the mocks # that were called during initialization mocks.parse_grpc_uri.reset_mock() - mocks.create_stub.reset_mock() client.connect(new_server_url) assert client.channel is mocks.channel - assert client.stub is mocks.stub assert client.is_connected same_url = new_server_url is None or new_server_url == _DEFAULT_SERVER_URL @@ -148,12 +128,10 @@ def test_base_api_client_connect( # reconnect if auto_connect and same_url: mocks.parse_grpc_uri.assert_not_called() - mocks.create_stub.assert_not_called() else: mocks.parse_grpc_uri.assert_called_once_with( client.server_url, ChannelOptions() ) - mocks.create_stub.assert_called_once_with(mocks.channel) async def test_base_api_client_disconnect(mocker: pytest_mock.MockFixture) -> None: @@ -177,23 +155,19 @@ async def test_base_api_client_async_context_manager( # We want to check only what happens when we enter the context manager, so we reset # the mocks that were called during initialization mocks.parse_grpc_uri.reset_mock() - mocks.create_stub.reset_mock() async with client: assert client.channel is mocks.channel - assert client.stub is mocks.stub assert client.is_connected mocks.channel.__aexit__.assert_not_called() # If we were previously connected, the client should not reconnect when entering # the context manager if auto_connect: mocks.parse_grpc_uri.assert_not_called() - mocks.create_stub.assert_not_called() else: mocks.parse_grpc_uri.assert_called_once_with( client.server_url, ChannelOptions() ) - mocks.create_stub.assert_called_once_with(mocks.channel) mocks.channel.__aexit__.assert_called_once_with(None, None, None) assert client.server_url == _DEFAULT_SERVER_URL