diff --git a/rclpy/rclpy/lifecycle/publisher.py b/rclpy/rclpy/lifecycle/publisher.py index 7c31868a8..63999fab2 100644 --- a/rclpy/rclpy/lifecycle/publisher.py +++ b/rclpy/rclpy/lifecycle/publisher.py @@ -14,8 +14,8 @@ from typing import Union -from rclpy.publisher import MsgType from rclpy.publisher import Publisher +from rclpy.type_support import MsgT from .managed_entity import SimpleManagedEntity @@ -28,7 +28,7 @@ def __init__(self, *args, **kwargs): Publisher.__init__(self, *args, **kwargs) @SimpleManagedEntity.when_enabled - def publish(self, msg: Union[MsgType, bytes]) -> None: + def publish(self, msg: Union[MsgT, bytes]) -> None: """ Publish a message if the lifecycle publisher is enabled. diff --git a/rclpy/rclpy/node.py b/rclpy/rclpy/node.py index ec5576435..e21f523b6 100644 --- a/rclpy/rclpy/node.py +++ b/rclpy/rclpy/node.py @@ -83,6 +83,7 @@ from rclpy.type_description_service import TypeDescriptionService from rclpy.type_support import check_is_valid_msg_type from rclpy.type_support import check_is_valid_srv_type +from rclpy.type_support import MsgT from rclpy.utilities import get_default_context from rclpy.validate_full_topic_name import validate_full_topic_name from rclpy.validate_namespace import validate_namespace @@ -93,8 +94,9 @@ HIDDEN_NODE_PREFIX = '_' -# Used for documentation purposes only +# Left to support Legacy TypeVar. MsgType = TypeVar('MsgType') + SrvType = TypeVar('SrvType') SrvTypeRequest = TypeVar('SrvTypeRequest') SrvTypeResponse = TypeVar('SrvTypeResponse') @@ -1499,7 +1501,7 @@ def resolve_service_name( def create_publisher( self, - msg_type, + msg_type: Type[MsgT], topic: str, qos_profile: Union[QoSProfile, int], *, @@ -1507,7 +1509,7 @@ def create_publisher( event_callbacks: Optional[PublisherEventCallbacks] = None, qos_overriding_options: Optional[QoSOverridingOptions] = None, publisher_class: Type[Publisher] = Publisher, - ) -> Publisher: + ) -> Publisher[MsgT]: """ Create a new publisher. @@ -1573,16 +1575,16 @@ def create_publisher( def create_subscription( self, - msg_type, + msg_type: Type[MsgT], topic: str, - callback: Callable[[MsgType], None], + callback: Callable[[MsgT], None], qos_profile: Union[QoSProfile, int], *, callback_group: Optional[CallbackGroup] = None, event_callbacks: Optional[SubscriptionEventCallbacks] = None, qos_overriding_options: Optional[QoSOverridingOptions] = None, raw: bool = False - ) -> Subscription: + ) -> Subscription[MsgT]: """ Create a new subscription. diff --git a/rclpy/rclpy/publisher.py b/rclpy/rclpy/publisher.py index 2570461fd..d1a14ee7e 100644 --- a/rclpy/rclpy/publisher.py +++ b/rclpy/rclpy/publisher.py @@ -12,24 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TypeVar, Union +from typing import Generic, List, Type, TypeVar, Union from rclpy.callback_groups import CallbackGroup from rclpy.duration import Duration -from rclpy.event_handler import EventHandler -from rclpy.event_handler import PublisherEventCallbacks +from rclpy.event_handler import EventHandler, PublisherEventCallbacks from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy from rclpy.qos import QoSProfile +from rclpy.type_support import MsgT +# Left to support Legacy TypeVars. MsgType = TypeVar('MsgType') -class Publisher: +class Publisher(Generic[MsgT]): def __init__( self, publisher_impl: _rclpy.Publisher, - msg_type: MsgType, + msg_type: Type[MsgT], topic: str, qos_profile: QoSProfile, event_callbacks: PublisherEventCallbacks, @@ -54,10 +55,10 @@ def __init__( self.topic = topic self.qos_profile = qos_profile - self.event_handlers: EventHandler = event_callbacks.create_event_handlers( + self.event_handlers: List[EventHandler] = event_callbacks.create_event_handlers( callback_group, publisher_impl, topic) - def publish(self, msg: Union[MsgType, bytes]) -> None: + def publish(self, msg: Union[MsgT, bytes]) -> None: """ Send a message to the topic for the publisher. diff --git a/rclpy/rclpy/serialization.py b/rclpy/rclpy/serialization.py index 2c835c8a7..7b3a8fc2e 100644 --- a/rclpy/rclpy/serialization.py +++ b/rclpy/rclpy/serialization.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Type from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy -from rclpy.type_support import check_for_type_support +from rclpy.type_support import check_for_type_support, Msg, MsgT -def serialize_message(message) -> bytes: +def serialize_message(message: Msg) -> bytes: """ Serialize a ROS message. @@ -29,7 +30,7 @@ def serialize_message(message) -> bytes: return _rclpy.rclpy_serialize(message, message_type) -def deserialize_message(serialized_message: bytes, message_type): +def deserialize_message(serialized_message: bytes, message_type: Type[MsgT]) -> MsgT: """ Deserialize a ROS message. diff --git a/rclpy/rclpy/subscription.py b/rclpy/rclpy/subscription.py index 9f993a147..31eecdd67 100644 --- a/rclpy/rclpy/subscription.py +++ b/rclpy/rclpy/subscription.py @@ -12,23 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. + from enum import Enum import inspect -from typing import Callable -from typing import TypeVar +from typing import Callable, Generic, List, Type, TypeVar from rclpy.callback_groups import CallbackGroup -from rclpy.event_handler import EventHandler -from rclpy.event_handler import SubscriptionEventCallbacks +from rclpy.event_handler import EventHandler, SubscriptionEventCallbacks from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy from rclpy.qos import QoSProfile +from rclpy.type_support import MsgT -# For documentation only +# Left to support Legacy TypeVars. MsgType = TypeVar('MsgType') -class Subscription: +class Subscription(Generic[MsgT]): class CallbackType(Enum): MessageOnly = 0 @@ -37,9 +37,9 @@ class CallbackType(Enum): def __init__( self, subscription_impl: _rclpy.Subscription, - msg_type: MsgType, + msg_type: Type[MsgT], topic: str, - callback: Callable, + callback: Callable[[MsgT], None], callback_group: CallbackGroup, qos_profile: QoSProfile, raw: bool, @@ -73,7 +73,7 @@ def __init__( self.qos_profile = qos_profile self.raw = raw - self.event_handlers: EventHandler = event_callbacks.create_event_handlers( + self.event_handlers: List[EventHandler] = event_callbacks.create_event_handlers( callback_group, subscription_impl, topic) def get_publisher_count(self) -> int: @@ -102,11 +102,11 @@ def topic_name(self): return self.__subscription.get_topic_name() @property - def callback(self): + def callback(self) -> Callable[[MsgT], None]: return self._callback @callback.setter - def callback(self, value): + def callback(self, value: Callable[[MsgT], None]) -> None: self._callback = value self._callback_type = Subscription.CallbackType.MessageOnly try: diff --git a/rclpy/rclpy/type_support.py b/rclpy/rclpy/type_support.py index 9deddf5d8..29a4e1ae2 100644 --- a/rclpy/rclpy/type_support.py +++ b/rclpy/rclpy/type_support.py @@ -12,12 +12,64 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Protocol, Type, TypeVar, Union + from rclpy.exceptions import NoTypeSupportImportedException -def check_for_type_support(msg_or_srv_type): +class PyCapsule(Protocol): + """Alias for PyCapsule Pybind object.""" + + pass + + +# Done because metaclasses need to inherit from type +ProtocolType: Type = type(Protocol) + + +class CommonMsgSrvMetaClass(ProtocolType): + """Shared attributes between messages and services.""" + + _TYPE_SUPPORT: Optional[PyCapsule] + + @classmethod + def __import_type_support__(cls) -> None: + ... + + +class MsgMetaClass(CommonMsgSrvMetaClass): + """Generic Message Metaclass Alias.""" + + _CREATE_ROS_MESSAGE: Optional[PyCapsule] + _CONVERT_FROM_PY: Optional[PyCapsule] + _CONVERT_TO_PY: Optional[PyCapsule] + _DESTROY_ROS_MESSAGE: Optional[PyCapsule] + + +class Msg(Protocol, metaclass=MsgMetaClass): + """Generic Message Type Alias.""" + + pass + + +# Could likely be improved if generic across Request, Response, Event +class Srv(Protocol, metaclass=CommonMsgSrvMetaClass): + """Generic Service Type Alias.""" + + pass + + +MsgT = TypeVar('MsgT', bound=Msg) +SrvT = TypeVar('SrvT', bound=Srv) + +SrvRequestT = TypeVar('SrvRequestT', bound=Msg) +SrvResponseT = TypeVar('SrvResponseT', bound=Msg) +SrvEventT = TypeVar('SrvEventT', bound=Msg) + + +def check_for_type_support(msg_or_srv_type: Type[Union[Msg, Srv]]) -> None: try: - ts = msg_or_srv_type.__class__._TYPE_SUPPORT + ts = msg_or_srv_type._TYPE_SUPPORT except AttributeError as e: e.args = ( e.args[0] + @@ -26,19 +78,19 @@ def check_for_type_support(msg_or_srv_type): *e.args[1:]) raise if ts is None: - msg_or_srv_type.__class__.__import_type_support__() - if msg_or_srv_type.__class__._TYPE_SUPPORT is None: + msg_or_srv_type.__import_type_support__() + if msg_or_srv_type._TYPE_SUPPORT is None: raise NoTypeSupportImportedException() -def check_is_valid_msg_type(msg_type): +def check_is_valid_msg_type(msg_type: Type[Msg]) -> None: check_for_type_support(msg_type) try: assert None not in ( - msg_type.__class__._CREATE_ROS_MESSAGE, - msg_type.__class__._CONVERT_FROM_PY, - msg_type.__class__._CONVERT_TO_PY, - msg_type.__class__._DESTROY_ROS_MESSAGE, + msg_type._CREATE_ROS_MESSAGE, + msg_type._CONVERT_FROM_PY, + msg_type._CONVERT_TO_PY, + msg_type._DESTROY_ROS_MESSAGE, ) except (AssertionError, AttributeError): raise RuntimeError( @@ -47,7 +99,7 @@ def check_is_valid_msg_type(msg_type): ) from None -def check_is_valid_srv_type(srv_type): +def check_is_valid_srv_type(srv_type: Type[Srv]) -> None: check_for_type_support(srv_type) try: assert None not in (