Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to paho-mqtt 2.0 #286

Merged
merged 3 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 30 additions & 24 deletions aiomqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@
Generator,
Iterable,
Iterator,
Literal,
TypeVar,
cast,
)

import paho.mqtt.client as mqtt
from paho.mqtt.enums import CallbackAPIVersion
from paho.mqtt.properties import Properties
from paho.mqtt.reasoncodes import ReasonCode
from paho.mqtt.subscribeoptions import SubscribeOptions

from .exceptions import MqttCodeError, MqttConnectError, MqttError, MqttReentrantError
from .message import Message
Expand Down Expand Up @@ -116,7 +121,7 @@ class Will:
payload: PayloadType | None = None
qos: int = 0
retain: bool = False
properties: mqtt.Properties | None = None
properties: Properties | None = None


class Client:
Expand Down Expand Up @@ -185,17 +190,17 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
protocol: ProtocolVersion | None = None,
will: Will | None = None,
clean_session: bool | None = None,
transport: str = "tcp",
transport: Literal["tcp", "websockets"] = "tcp",
timeout: float | None = None,
keepalive: int = 60,
bind_address: str = "",
bind_port: int = 0,
clean_start: int = mqtt.MQTT_CLEAN_START_FIRST_ONLY,
clean_start: mqtt.CleanStartOption = mqtt.MQTT_CLEAN_START_FIRST_ONLY,
max_queued_incoming_messages: int | None = None,
max_queued_outgoing_messages: int | None = None,
max_inflight_messages: int | None = None,
max_concurrent_outgoing_calls: int | None = None,
properties: mqtt.Properties | None = None,
properties: Properties | None = None,
tls_context: ssl.SSLContext | None = None,
tls_params: TLSParameters | None = None,
tls_insecure: bool | None = None,
Expand All @@ -220,7 +225,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915

# Pending subscribe, unsubscribe, and publish calls
self._pending_subscribes: dict[
int, asyncio.Future[tuple[int] | list[mqtt.ReasonCodes]]
int, asyncio.Future[tuple[int, ...] | list[ReasonCode]]
] = {}
self._pending_unsubscribes: dict[int, asyncio.Event] = {}
self._pending_publishes: dict[int, asyncio.Event] = {}
Expand All @@ -247,7 +252,8 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915

# Create the underlying paho-mqtt client instance
self._client: mqtt.Client = mqtt.Client(
client_id=identifier,
callback_api_version=CallbackAPIVersion.VERSION1,
client_id=identifier, # type: ignore[arg-type]
protocol=protocol,
clean_session=clean_session,
transport=transport,
Expand Down Expand Up @@ -322,7 +328,7 @@ def identifier(self) -> str:
Note that paho-mqtt stores the client ID as `bytes` internally. We assume that
the client ID is a UTF8-encoded string and decode it first.
"""
return cast(bytes, self._client._client_id).decode() # type: ignore[attr-defined] # noqa: SLF001
return self._client._client_id.decode() # noqa: SLF001

@property
def _pending_calls(self) -> Generator[int, None, None]:
Expand All @@ -337,12 +343,12 @@ async def subscribe( # noqa: PLR0913
/,
topic: SubscribeTopic,
qos: int = 0,
options: mqtt.SubscribeOptions | None = None,
properties: mqtt.Properties | None = None,
options: SubscribeOptions | None = None,
properties: Properties | None = None,
*args: Any,
timeout: float | None = None,
**kwargs: Any,
) -> tuple[int] | list[mqtt.ReasonCodes]:
) -> tuple[int, ...] | list[ReasonCode]:
"""Subscribe to a topic or wildcard.

Args:
Expand All @@ -362,11 +368,11 @@ async def subscribe( # noqa: PLR0913
topic, qos, options, properties, *args, **kwargs
)
# Early out on error
if result != mqtt.MQTT_ERR_SUCCESS:
if result != mqtt.MQTT_ERR_SUCCESS or mid is None:
raise MqttCodeError(result, "Could not subscribe to topic")
# Create future for when the on_subscribe callback is called
callback_result: asyncio.Future[
tuple[int] | list[mqtt.ReasonCodes]
tuple[int, ...] | list[ReasonCode]
] = asyncio.Future()
with self._pending_call(mid, callback_result, self._pending_subscribes):
# Wait for callback_result
Expand All @@ -377,7 +383,7 @@ async def unsubscribe(
self,
/,
topic: str | list[str],
properties: mqtt.Properties | None = None,
properties: Properties | None = None,
*args: Any,
timeout: float | None = None,
**kwargs: Any,
Expand All @@ -394,9 +400,9 @@ async def unsubscribe(
**kwargs: Additional keyword arguments to pass to paho-mqtt's unsubscribe
method.
"""
result, mid = self._client.unsubscribe(topic, properties, *args, **kwargs)
result, mid = self._client.unsubscribe(topic, properties, *args, **kwargs) # type: ignore[arg-type]
# Early out on error
if result != mqtt.MQTT_ERR_SUCCESS:
if result != mqtt.MQTT_ERR_SUCCESS or mid is None:
raise MqttCodeError(result, "Could not unsubscribe from topic")
# Create event for when the on_unsubscribe callback is called
confirmation = asyncio.Event()
Expand All @@ -412,7 +418,7 @@ async def publish( # noqa: PLR0913
payload: PayloadType = None,
qos: int = 0,
retain: bool = False,
properties: mqtt.Properties | None = None,
properties: Properties | None = None,
*args: Any,
timeout: float | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -518,8 +524,8 @@ def _on_connect( # noqa: PLR0913
client: mqtt.Client,
userdata: Any,
flags: dict[str, int],
rc: int | mqtt.ReasonCodes,
properties: mqtt.Properties | None = None,
rc: int | ReasonCode,
properties: Properties | None = None,
) -> None:
"""Called when we receive a CONNACK message from the broker."""
# Return early if already connected. Sometimes, paho-mqtt calls _on_connect
Expand All @@ -538,8 +544,8 @@ def _on_disconnect(
self,
client: mqtt.Client,
userdata: Any,
rc: int | mqtt.ReasonCodes | None,
properties: mqtt.Properties | None = None,
rc: int | ReasonCode | None,
properties: Properties | None = None,
) -> None:
# Return early if the disconnect is already acknowledged.
# Sometimes (e.g., due to timeouts), paho-mqtt calls _on_disconnect
Expand Down Expand Up @@ -570,8 +576,8 @@ def _on_subscribe( # noqa: PLR0913
client: mqtt.Client,
userdata: Any,
mid: int,
granted_qos: tuple[int] | list[mqtt.ReasonCodes],
properties: mqtt.Properties | None = None,
granted_qos: tuple[int, ...] | list[ReasonCode],
properties: Properties | None = None,
) -> None:
"""Called when we receive a SUBACK message from the broker."""
try:
Expand All @@ -588,8 +594,8 @@ def _on_unsubscribe( # noqa: PLR0913
client: mqtt.Client,
userdata: Any,
mid: int,
properties: mqtt.Properties | None = None,
reason_codes: list[mqtt.ReasonCodes] | mqtt.ReasonCodes | None = None,
properties: Properties | None = None,
reason_codes: list[ReasonCode] | ReasonCode | None = None,
) -> None:
"""Called when we receive an UNSUBACK message from the broker."""
try:
Expand Down
11 changes: 6 additions & 5 deletions aiomqtt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,29 @@
from typing import Any

import paho.mqtt.client as mqtt
from paho.mqtt.reasoncodes import ReasonCode


class MqttError(Exception):
pass


class MqttCodeError(MqttError):
def __init__(self, rc: int | mqtt.ReasonCodes | None, *args: Any) -> None:
def __init__(self, rc: int | ReasonCode | None, *args: Any) -> None:
super().__init__(*args)
self.rc = rc

def __str__(self) -> str:
if isinstance(self.rc, mqtt.ReasonCodes):
if isinstance(self.rc, ReasonCode):
return f"[code:{self.rc.value}] {self.rc!s}"
if isinstance(self.rc, int):
return f"[code:{self.rc}] {mqtt.error_string(self.rc)}"
return f"[code:{self.rc}] {mqtt.error_string(self.rc)}" # type: ignore[arg-type]
return f"[code:{self.rc}] {super().__str__()}"


class MqttConnectError(MqttCodeError):
def __init__(self, rc: int | mqtt.ReasonCodes) -> None:
if isinstance(rc, mqtt.ReasonCodes):
def __init__(self, rc: int | ReasonCode) -> None:
if isinstance(rc, ReasonCode):
super().__init__(rc)
return
msg = "Connection refused"
Expand Down
3 changes: 2 additions & 1 deletion aiomqtt/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys

import paho.mqtt.client as mqtt
from paho.mqtt.properties import Properties

if sys.version_info >= (3, 11):
from typing import Self
Expand Down Expand Up @@ -50,7 +51,7 @@ def __init__( # noqa: PLR0913
qos: int,
retain: bool,
mid: int,
properties: mqtt.Properties | None,
properties: Properties | None,
) -> None:
self.topic = Topic(topic) if not isinstance(topic, Topic) else topic
self.payload = payload
Expand Down
6 changes: 3 additions & 3 deletions aiomqtt/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys
from typing import Any, Callable, TypeVar

import paho.mqtt.client as mqtt
from paho.mqtt.subscribeoptions import SubscribeOptions

if sys.version_info >= (3, 10):
from typing import ParamSpec, TypeAlias
Expand All @@ -18,10 +18,10 @@
P = ParamSpec("P")

PayloadType: TypeAlias = "str | bytes | bytearray | int | float | None"
SubscribeTopic: TypeAlias = "str | tuple[str, mqtt.SubscribeOptions] | list[tuple[str, mqtt.SubscribeOptions]] | list[tuple[str, int]]"
SubscribeTopic: TypeAlias = "str | tuple[str, SubscribeOptions] | list[tuple[str, SubscribeOptions]] | list[tuple[str, int]]"
WebSocketHeaders: TypeAlias = (
"dict[str, str] | Callable[[dict[str, str]], dict[str, str]]"
)
_PahoSocket: TypeAlias = "socket.socket | ssl.SSLSocket | mqtt.WebsocketWrapper | Any"
_PahoSocket: TypeAlias = "socket.socket | ssl.SSLSocket | Any"
# See the overloads of `socket.setsockopt` for details.
SocketOption: TypeAlias = "tuple[int, int, int | bytes] | tuple[int, int, None, int]"
Loading
Loading