Skip to content

Commit

Permalink
feat: Draft reconnection background task and publication retry
Browse files Browse the repository at this point in the history
  • Loading branch information
empicano committed Mar 22, 2024
1 parent f7697de commit d92fd30
Showing 1 changed file with 74 additions and 35 deletions.
109 changes: 74 additions & 35 deletions aiomqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ class Client:
password: The password to authenticate with.
logger: Custom logger instance.
identifier: The client identifier. Generated automatically if ``None``.
reconnect: If ``True``, the client will automatically reconnect to the broker
if the connection is lost. Defaults to ``False``.
queue_type: The class to use for the queue. The default is
``asyncio.Queue``, which stores messages in FIFO order. For LIFO order,
you can use ``asyncio.LifoQueue``; For priority order you can subclass
Expand Down Expand Up @@ -181,6 +183,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
password: str | None = None,
logger: logging.Logger | None = None,
identifier: str | None = None,
reconnect: bool = False,
queue_type: type[asyncio.Queue[Message]] | None = None,
protocol: ProtocolVersion | None = None,
will: Will | None = None,
Expand All @@ -206,6 +209,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
) -> None:
self._hostname = hostname
self._port = port
self._reconnect = reconnect
self._keepalive = keepalive
self._bind_address = bind_address
self._bind_port = bind_port
Expand All @@ -225,7 +229,10 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
self._pending_unsubscribes: dict[int, asyncio.Event] = {}
self._pending_publishes: dict[int, asyncio.Event] = {}
self.pending_calls_threshold: int = 10

# Background tasks
self._misc_task: asyncio.Task[None] | None = None
self._reconnection_task: asyncio.Task[None] | None = None

# Queue that holds incoming messages
if queue_type is None:
Expand Down Expand Up @@ -432,9 +439,17 @@ async def publish( # noqa: PLR0913
**kwargs: Additional keyword arguments to pass to paho-mqtt's publish
method.
"""
info = self._client.publish(
topic, payload, qos, retain, properties, *args, **kwargs
) # [2]
while True:
info = self._client.publish(
topic, payload, qos, retain, properties, *args, **kwargs
) # [2]
if not (info.rc == mqtt.MQTT_ERR_NO_CONN and self._reconnect):
break
while True:
with contextlib.suppress(asyncio.CancelledError):
await self._connected
break
self._connected = asyncio.Future()
# Early out on error
if info.rc != mqtt.MQTT_ERR_SUCCESS:
raise MqttCodeError(info.rc, "Could not publish message")
Expand Down Expand Up @@ -677,43 +692,65 @@ async def _misc_loop(self) -> None:
while self._client.loop_misc() == mqtt.MQTT_ERR_SUCCESS:
await asyncio.sleep(1)

async def _connect(self) -> None:
"""Connect to the broker. Retry indefinitely if self._reconnect is True."""
while True:
try:
try:
loop = asyncio.get_running_loop()
# [3] Run connect() within an executor thread, since it blocks on socket
# connection for up to `keepalive` seconds: https://git.io/Jt5Yc
await loop.run_in_executor(
None,
self._client.connect,
self._hostname,
self._port,
self._keepalive,
self._bind_address,
self._bind_port,
self._clean_start,
self._properties,
)
_set_client_socket_defaults(self._client.socket(), self._socket_options)
# Convert all possible paho-mqtt Client.connect exceptions to our MqttError
# See: https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L1770
except (OSError, mqtt.WebsocketConnectionError) as exc:
raise MqttError(str(exc)) from None
await self._wait_for(self._connected, timeout=None)
self._logger.info("Successfully connected to the broker.")
break
except MqttError:
# Reset internal state if the connection attempt failed
if self._connected.done():
self._connected = asyncio.Future()
if self._disconnected.done():
self._disconnected = asyncio.Future()
if not self._reconnect:
self._lock.release()
raise
self._logger.warning("Failed to connect. Trying again in 2 seconds...")
await asyncio.sleep(2)

async def _reconnection(self) -> None:
"""Reconnect when the connection is lost."""
while True:
with contextlib.suppress(MqttError):
await self._disconnected
self._logger.warning("Connection lost. Reconnecting...")
self._connected = asyncio.Future()
self._disconnected = asyncio.Future()
await self._connect()

async def __aenter__(self) -> Self:
"""Connect to the broker."""
if self._lock.locked():
msg = "The client context manager is reusable, but not reentrant"
raise MqttReentrantError(msg)
await self._lock.acquire()
try:
loop = asyncio.get_running_loop()
# [3] Run connect() within an executor thread, since it blocks on socket
# connection for up to `keepalive` seconds: https://git.io/Jt5Yc
await loop.run_in_executor(
None,
self._client.connect,
self._hostname,
self._port,
self._keepalive,
self._bind_address,
self._bind_port,
self._clean_start,
self._properties,
)
_set_client_socket_defaults(self._client.socket(), self._socket_options)
# Convert all possible paho-mqtt Client.connect exceptions to our MqttError
# See: https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L1770
except (OSError, mqtt.WebsocketConnectionError) as exc:
self._lock.release()
raise MqttError(str(exc)) from None
try:
await self._wait_for(self._connected, timeout=None)
except MqttError:
# Reset state if connection attempt times out or CONNACK returns negative
self._lock.release()
self._connected = asyncio.Future()
raise
# Reset `_disconnected` if it's already in completed state after connecting
if self._disconnected.done():
self._disconnected = asyncio.Future()
await self._connect()
# Start the reconnection task
if self._reconnect:
self._reconnection_task = asyncio.create_task(self._reconnection())
return self

async def __aexit__(
Expand All @@ -723,8 +760,10 @@ async def __aexit__(
tb: TracebackType | None,
) -> None:
"""Disconnect from the broker."""
if self._reconnect:
self._reconnection_task.cancel()
# Return early if the client is already disconnected
if self._disconnected.done():
# Return early if the client is already disconnected
if self._lock.locked():
self._lock.release()
if (exc := self._disconnected.exception()) is not None:
Expand Down

0 comments on commit d92fd30

Please sign in to comment.