Skip to content

Commit

Permalink
Remove generic type from BaseApiClient
Browse files Browse the repository at this point in the history
The `BaseApiClient` class is not generic anymore, and doesn't take a
function to create the stub. Instead, subclasses should create their own
stub right after calling the parent constructor. This enables subclasses
to cast the stub to the generated `XxxAsyncStub` class, which have
proper `async` type hints.

Signed-off-by: Leandro Lucarella <[email protected]>
  • Loading branch information
llucax committed Oct 29, 2024
1 parent 92028a8 commit 8048e1a
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 65 deletions.
25 changes: 25 additions & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,31 @@

* The `ExponentialBackoff` and `LinearBackoff` classes now require keyword arguments for their constructor. This change was made to make the classes easier to use and to avoid confusion with the order of the arguments.

* The `BaseApiClient` class is not generic anymore, and doesn't take a function to create the stub. Instead, subclasses should create their own stub right after calling the parent constructor. This enables subclasses to cast the stub to the generated `XxxAsyncStub` class, which have proper `async` type hints. To convert you client:

```python
# Old
class MyApiClient(BaseApiClient[MyServiceStub]):
def __init__(self, server_url: str, *, ...) -> None:
super().__init__(server_url, MyServiceStub, ...)
...

# New
class MyApiClient(BaseApiClient):
def __init__(self, server_url: str, *, ...) -> None:
super().__init__(server_url, connect=connect)
self._stub = cast(MyServiceAsyncStub, MyServiceStub(self.channel))
...

@property
def stub(self) -> ExampleAsyncStub:
if self._channel is None:
raise ClientNotConnected(server_url=self.server_url, operation="stub")
return self._stub
```

After this, you should be able to remove a lot of `cast`s or `type: ignore` from the code when calling the stub `async` methods.

## New Features

<!-- Here goes the main new features and examples or instructions on how to use them -->
Expand Down
83 changes: 47 additions & 36 deletions src/frequenz/client/base/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,15 @@
import abc
import inspect
from collections.abc import Awaitable, Callable
from typing import Any, Generic, Self, TypeVar, overload
from typing import Any, Self, TypeVar, overload

from grpc.aio import AioRpcError, Channel

from .channel import ChannelOptions, parse_grpc_uri
from .exception import ApiClientError, ClientNotConnected

StubT = TypeVar("StubT")
"""The type of the gRPC stub."""


class BaseApiClient(abc.ABC, Generic[StubT]):
class BaseApiClient(abc.ABC):
"""A base class for API clients.
This class provides a common interface for API clients that communicate with a API
Expand All @@ -32,12 +29,31 @@ class BaseApiClient(abc.ABC, Generic[StubT]):
a class that helps sending messages from a gRPC stream to
a [Broadcast][frequenz.channels.Broadcast] channel.
Note:
Because grpcio doesn't provide proper type hints, a hack is needed to have
propepr async type hints for the stubs generated by protoc. When using
`mypy-protobuf`, a `XxxAsyncStub` class is generated for each `XxxStub` class
but in the `.pyi` file, so the type can be used to specify type hints, but
**not** in any other context, as the class doesn't really exist for the Python
interpreter. This include generics, and because of this, this class can't be
even parametrized using the async class, so the instantiation of the stub can't
be done in the base class.
Because of this, subclasses need to create the stubs by themselves, using the
real stub class and casting it to the `XxxAsyncStub` class, so `mypy` can use
the async version of the stubs.
It is recommended to define a `stub` property that returns the async stub, so
this hack is completely hidden from clients, even if they need to access the
stub for more advanced uses.
Example:
This example illustrates how to create a simple API client that connects to a
gRPC server and calls a method on a stub.
```python
from collections.abc import AsyncIterable
from typing import cast
from frequenz.client.base.client import BaseApiClient, call_stub_method
from frequenz.client.base.streaming import GrpcStreamBroadcaster
from frequenz.channels import Receiver
Expand All @@ -57,25 +73,42 @@ async def example_method(
) -> ExampleResponse:
...
def example_stream(self) -> AsyncIterable[ExampleResponse]:
def example_stream(self, _: ExampleRequest) -> AsyncIterable[ExampleResponse]:
...
class ExampleAsyncStub:
async def example_method(
self,
request: ExampleRequest # pylint: disable=unused-argument
) -> ExampleResponse:
...
def example_stream(self, _: ExampleRequest) -> AsyncIterable[ExampleResponse]:
...
# End of generated classes
class ExampleResponseWrapper:
def __init__(self, response: ExampleResponse):
def __init__(self, response: ExampleResponse) -> None:
self.transformed_value = f"{response.float_value:.2f}"
class MyApiClient(BaseApiClient[ExampleStub]):
def __init__(self, server_url: str, *, connect: bool = True):
super().__init__(
server_url, ExampleStub, connect=connect
class MyApiClient(BaseApiClient):
def __init__(self, server_url: str, *, connect: bool = True) -> None:
super().__init__(server_url, connect=connect)
self._stub = cast(
ExampleAsyncStub, ExampleStub(self.channel)
)
self._broadcaster = GrpcStreamBroadcaster(
"stream",
lambda: self.stub.example_stream(ExampleRequest()),
ExampleResponseWrapper,
)
@property
def stub(self) -> ExampleAsyncStub:
if self._channel is None:
raise ClientNotConnected(server_url=self.server_url, operation="stub")
return self._stub
async def example_method(
self, int_value: int, str_value: str
) -> ExampleResponseWrapper:
Expand Down Expand Up @@ -114,7 +147,6 @@ async def main():
def __init__(
self,
server_url: str,
create_stub: Callable[[Channel], StubT],
*,
connect: bool = True,
channel_defaults: ChannelOptions = ChannelOptions(),
Expand All @@ -123,7 +155,6 @@ def __init__(
Args:
server_url: The URL of the server to connect to.
create_stub: A function that creates a stub from a channel.
connect: Whether to connect to the server as soon as a client instance is
created. If `False`, the client will not connect to the server until
[connect()][frequenz.client.base.client.BaseApiClient.connect] is
Expand All @@ -132,10 +163,8 @@ def __init__(
the server URL.
"""
self._server_url: str = server_url
self._create_stub: Callable[[Channel], StubT] = create_stub
self._channel_defaults: ChannelOptions = channel_defaults
self._channel: Channel | None = None
self._stub: StubT | None = None
if connect:
self.connect(server_url)

Expand Down Expand Up @@ -165,22 +194,6 @@ def channel_defaults(self) -> ChannelOptions:
"""The default options for the gRPC channel."""
return self._channel_defaults

@property
def stub(self) -> StubT:
"""The underlying gRPC stub.
Warning:
This stub is provided as a last resort for advanced users. It is not
recommended to use this property directly unless you know what you are
doing and you don't care about being tied to a specific gRPC library.
Raises:
ClientNotConnected: If the client is not connected to the server.
"""
if self._stub is None:
raise ClientNotConnected(server_url=self.server_url, operation="stub")
return self._stub

@property
def is_connected(self) -> bool:
"""Whether the client is connected to the server."""
Expand All @@ -202,7 +215,6 @@ def connect(self, server_url: str | None = None) -> None:
elif self.is_connected:
return
self._channel = parse_grpc_uri(self._server_url, self._channel_defaults)
self._stub = self._create_stub(self._channel)

async def disconnect(self) -> None:
"""Disconnect from the server.
Expand All @@ -227,7 +239,6 @@ async def __aexit__(
return None
result = await self._channel.__aexit__(_exc_type, _exc_val, _exc_tb)
self._channel = None
self._stub = None
return result


Expand All @@ -240,7 +251,7 @@ async def __aexit__(

@overload
async def call_stub_method(
client: BaseApiClient[StubT],
client: BaseApiClient,
stub_method: Callable[[], Awaitable[StubOutT]],
*,
method_name: str | None = None,
Expand All @@ -250,7 +261,7 @@ async def call_stub_method(

@overload
async def call_stub_method(
client: BaseApiClient[StubT],
client: BaseApiClient,
stub_method: Callable[[], Awaitable[StubOutT]],
*,
method_name: str | None = None,
Expand All @@ -261,7 +272,7 @@ async def call_stub_method(
# We need the `noqa: DOC503` because `pydoclint` can't figure out that
# `ApiClientError.from_grpc_error()` returns a `GrpcError` instance.
async def call_stub_method( # noqa: DOC503
client: BaseApiClient[StubT],
client: BaseApiClient,
stub_method: Callable[[], Awaitable[StubOutT]],
*,
method_name: str | None = None,
Expand Down
32 changes: 3 additions & 29 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
import pytest_mock

from frequenz.client.base.channel import ChannelOptions, SslOptions
from frequenz.client.base.client import BaseApiClient, StubT, call_stub_method
from frequenz.client.base.client import BaseApiClient, call_stub_method
from frequenz.client.base.exception import ClientNotConnected, UnknownError


def _auto_connect_name(auto_connect: bool) -> str:
return f"{auto_connect=}"


def _assert_is_disconnected(client: BaseApiClient[StubT]) -> None:
def _assert_is_disconnected(client: BaseApiClient) -> None:
"""Assert that the client is disconnected."""
assert not client.is_connected

Expand All @@ -30,17 +30,9 @@ def _assert_is_disconnected(client: BaseApiClient[StubT]) -> None:
assert exc.server_url == _DEFAULT_SERVER_URL
assert exc.operation == "channel"

with pytest.raises(ClientNotConnected, match=r"") as exc_info:
_ = client.stub
exc = exc_info.value
assert exc.server_url == _DEFAULT_SERVER_URL
assert exc.operation == "stub"


@dataclass(kw_only=True, frozen=True)
class _ClientMocks:
stub: mock.MagicMock
create_stub: mock.MagicMock
channel: mock.MagicMock
parse_grpc_uri: mock.MagicMock

Expand All @@ -54,10 +46,8 @@ def create_client_with_mocks(
auto_connect: bool = True,
server_url: str = _DEFAULT_SERVER_URL,
channel_defaults: ChannelOptions | None = None,
) -> tuple[BaseApiClient[mock.MagicMock], _ClientMocks]:
) -> tuple[BaseApiClient, _ClientMocks]:
"""Create a BaseApiClient instance with mocks."""
mock_stub = mock.MagicMock(name="stub")
mock_create_stub = mock.MagicMock(name="create_stub", return_value=mock_stub)
mock_channel = mock.MagicMock(name="channel", spec=grpc.aio.Channel)
mock_parse_grpc_uri = mocker.patch(
"frequenz.client.base.client.parse_grpc_uri", return_value=mock_channel
Expand All @@ -67,13 +57,10 @@ def create_client_with_mocks(
kwargs["channel_defaults"] = channel_defaults
client = BaseApiClient(
server_url=server_url,
create_stub=mock_create_stub,
connect=auto_connect,
**kwargs,
)
return client, _ClientMocks(
stub=mock_stub,
create_stub=mock_create_stub,
channel=mock_channel,
parse_grpc_uri=mock_parse_grpc_uri,
)
Expand All @@ -92,13 +79,10 @@ def test_base_api_client_init(
client.server_url, ChannelOptions()
)
assert client.channel is mocks.channel
assert client.stub is mocks.stub
assert client.is_connected
mocks.create_stub.assert_called_once_with(mocks.channel)
else:
_assert_is_disconnected(client)
mocks.parse_grpc_uri.assert_not_called()
mocks.create_stub.assert_not_called()


def test_base_api_client_init_with_channel_defaults(
Expand All @@ -110,9 +94,7 @@ def test_base_api_client_init_with_channel_defaults(
assert client.server_url == _DEFAULT_SERVER_URL
mocks.parse_grpc_uri.assert_called_once_with(client.server_url, channel_defaults)
assert client.channel is mocks.channel
assert client.stub is mocks.stub
assert client.is_connected
mocks.create_stub.assert_called_once_with(mocks.channel)


@pytest.mark.parametrize(
Expand All @@ -129,12 +111,10 @@ def test_base_api_client_connect(
# We want to check only what happens when we call connect, so we reset the mocks
# that were called during initialization
mocks.parse_grpc_uri.reset_mock()
mocks.create_stub.reset_mock()

client.connect(new_server_url)

assert client.channel is mocks.channel
assert client.stub is mocks.stub
assert client.is_connected

same_url = new_server_url is None or new_server_url == _DEFAULT_SERVER_URL
Expand All @@ -148,12 +128,10 @@ def test_base_api_client_connect(
# reconnect
if auto_connect and same_url:
mocks.parse_grpc_uri.assert_not_called()
mocks.create_stub.assert_not_called()
else:
mocks.parse_grpc_uri.assert_called_once_with(
client.server_url, ChannelOptions()
)
mocks.create_stub.assert_called_once_with(mocks.channel)


async def test_base_api_client_disconnect(mocker: pytest_mock.MockFixture) -> None:
Expand All @@ -177,23 +155,19 @@ async def test_base_api_client_async_context_manager(
# We want to check only what happens when we enter the context manager, so we reset
# the mocks that were called during initialization
mocks.parse_grpc_uri.reset_mock()
mocks.create_stub.reset_mock()

async with client:
assert client.channel is mocks.channel
assert client.stub is mocks.stub
assert client.is_connected
mocks.channel.__aexit__.assert_not_called()
# If we were previously connected, the client should not reconnect when entering
# the context manager
if auto_connect:
mocks.parse_grpc_uri.assert_not_called()
mocks.create_stub.assert_not_called()
else:
mocks.parse_grpc_uri.assert_called_once_with(
client.server_url, ChannelOptions()
)
mocks.create_stub.assert_called_once_with(mocks.channel)

mocks.channel.__aexit__.assert_called_once_with(None, None, None)
assert client.server_url == _DEFAULT_SERVER_URL
Expand Down

0 comments on commit 8048e1a

Please sign in to comment.