From 7bc00e66d9884cd1befbe109afa671c4b430090e Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Fri, 3 May 2024 15:52:08 +0200 Subject: [PATCH] Refactor --- posttroll/__init__.py | 4 +- posttroll/address_receiver.py | 28 +- posttroll/backends/zmq/__init__.py | 55 +++- posttroll/backends/zmq/address_receiver.py | 24 +- posttroll/backends/zmq/message_broadcaster.py | 21 +- posttroll/backends/zmq/ns.py | 108 ++++---- posttroll/backends/zmq/publisher.py | 97 +------ posttroll/backends/zmq/socket.py | 106 ++++++++ posttroll/backends/zmq/subscriber.py | 120 +++------ posttroll/bbmcast.py | 10 +- posttroll/message_broadcaster.py | 15 +- posttroll/ns.py | 17 +- posttroll/publisher.py | 15 +- posttroll/subscriber.py | 18 +- posttroll/tests/test_nameserver.py | 253 ++++++++++++++++++ posttroll/tests/test_pubsub.py | 219 +-------------- posttroll/tests/test_secure_zmq_backend.py | 67 +++-- 17 files changed, 658 insertions(+), 519 deletions(-) create mode 100644 posttroll/backends/zmq/socket.py create mode 100644 posttroll/tests/test_nameserver.py diff --git a/posttroll/__init__.py b/posttroll/__init__.py index aece644..46b0eaf 100644 --- a/posttroll/__init__.py +++ b/posttroll/__init__.py @@ -30,7 +30,7 @@ from donfig import Config -config = Config("posttroll") +config = Config("posttroll", defaults=[dict(backend="unsecure_zmq")]) # context = {} logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ def get_context(): This function takes care of creating new contexts in case of forks. """ - backend = config.get("backend", "unsecure_zmq") + backend = config["backend"] if "zmq" in backend: from posttroll.backends.zmq import get_context return get_context() diff --git a/posttroll/address_receiver.py b/posttroll/address_receiver.py index 2f0a24d..d2cef04 100644 --- a/posttroll/address_receiver.py +++ b/posttroll/address_receiver.py @@ -37,16 +37,16 @@ import netifaces from posttroll import config -from posttroll.bbmcast import MulticastReceiver, SocketTimeout +from posttroll.bbmcast import MulticastReceiver, SocketTimeout, get_configured_broadcast_port from posttroll.message import Message from posttroll.publisher import Publish +from zmq import ZMQError __all__ = ("AddressReceiver", "getaddress") LOGGER = logging.getLogger(__name__) debug = os.environ.get("DEBUG", False) -broadcast_port = 21200 DEFAULT_ADDRESS_PUBLISH_PORT = 16543 @@ -144,13 +144,13 @@ def _check_age(self, pub, min_interval=zero_seconds): msg = Message("/address/" + metadata["name"], "info", mda) to_del.append(addr) LOGGER.info(f"publish remove '{msg}'") - pub.send(msg.encode()) + pub.send(str(msg.encode())) for addr in to_del: del self._addresses[addr] def _run(self): """Run the receiver.""" - port = broadcast_port + port = get_configured_broadcast_port() nameservers, recv = self.set_up_address_receiver(port) self._is_running = True @@ -159,7 +159,16 @@ def _run(self): try: while self._do_run: try: - data, fromaddr = recv() + rerun = True + while rerun: + try: + data, fromaddr = recv() + rerun = False + except TimeoutError: + if self._do_run: + continue + else: + raise if self._multicast_enabled: ip_, port = fromaddr if self._restrict_to_localhost and ip_ not in self._local_ips: @@ -171,6 +180,8 @@ def _run(self): if self._multicast_enabled: LOGGER.debug("Multicast socket timed out on recv!") continue + except ZMQError: + return finally: self._check_age(pub, min_interval=self._max_age / 20) if self._do_heartbeat: @@ -216,9 +227,10 @@ def set_up_address_receiver(self, port): break else: - if config.get("backend", "unsecure_zmq") == "unsecure_zmq": - from posttroll.backends.zmq.address_receiver import SimpleReceiver - recv = SimpleReceiver(port) + if config["backend"] not in ["unsecure_zmq", "secure_zmq"]: + raise NotImplementedError + from posttroll.backends.zmq.address_receiver import SimpleReceiver + recv = SimpleReceiver(port, timeout=2) nameservers = ["localhost"] return nameservers,recv diff --git a/posttroll/backends/zmq/__init__.py b/posttroll/backends/zmq/__init__.py index 17a60f9..2cd6597 100644 --- a/posttroll/backends/zmq/__init__.py +++ b/posttroll/backends/zmq/__init__.py @@ -5,6 +5,7 @@ import zmq from posttroll import config +from posttroll.message import Message logger = logging.getLogger(__name__) context = {} @@ -21,14 +22,54 @@ def get_context(): logger.debug("renewed context for PID %d", pid) return context[pid] +def destroy_context(linger=None): + pid = os.getpid() + context.pop(pid).destroy(linger) + def _set_tcp_keepalive(socket): """Set the tcp keepalive parameters on *socket*.""" - _set_int_sockopt(socket, zmq.TCP_KEEPALIVE, config.get("tcp_keepalive", None)) - _set_int_sockopt(socket, zmq.TCP_KEEPALIVE_CNT, config.get("tcp_keepalive_cnt", None)) - _set_int_sockopt(socket, zmq.TCP_KEEPALIVE_IDLE, config.get("tcp_keepalive_idle", None)) - _set_int_sockopt(socket, zmq.TCP_KEEPALIVE_INTVL, config.get("tcp_keepalive_intvl", None)) + keepalive_options = get_tcp_keepalive_options() + for param, value in keepalive_options.items(): + socket.setsockopt(param, value) + +def get_tcp_keepalive_options(): + """Get the tcp_keepalive options from config.""" + keepalive_options = dict() + for opt in ("tcp_keepalive", + "tcp_keepalive_cnt", + "tcp_keepalive_idle", + "tcp_keepalive_intvl"): + try: + value = int(config[opt]) + except (KeyError, TypeError): + continue + param = getattr(zmq, opt.upper()) + keepalive_options[param] = value + return keepalive_options + + +class SocketReceiver: + + def __init__(self): + self._poller = zmq.Poller() + + def register(self, socket): + """Register the socket.""" + self._poller.register(socket, zmq.POLLIN) + def unregister(self, socket): + """Unregister the socket.""" + self._poller.unregister(socket) -def _set_int_sockopt(socket, param, value): - if value is not None: - socket.setsockopt(param, int(value)) + def receive(self, *sockets, timeout=None): + """Timeout is in seconds.""" + if timeout: + timeout *= 1000 + socks = dict(self._poller.poll(timeout=timeout)) + if socks: + for sock in sockets: + if socks.get(sock) == zmq.POLLIN: + received = sock.recv_string(zmq.NOBLOCK) + yield Message.decode(received), sock + else: + raise TimeoutError("Did not receive anything on sockets.") diff --git a/posttroll/backends/zmq/address_receiver.py b/posttroll/backends/zmq/address_receiver.py index 8eb22f6..f926747 100644 --- a/posttroll/backends/zmq/address_receiver.py +++ b/posttroll/backends/zmq/address_receiver.py @@ -3,25 +3,35 @@ from zmq import LINGER, REP from posttroll.address_receiver import get_configured_address_port -from posttroll.backends.zmq import get_context +from posttroll.backends.zmq.socket import set_up_server_socket class SimpleReceiver(object): """Simple listing on port for address messages.""" - def __init__(self, port=None): + def __init__(self, port=None, timeout=2): """Set up the receiver.""" self._port = port or get_configured_address_port() - self._socket = get_context().socket(REP) - self._socket.bind("tcp://*:" + str(port)) + address = "tcp://*:" + str(port) + self._socket, _, self._authenticator = set_up_server_socket(REP, address) + self._running = True + self.timeout = timeout def __call__(self): """Receive a message.""" - message = self._socket.recv_string() - self._socket.send_string("ok") - return message, None + while self._running: + try: + message = self._socket.recv_string(self.timeout) + except TimeoutError: + continue + else: + self._socket.send_string("ok") + return message, None def close(self): """Close the receiver.""" + self._running = False self._socket.setsockopt(LINGER, 1) self._socket.close() + if self._authenticator: + self._authenticator.stop() diff --git a/posttroll/backends/zmq/message_broadcaster.py b/posttroll/backends/zmq/message_broadcaster.py index 060e9ae..fe2ddfe 100644 --- a/posttroll/backends/zmq/message_broadcaster.py +++ b/posttroll/backends/zmq/message_broadcaster.py @@ -3,20 +3,19 @@ import logging import threading +from posttroll.backends.zmq.socket import set_up_client_socket from zmq import LINGER, NOBLOCK, REQ, ZMQError -from posttroll.backends.zmq import get_context logger = logging.getLogger(__name__) -class UnsecureZMQDesignatedReceiversSender: +class ZMQDesignatedReceiversSender: """Sends message to multiple *receivers* on *port*.""" def __init__(self, default_port, receivers): """Set up the sender.""" self.default_port = default_port - self.receivers = receivers self._shutdown_event = threading.Event() @@ -28,13 +27,14 @@ def __call__(self, data): def _send_to_address(self, address, data, timeout=10): """Send data to *address* and *port* without verification of response.""" # Socket to talk to server - socket = get_context().socket(REQ) + if address.find(":") == -1: + full_address = "tcp://%s:%d" % (address, self.default_port) + else: + full_address = "tcp://%s" % address + options = {LINGER: int(timeout * 1000)} + socket = set_up_client_socket(REQ, full_address, options) try: - socket.setsockopt(LINGER, timeout * 1000) - if address.find(":") == -1: - socket.connect("tcp://%s:%d" % (address, self.default_port)) - else: - socket.connect("tcp://%s" % address) + socket.send_string(data) while not self._shutdown_event.is_set(): try: @@ -43,10 +43,11 @@ def _send_to_address(self, address, data, timeout=10): self._shutdown_event.wait(.1) continue if message != "ok": - logger.warn("invalid acknowledge received: %s" % message) + logger.warning("invalid acknowledge received: %s" % message) break finally: + socket.setsockopt(LINGER, 1) socket.close() def close(self): diff --git a/posttroll/backends/zmq/ns.py b/posttroll/backends/zmq/ns.py index 4f7214c..3325272 100644 --- a/posttroll/backends/zmq/ns.py +++ b/posttroll/backends/zmq/ns.py @@ -1,11 +1,13 @@ """ZMQ implexentation of ns.""" import logging +from contextlib import suppress from threading import Lock -from zmq import LINGER, NOBLOCK, POLLIN, REP, REQ, Poller +from posttroll.backends.zmq.socket import set_up_client_socket, set_up_server_socket +from zmq import LINGER, REP, REQ +from posttroll.backends.zmq import SocketReceiver -from posttroll.backends.zmq import get_context from posttroll.message import Message from posttroll.ns import get_active_address, get_configured_nameserver_port @@ -14,78 +16,94 @@ nslock = Lock() -def unsecure_zmq_get_pub_address(name, timeout=10, nameserver="localhost"): +def zmq_get_pub_address(name, timeout=10, nameserver="localhost"): """Get the address of the publisher. For a given publisher *name* from the nameserver on *nameserver* (localhost by default). """ + nameserver_address = create_nameserver_address(nameserver) # Socket to talk to server - socket = get_context().socket(REQ) + logger.debug(f"Connecting to {nameserver_address}") + socket = create_req_socket(timeout, nameserver_address) + return _fetch_address_using_socket(socket, name, timeout) + + +def create_nameserver_address(nameserver): + port = get_configured_nameserver_port() + nameserver_address = "tcp://" + nameserver + ":" + str(port) + return nameserver_address + + +def _fetch_address_using_socket(socket, name, timeout): try: - port = get_configured_nameserver_port() - socket.setsockopt(LINGER, int(timeout * 1000)) - socket.connect("tcp://" + nameserver + ":" + str(port)) - logger.debug("Connecting to %s", - "tcp://" + nameserver + ":" + str(port)) - poller = Poller() - poller.register(socket, POLLIN) + socket_receiver = SocketReceiver() + socket_receiver.register(socket) message = Message("/oper/ns", "request", {"service": name}) socket.send_string(str(message)) # Get the reply. - sock = poller.poll(timeout=timeout * 1000) - if sock: - if sock[0][0] == socket: - message = Message.decode(socket.recv_string(NOBLOCK)) - return message.data - else: - raise TimeoutError("Didn't get an address after %d seconds." - % timeout) + #socket.poll(timeout) + #message = socket.recv(timeout) + for message, _ in socket_receiver.receive(socket, timeout=timeout): + return message.data + except TimeoutError: + raise TimeoutError("Didn't get an address after %d seconds." + % timeout) finally: + socket_receiver.unregister(socket) + socket.setsockopt(LINGER, 1) socket.close() +def create_req_socket(timeout, nameserver_address): + options = {LINGER: int(timeout * 1000)} + socket = set_up_client_socket(REQ, nameserver_address, options) + return socket -class UnsecureZMQNameServer: +class ZMQNameServer: """The name server.""" def __init__(self): """Set up the nameserver.""" - self.loop = True + self.running = True self.listener = None - def run(self, arec): + def run(self, address_receiver): """Run the listener and answer to requests.""" port = get_configured_nameserver_port() try: - with nslock: - self.listener = get_context().socket(REP) - self.listener.bind("tcp://*:" + str(port)) - logger.debug(f"Nameserver listening on port {port}") - poller = Poller() - poller.register(self.listener, POLLIN) - while self.loop: - with nslock: - socks = dict(poller.poll(1000)) - if socks: - if socks.get(self.listener) == POLLIN: - msg = self.listener.recv_string() - else: - continue - logger.debug("Replying to request: " + str(msg)) - msg = Message.decode(msg) - active_address = get_active_address(msg.data["service"], arec) - self.listener.send_unicode(str(active_address)) + # stop was called before we could start running, exit + if not self.running: + return + address = "tcp://*:" + str(port) + self.listener, _, self._authenticator = set_up_server_socket(REP, address) + logger.debug(f"Nameserver listening on port {port}") + socket_receiver = SocketReceiver() + socket_receiver.register(self.listener) + while self.running: + try: + for msg, _ in socket_receiver.receive(self.listener, timeout=1): + logger.debug("Replying to request: " + str(msg)) + active_address = get_active_address(msg.data["service"], address_receiver) + self.listener.send_unicode(str(active_address)) + except TimeoutError: + continue except KeyboardInterrupt: # Needed to stop the nameserver. pass finally: - self.stop() + socket_receiver.unregister(self.listener) + self.close_sockets_and_threads() + + def close_sockets_and_threads(self): + with suppress(AttributeError): + self.listener.setsockopt(LINGER, 1) + self.listener.close() + with suppress(AttributeError): + self._authenticator.stop() + def stop(self): """Stop the name server.""" - self.listener.setsockopt(LINGER, 1) - self.loop = False - with nslock: - self.listener.close() + self.running = False diff --git a/posttroll/backends/zmq/publisher.py b/posttroll/backends/zmq/publisher.py index ffbf2d6..37a4898 100644 --- a/posttroll/backends/zmq/publisher.py +++ b/posttroll/backends/zmq/publisher.py @@ -1,18 +1,18 @@ """ZMQ implementation of the publisher.""" +from contextlib import suppress import logging from threading import Lock -from urllib.parse import urlsplit, urlunsplit +from posttroll.backends.zmq.socket import set_up_server_socket import zmq -from zmq.auth.thread import ThreadAuthenticator -from posttroll.backends.zmq import _set_tcp_keepalive, get_context +from posttroll.backends.zmq import get_tcp_keepalive_options LOGGER = logging.getLogger(__name__) -class UnsecureZMQPublisher: +class ZMQPublisher: """Unsecure ZMQ implementation of the publisher class.""" def __init__(self, address, name="", min_port=None, max_port=None): @@ -32,33 +32,20 @@ def __init__(self, address, name="", min_port=None, max_port=None): self.max_port = max_port self.port_number = None self._pub_lock = Lock() + self._authenticator = None def start(self): """Start the publisher.""" - self.publish_socket = get_context().socket(zmq.PUB) - _set_tcp_keepalive(self.publish_socket) - - self._bind() + self._create_socket() LOGGER.info(f"Publisher for {self.destination} started on port {self.port_number}.") + return self - def _bind(self): - # Check for port 0 (random port) - u__ = urlsplit(self.destination) - port = u__.port - if port == 0: - dest = urlunsplit((u__.scheme, u__.hostname, - u__.path, u__.query, u__.fragment)) - self.port_number = self.publish_socket.bind_to_random_port( - dest, - min_port=self.min_port, - max_port=self.max_port) - netloc = u__.hostname + ":" + str(self.port_number) - self.destination = urlunsplit((u__.scheme, netloc, u__.path, - u__.query, u__.fragment)) - else: - self.publish_socket.bind(self.destination) - self.port_number = port + def _create_socket(self): + options = get_tcp_keepalive_options() + self.publish_socket, port, self._authenticator = set_up_server_socket(zmq.PUB, self.destination, options, + (self.min_port, self.max_port)) + self.port_number = port def send(self, msg): """Send the given message.""" @@ -69,61 +56,5 @@ def stop(self): """Stop the publisher.""" self.publish_socket.setsockopt(zmq.LINGER, 1) self.publish_socket.close() - - -class SecureZMQPublisher(UnsecureZMQPublisher): - """Secure ZMQ implementation of the publisher class.""" - - def __init__(self, *args, server_secret_key=None, public_keys_directory=None, authorized_sub_addresses=None, **kwargs): # noqa - """Set up the secure ZMQ publisher. - - Args: - address: the address to connect to. - server_secret_key: the secret key for this publisher. - public_keys_directory: the directory containing the public keys of the subscribers that are allowed to - connect. - authorized_sub_addresses: the list of addresse allowed to subscibe to this publisher. By default, all are - allowed. - kwargs: passed to the underlying UnsecureZMQPublisher instance. - - """ - if server_secret_key is None: - raise TypeError("Missing server_secret_key argument.") - if public_keys_directory is None: - raise TypeError("Missing public_keys_directory argument.") - self._server_secret_key = server_secret_key - self._authorized_sub_addresses = authorized_sub_addresses or [] - self._pub_keys_dir = public_keys_directory - self._authenticator = None - - super().__init__(*args, **kwargs) - - def start(self): - """Start the publisher.""" - ctx = get_context() - - # Start an authenticator for this context. - auth = ThreadAuthenticator(ctx) - auth.start() - auth.allow(*self._authorized_sub_addresses) - # Tell authenticator to use the certificate in a directory - auth.configure_curve(domain="*", location=self._pub_keys_dir) - self._authenticator = auth - - self.publish_socket = ctx.socket(zmq.PUB) - - server_public, server_secret =zmq.auth.load_certificate(self._server_secret_key) - self.publish_socket.curve_secretkey = server_secret - self.publish_socket.curve_publickey = server_public - self.publish_socket.curve_server = True - - _set_tcp_keepalive(self.publish_socket) - - self._bind() - LOGGER.info(f"Publisher for {self.destination} started on port {self.port_number}.") - return self - - def stop(self): - """Stop the publisher.""" - super().stop() - self._authenticator.stop() + with suppress(AttributeError): + self._authenticator.stop() diff --git a/posttroll/backends/zmq/socket.py b/posttroll/backends/zmq/socket.py new file mode 100644 index 0000000..16ae3da --- /dev/null +++ b/posttroll/backends/zmq/socket.py @@ -0,0 +1,106 @@ +from posttroll import get_context, config +import zmq +from zmq.auth.thread import ThreadAuthenticator +from urllib.parse import urlsplit, urlunsplit + + + +def set_up_client_socket(socket_type, address, options=None): + backend = config["backend"] + if backend == "unsecure_zmq": + sock = create_unsecure_client_socket(socket_type) + elif backend == "secure_zmq": + sock = create_secure_client_socket(socket_type) + add_options(sock, options) + sock.connect(address) + return sock + + +def create_unsecure_client_socket(socket_type): + return get_context().socket(socket_type) + + +def add_options(sock, options=None): + if not options: + return + for param, val in options.items(): + sock.setsockopt(param, val) + + +def create_secure_client_socket(socket_type): + subscriber = get_context().socket(socket_type) + + client_secret_key_file = config["client_secret_key_file"] + server_public_key_file = config["server_public_key_file"] + client_public, client_secret = zmq.auth.load_certificate(client_secret_key_file) + subscriber.curve_secretkey = client_secret + subscriber.curve_publickey = client_public + + server_public, _ = zmq.auth.load_certificate(server_public_key_file) + # The client must know the server's public key to make a CURVE connection. + subscriber.curve_serverkey = server_public + return subscriber + + +def set_up_server_socket(socket_type, destination, options=None, port_interval=(None, None)): + if options is None: + options = {} + backend = config["backend"] + if backend == "unsecure_zmq": + sock = create_unsecure_server_socket(socket_type) + authenticator = None + elif backend == "secure_zmq": + sock, authenticator = create_secure_server_socket(socket_type) + + add_options(sock, options) + + port = bind(sock, destination, port_interval) + return sock, port, authenticator + + +def create_unsecure_server_socket(socket_type): + return get_context().socket(socket_type) + + +def bind(sock, destination, port_interval): + # Check for port 0 (random port) + min_port, max_port = port_interval + u__ = urlsplit(destination) + port = u__.port + if port == 0: + dest = urlunsplit((u__.scheme, u__.hostname, + u__.path, u__.query, u__.fragment)) + port_number = sock.bind_to_random_port(dest, + min_port=min_port, + max_port=max_port) + netloc = u__.hostname + ":" + str(port_number) + destination = urlunsplit((u__.scheme, netloc, u__.path, + u__.query, u__.fragment)) + else: + sock.bind(destination) + port_number = port + return port_number + + +def create_secure_server_socket(socket_type): + server_secret_key = config["server_secret_key_file"] + clients_public_keys_directory = config["clients_public_keys_directory"] + authorized_sub_addresses = config.get("authorized_client_addresses", []) + + ctx = get_context() + + # Start an authenticator for this context. + authenticator_thread = ThreadAuthenticator(ctx) + authenticator_thread.start() + authenticator_thread.allow(*authorized_sub_addresses) + # Tell authenticator to use the certificate in a directory + authenticator_thread.configure_curve(domain="*", location=clients_public_keys_directory) + + + server_socket = ctx.socket(socket_type) + + server_public, server_secret =zmq.auth.load_certificate(server_secret_key) + server_socket.curve_secretkey = server_secret + server_socket.curve_publickey = server_public + server_socket.curve_server = True + return server_socket, authenticator_thread diff --git a/posttroll/backends/zmq/subscriber.py b/posttroll/backends/zmq/subscriber.py index 5b04e4d..8186f69 100644 --- a/posttroll/backends/zmq/subscriber.py +++ b/posttroll/backends/zmq/subscriber.py @@ -5,15 +5,15 @@ from time import sleep from urllib.parse import urlsplit -from zmq import LINGER, NOBLOCK, POLLIN, PULL, SUB, SUBSCRIBE, Poller, ZMQError +from zmq import LINGER, PULL, SUB, SUBSCRIBE, ZMQError +from posttroll.backends.zmq.socket import set_up_client_socket -from posttroll.backends.zmq import _set_tcp_keepalive, get_context -from posttroll.message import Message +from posttroll.backends.zmq import SocketReceiver, get_tcp_keepalive_options LOGGER = logging.getLogger(__name__) -class _ZMQSubscriber: +class ZMQSubscriber: def __init__(self, addresses, topics="", message_filter=None, translate=False): """Initialize the subscriber.""" @@ -27,7 +27,8 @@ def __init__(self, addresses, topics="", message_filter=None, translate=False): self._hooks = [] self._hooks_cb = {} - self.poller = Poller() + #self.poller = Poller() + self._sock_receiver = SocketReceiver() self._lock = Lock() self.update(addresses) @@ -68,8 +69,8 @@ def remove(self, address): self._remove_sub_socket(subscriber) def _remove_sub_socket(self, subscriber): - if self.poller: - self.poller.unregister(subscriber) + if self._sock_receiver: + self._sock_receiver.unregister(subscriber) subscriber.close() def update(self, addresses): @@ -107,10 +108,9 @@ def add_hook_pull(self, address, callback): specified subscription. Good for pushed 'inproc' messages from another thread. """ LOGGER.info("Subscriber adding PULL hook %s", str(address)) - socket = get_context().socket(PULL) - socket.connect(address) - if self.poller: - self.poller.register(socket, POLLIN) + socket = self._create_socket(PULL, address) + if self._sock_receiver: + self._sock_receiver.register(socket) self._add_hook(socket, callback) def _add_hook(self, socket, callback): @@ -131,11 +131,9 @@ def subscribers(self): def recv(self, timeout=None): """Receive, optionally with *timeout* in seconds.""" - if timeout: - timeout *= 1000. for sub in list(self.subscribers) + self._hooks: - self.poller.register(sub, POLLIN) + self._sock_receiver.register(sub) self._loop = True try: while self._loop: @@ -143,38 +141,33 @@ def recv(self, timeout=None): yield from self._new_messages(timeout) finally: for sub in list(self.subscribers) + self._hooks: - self.poller.unregister(sub) + self._sock_receiver.unregister(sub) + # self.poller.unregister(sub) def _new_messages(self, timeout): """Check for new messages to yield and pass to the callbacks.""" + all_subs = list(self.subscribers) + self._hooks try: - socks = dict(self.poller.poll(timeout=timeout)) - if socks: - for sub in self.subscribers: - if sub in socks and socks[sub] == POLLIN: - received = sub.recv_string(NOBLOCK) - m__ = Message.decode(received) - if not self._filter or self._filter(m__): - if self._translate: - url = urlsplit(self.sub_addr[sub]) - host = url[1].split(":")[0] - m__.sender = (m__.sender.split("@")[0] - + "@" + host) - yield m__ - - for sub in self._hooks: - if sub in socks and socks[sub] == POLLIN: - m__ = Message.decode(sub.recv_string(NOBLOCK)) - self._hooks_cb[sub](m__) - else: - # timeout - yield None + for m__, sock in self._sock_receiver.receive(*all_subs, timeout=timeout): + if sock in self.subscribers: + if not self._filter or self._filter(m__): + if self._translate: + url = urlsplit(self.sub_addr[sock]) + host = url[1].split(":")[0] + m__.sender = (m__.sender.split("@")[0] + + "@" + host) + yield m__ + elif sock in self._hooks: + self._hooks_cb[sock](m__) + except TimeoutError: + yield None except ZMQError as err: if self._loop: LOGGER.exception("Receive failed: %s", str(err)) + def __call__(self, **kwargs): """Handle calls with class instance.""" return self.recv(**kwargs) @@ -201,54 +194,21 @@ def __del__(self): except Exception: # noqa: E722 pass - -class UnsecureZMQSubscriber(_ZMQSubscriber): - """Unsecure ZMQ implementation of the subscriber.""" - def _add_sub_socket(self, address, topics): - subscriber = get_context().socket(SUB) - _set_tcp_keepalive(subscriber) - for t__ in topics: - subscriber.setsockopt_string(SUBSCRIBE, str(t__)) - subscriber.connect(address) - - if self.poller: - self.poller.register(subscriber, POLLIN) - return subscriber - - -class SecureZMQSubscriber(_ZMQSubscriber): - """Secure ZMQ implementation of the subscriber, using Curve.""" - def __init__(self, *args, client_secret_key_file=None, server_public_key_file=None, **kwargs): - """Initialize the subscriber.""" - if client_secret_key_file is None: - raise TypeError("Missing client_secret_key_file argument.") - if server_public_key_file is None: - raise TypeError("Missing server_public_key_file argument.") - self._client_secret_file = client_secret_key_file - self._server_public_key_file = server_public_key_file - - super().__init__(*args, **kwargs) - - def _add_sub_socket(self, address, topics): - import zmq.auth - subscriber = get_context().socket(SUB) + options = get_tcp_keepalive_options() - client_public, client_secret = zmq.auth.load_certificate(self._client_secret_file) - subscriber.curve_secretkey = client_secret - subscriber.curve_publickey = client_public + subscriber = self._create_socket(SUB, address, options) + add_subscriptions(subscriber, topics) - server_public, _ = zmq.auth.load_certificate(self._server_public_key_file) - # The client must know the server's public key to make a CURVE connection. - subscriber.curve_serverkey = server_public + if self._sock_receiver: + self._sock_receiver.register(subscriber) + return subscriber + def _create_socket(self, socket_type, address, options): + return set_up_client_socket(socket_type, address, options) - _set_tcp_keepalive(subscriber) - for t__ in topics: - subscriber.setsockopt_string(SUBSCRIBE, str(t__)) - subscriber.connect(address) - if self.poller: - self.poller.register(subscriber, POLLIN) - return subscriber +def add_subscriptions(socket, topics): + for t__ in topics: + socket.setsockopt_string(SUBSCRIBE, str(t__)) diff --git a/posttroll/bbmcast.py b/posttroll/bbmcast.py index da759f5..c2cf7b3 100644 --- a/posttroll/bbmcast.py +++ b/posttroll/bbmcast.py @@ -68,6 +68,14 @@ SocketTimeout = timeout # for easy access to socket.timeout +DEFAULT_BROADCAST_PORT = 21200 + +def get_configured_broadcast_port(): + """Get the configured nameserver port.""" + return config.get("broadcast_port", DEFAULT_BROADCAST_PORT) + + + # ----------------------------------------------------------------------------- # # Sender. @@ -139,7 +147,7 @@ def get_mc_group(): # ----------------------------------------------------------------------------- -class MulticastReceiver(object): +class MulticastReceiver: """Multicast receiver on *port* for an *mcgroup*.""" BUFSIZE = 1024 diff --git a/posttroll/message_broadcaster.py b/posttroll/message_broadcaster.py index b3e1501..d72dd4c 100644 --- a/posttroll/message_broadcaster.py +++ b/posttroll/message_broadcaster.py @@ -28,23 +28,20 @@ import threading from posttroll import config, message -from posttroll.bbmcast import MulticastSender +from posttroll.bbmcast import MulticastSender, get_configured_broadcast_port __all__ = ("MessageBroadcaster", "AddressBroadcaster", "sendaddress") LOGGER = logging.getLogger(__name__) -broadcast_port = 21200 - - class DesignatedReceiversSender: """Sends message to multiple *receivers* on *port*.""" def __init__(self, default_port, receivers): """Set settings.""" backend = config.get("backend", "unsecure_zmq") if backend == "unsecure_zmq": - from posttroll.backends.zmq.message_broadcaster import UnsecureZMQDesignatedReceiversSender - self._sender = UnsecureZMQDesignatedReceiversSender(default_port, receivers) + from posttroll.backends.zmq.message_broadcaster import ZMQDesignatedReceiversSender + self._sender = ZMQDesignatedReceiversSender(default_port, receivers) def __call__(self, data): """Send messages from all receivers.""" @@ -61,7 +58,7 @@ def close(self): # ---------------------------------------------------------------------------- -class MessageBroadcaster(object): +class MessageBroadcaster: """Class to broadcast stuff. If *interval* is 0 or negative, no broadcasting is done. @@ -135,7 +132,7 @@ def __init__(self, name, address, interval, nameservers): """Set up the Address broadcaster.""" msg = message.Message("/address/%s" % name, "info", {"URI": "%s:%d" % address}).encode() - MessageBroadcaster.__init__(self, msg, broadcast_port, interval, + MessageBroadcaster.__init__(self, msg, get_configured_broadcast_port(), interval, nameservers) @@ -158,7 +155,7 @@ def __init__(self, name, address, data_type, interval=2, nameservers=None): msg = message.Message("/address/%s" % name, "info", {"URI": address, "service": data_type}).encode() - MessageBroadcaster.__init__(self, msg, broadcast_port, interval, + MessageBroadcaster.__init__(self, msg, get_configured_broadcast_port(), interval, nameservers) diff --git a/posttroll/ns.py b/posttroll/ns.py index 9296dd2..0221bf1 100644 --- a/posttroll/ns.py +++ b/posttroll/ns.py @@ -88,10 +88,10 @@ def get_pub_address(name, timeout=10, nameserver="localhost"): timeout: how long to wait for an address, in seconds. nameserver: nameserver address to query the publishers from (default: localhost). """ - backend = config.get("backend", "unsecure_zmq") - if backend == "unsecure_zmq": - from posttroll.backends.zmq.ns import unsecure_zmq_get_pub_address - return unsecure_zmq_get_pub_address(name, timeout, nameserver) + if config["backend"] not in ["unsecure_zmq", "secure_zmq"]: + raise NotImplementedError(f"Did not recognize backend: {config['backend']}") + from posttroll.backends.zmq.ns import zmq_get_pub_address + return zmq_get_pub_address(name, timeout, nameserver) # Server part. @@ -116,10 +116,11 @@ def __init__(self, max_age=None, multicast_enabled=True, restrict_to_localhost=F self._max_age = max_age or dt.timedelta(minutes=10) self._multicast_enabled = multicast_enabled self._restrict_to_localhost = restrict_to_localhost - backend = config.get("backend", "unsecure_zmq") - if backend == "unsecure_zmq": - from posttroll.backends.zmq.ns import UnsecureZMQNameServer - self._ns = UnsecureZMQNameServer() + backend = config["backend"] + if backend not in ["unsecure_zmq", "secure_zmq"]: + raise NotImplementedError(f"Did not recognize backend: {backend}") + from posttroll.backends.zmq.ns import ZMQNameServer + self._ns = ZMQNameServer() def run(self, *args): """Run the listener and answer to requests.""" diff --git a/posttroll/publisher.py b/posttroll/publisher.py index 51a89c9..dee85cc 100644 --- a/posttroll/publisher.py +++ b/posttroll/publisher.py @@ -85,7 +85,7 @@ class Publisher: """ - def __init__(self, address, name="", min_port=None, max_port=None, **kwargs): + def __init__(self, address, name="", min_port=None, max_port=None): """Bind the publisher class to a port.""" # Limit port range or use the defaults when no port is defined # by the user @@ -95,17 +95,10 @@ def __init__(self, address, name="", min_port=None, max_port=None, **kwargs): self._heartbeat = None backend = config.get("backend", "unsecure_zmq") - if backend == "unsecure_zmq": - if kwargs: - raise TypeError(f"Unexpected keyword arguments: {kwargs}") - from posttroll.backends.zmq.publisher import UnsecureZMQPublisher - self._publisher = UnsecureZMQPublisher(address, name=name, min_port=min_port, max_port=max_port) - elif backend == "secure_zmq": - from posttroll.backends.zmq.publisher import SecureZMQPublisher - self._publisher = SecureZMQPublisher(address, name=name, min_port=min_port, max_port=max_port, - **kwargs) - else: + if backend not in ["unsecure_zmq", "secure_zmq"]: raise NotImplementedError(f"No support for backend {backend} implemented (yet?).") + from posttroll.backends.zmq.publisher import ZMQPublisher + self._publisher = ZMQPublisher(address, name=name, min_port=min_port, max_port=max_port) def start(self): """Start the publisher.""" diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py index 5528fa4..9a04008 100644 --- a/posttroll/subscriber.py +++ b/posttroll/subscriber.py @@ -60,23 +60,17 @@ class Subscriber: """ - def __init__(self, addresses, topics="", message_filter=None, translate=False, **kwargs): + def __init__(self, addresses, topics="", message_filter=None, translate=False): """Initialize the subscriber.""" topics = self._magickfy_topics(topics) backend = config.get("backend", "unsecure_zmq") - if backend == "unsecure_zmq": - if kwargs: - raise TypeError(f"Unexpected keyword arguments: {kwargs}") - from posttroll.backends.zmq.subscriber import UnsecureZMQSubscriber - self._subscriber = UnsecureZMQSubscriber(addresses, topics=topics, - message_filter=message_filter, translate=translate) - elif backend == "secure_zmq": - from posttroll.backends.zmq.subscriber import SecureZMQSubscriber - self._subscriber = SecureZMQSubscriber(addresses, topics=topics, - message_filter=message_filter, translate=translate, **kwargs) - else: + if backend not in ["unsecure_zmq", "secure_zmq"]: raise NotImplementedError(f"No support for backend {backend} implemented (yet?).") + from posttroll.backends.zmq.subscriber import ZMQSubscriber + self._subscriber = ZMQSubscriber(addresses, topics=topics, + message_filter=message_filter, translate=translate) + def add(self, address, topics=None): """Add *address* to the subscribing list for *topics*. diff --git a/posttroll/tests/test_nameserver.py b/posttroll/tests/test_nameserver.py new file mode 100644 index 0000000..123ea53 --- /dev/null +++ b/posttroll/tests/test_nameserver.py @@ -0,0 +1,253 @@ +"""Tests for communication involving the nameserver for service discovery.""" + +import os +import time +import unittest +from contextlib import contextmanager +from datetime import timedelta +from threading import Thread +from unittest import mock + +import pytest + +from posttroll import config +from posttroll.message import Message +from posttroll.ns import NameServer, get_pub_address +from posttroll.publisher import Publish +from posttroll.subscriber import Subscribe + + +def free_port(): + """Get a free port. + + From https://gist.github.com/bertjwregeer/0be94ced48383a42e70c3d9fff1f4ad0 + + Returns a factory that finds the next free port that is available on the OS + This is a bit of a hack, it does this by creating a new socket, and calling + bind with the 0 port. The operating system will assign a brand new port, + which we can find out using getsockname(). Once we have the new port + information we close the socket thereby returning it to the free pool. + This means it is technically possible for this function to return the same + port twice (for example if run in very quick succession), however operating + systems return a random port number in the default range (1024 - 65535), + and it is highly unlikely for two processes to get the same port number. + In other words, it is possible to flake, but incredibly unlikely. + """ + import socket + + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("0.0.0.0", 0)) + portnum = s.getsockname()[1] + s.close() + + return portnum + + +@contextmanager +def create_nameserver_instance(max_age=3, multicast_enabled=True): + """Create a nameserver instance.""" + config.set(nameserver_port=free_port()) + config.set(address_publish_port=free_port()) + ns = NameServer(max_age=timedelta(seconds=max_age), multicast_enabled=multicast_enabled) + thr = Thread(target=ns.run) + thr.start() + + try: + yield + finally: + ns.stop() + thr.join() + + + +class TestAddressReceiver(unittest.TestCase): + """Test the AddressReceiver.""" + + @mock.patch("posttroll.address_receiver.Message") + @mock.patch("posttroll.address_receiver.Publish") + @mock.patch("posttroll.address_receiver.MulticastReceiver") + def test_localhost_restriction(self, mcrec, pub, msg): + """Test address receiver restricted only to localhost.""" + mocked_publish_instance = mock.Mock() + pub.return_value.__enter__.return_value = mocked_publish_instance + mcr_instance = mock.Mock() + mcrec.return_value = mcr_instance + mcr_instance.return_value = "blabla", ("255.255.255.255", 12) + + from posttroll.address_receiver import AddressReceiver + adr = AddressReceiver(restrict_to_localhost=True) + adr.start() + time.sleep(3) + try: + msg.decode.assert_not_called() + mocked_publish_instance.send.assert_not_called() + finally: + adr.stop() + + +@pytest.mark.parametrize( + "multicast_enabled", + [True, False] +) +def test_pub_addresses(multicast_enabled): + """Test retrieving addresses.""" + from posttroll.ns import get_pub_addresses + from posttroll.publisher import Publish + + if multicast_enabled: + if os.getenv("DISABLED_MULTICAST"): + pytest.skip("Multicast tests disabled.") + nameservers = None + else: + nameservers = ["localhost"] + with config.set(broadcast_port=free_port()): + with create_nameserver_instance(multicast_enabled=multicast_enabled): + with Publish(str("data_provider"), 0, ["this_data"], nameservers=nameservers, broadcast_interval=0.1): + time.sleep(.3) + res = get_pub_addresses(["this_data"], timeout=.5) + assert len(res) == 1 + expected = {u"status": True, + u"service": [u"data_provider", u"this_data"], + u"name": u"address"} + for key, val in expected.items(): + assert res[0][key] == val + assert "receive_time" in res[0] + assert "URI" in res[0] + res = get_pub_addresses([str("data_provider")]) + assert len(res) == 1 + expected = {u"status": True, + u"service": [u"data_provider", u"this_data"], + u"name": u"address"} + for key, val in expected.items(): + assert res[0][key] == val + assert "receive_time" in res[0] + assert "URI" in res[0] + + +@pytest.mark.parametrize( + "multicast_enabled", + [True, False] +) +def test_pub_sub_ctx(multicast_enabled): + """Test publish and subscribe.""" + if multicast_enabled: + if os.getenv("DISABLED_MULTICAST"): + pytest.skip("Multicast tests disabled.") + nameservers = None + else: + nameservers = ["localhost"] + + with config.set(broadcast_port=free_port()): + with create_nameserver_instance(multicast_enabled=multicast_enabled): + with Publish("data_provider", 0, ["this_data"], nameservers=nameservers, broadcast_interval=0.1) as pub: + with Subscribe("this_data", "counter") as sub: + for counter in range(5): + message = Message("/counter", "info", str(counter)) + pub.send(str(message)) + time.sleep(.1) + msg = next(sub.recv(.2)) + if msg is not None: + assert str(msg) == str(message) + tested = True + assert tested + + +@pytest.mark.parametrize( + "multicast_enabled", + [True, False] +) +def test_pub_sub_add_rm(multicast_enabled): + """Test adding and removing publishers.""" + if multicast_enabled: + if os.getenv("DISABLED_MULTICAST"): + pytest.skip("Multicast tests disabled.") + nameservers = None + else: + nameservers = ["localhost"] + + max_age = 0.5 + with config.set(broadcast_port=free_port()): + with create_nameserver_instance(max_age=max_age, multicast_enabled=multicast_enabled): + with Subscribe("this_data", "counter", addr_listener=True, timeout=.2) as sub: + assert len(sub.addresses) == 0 + with Publish("data_provider", 0, ["this_data"], nameservers=nameservers): + time.sleep(.1) + next(sub.recv(.1)) + assert len(sub.addresses) == 1 + time.sleep(max_age * 4) + for msg in sub.recv(.1): + if msg is None: + break + time.sleep(.3) + assert len(sub.addresses) == 0 + with Publish("data_provider_2", 0, ["another_data"], nameservers=nameservers): + time.sleep(.1) + next(sub.recv(.1)) + assert len(sub.addresses) == 0 + + +@pytest.mark.skipif( + os.getenv("DISABLED_MULTICAST"), + reason="Multicast tests disabled.", +) +def test_listener_container(): + """Test listener container.""" + from posttroll.listener import ListenerContainer + from posttroll.message import Message + from posttroll.publisher import NoisyPublisher + + with create_nameserver_instance(): + pub = NoisyPublisher("test", broadcast_interval=0.1) + pub.start() + sub = ListenerContainer(topics=["/counter"]) + time.sleep(.1) + for counter in range(5): + tested = False + msg_out = Message("/counter", "info", str(counter)) + pub.send(str(msg_out)) + + msg_in = sub.output_queue.get(True, 1) + if msg_in is not None: + assert str(msg_in) == str(msg_out) + tested = True + assert tested + pub.stop() + sub.stop() + + +@pytest.mark.skipif( + os.getenv("DISABLED_MULTICAST"), + reason="Multicast tests disabled.", +) +def test_noisypublisher_heartbeat(): + """Test that the heartbeat in the NoisyPublisher works.""" + from posttroll.publisher import NoisyPublisher + from posttroll.subscriber import Subscribe + + ns_ = NameServer() + thr = Thread(target=ns_.run) + thr.start() + + pub = NoisyPublisher("test") + pub.start() + time.sleep(0.2) + + with Subscribe("test", topics="/heartbeat/test", nameserver="localhost") as sub: + time.sleep(0.2) + pub.heartbeat(min_interval=10) + msg = next(sub.recv(1)) + assert msg.type == "beat" + assert msg.data == {"min_interval": 10} + pub.stop() + ns_.stop() + thr.join() + + +def test_switch_backend_for_nameserver(): + """Test switching backend for nameserver.""" + with config.set(backend="spurious_backend"): + with pytest.raises(NotImplementedError): + NameServer() + with pytest.raises(NotImplementedError): + get_pub_address("some_name") diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index 2f88943..c72011d 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -23,12 +23,10 @@ """Test the publishing and subscribing facilities.""" -import os import time import unittest from contextlib import contextmanager -from datetime import timedelta -from threading import Lock, Thread +from threading import Lock from unittest import mock import pytest @@ -37,7 +35,6 @@ import posttroll from posttroll import config from posttroll.message import Message -from posttroll.ns import NameServer from posttroll.publisher import Publish, Publisher, create_publisher_from_dict_config from posttroll.subscriber import Subscribe, Subscriber @@ -71,125 +68,6 @@ def free_port(): return portnum -@contextmanager -def create_nameserver_instance(max_age=3, multicast_enabled=True): - """Create a nameserver instance.""" - config.set(nameserver_port=free_port()) - config.set(address_publish_port=free_port()) - ns = NameServer(max_age=timedelta(seconds=max_age), multicast_enabled=multicast_enabled) - thr = Thread(target=ns.run) - thr.start() - - try: - yield - finally: - ns.stop() - thr.join() - - -@pytest.mark.parametrize( - "multicast_enabled", - [True, False] -) -def test_pub_addresses(multicast_enabled): - """Test retrieving addresses.""" - from posttroll.ns import get_pub_addresses - from posttroll.publisher import Publish - - if multicast_enabled: - if os.getenv("DISABLED_MULTICAST"): - pytest.skip("Multicast tests disabled.") - nameservers = None - else: - nameservers = ["localhost"] - - - with create_nameserver_instance(multicast_enabled=multicast_enabled): - with Publish(str("data_provider"), 0, ["this_data"], nameservers=nameservers, broadcast_interval=0.1): - time.sleep(.3) - res = get_pub_addresses(["this_data"], timeout=.5) - assert len(res) == 1 - expected = {u"status": True, - u"service": [u"data_provider", u"this_data"], - u"name": u"address"} - for key, val in expected.items(): - assert res[0][key] == val - assert "receive_time" in res[0] - assert "URI" in res[0] - res = get_pub_addresses([str("data_provider")]) - assert len(res) == 1 - expected = {u"status": True, - u"service": [u"data_provider", u"this_data"], - u"name": u"address"} - for key, val in expected.items(): - assert res[0][key] == val - assert "receive_time" in res[0] - assert "URI" in res[0] - - -@pytest.mark.parametrize( - "multicast_enabled", - [True, False] -) -def test_pub_sub_ctx(multicast_enabled): - """Test publish and subscribe.""" - if multicast_enabled: - if os.getenv("DISABLED_MULTICAST"): - pytest.skip("Multicast tests disabled.") - nameservers = None - else: - nameservers = ["localhost"] - - with create_nameserver_instance(multicast_enabled=multicast_enabled): - with Publish("data_provider", 0, ["this_data"], nameservers=nameservers, broadcast_interval=0.1) as pub: - with Subscribe("this_data", "counter") as sub: - for counter in range(5): - message = Message("/counter", "info", str(counter)) - pub.send(str(message)) - time.sleep(.1) - msg = next(sub.recv(.2)) - if msg is not None: - assert str(msg) == str(message) - tested = True - sub.close() - assert tested - - -@pytest.mark.parametrize( - "multicast_enabled", - [True, False] -) -def test_pub_sub_add_rm(multicast_enabled): - """Test adding and removing publishers.""" - if multicast_enabled: - if os.getenv("DISABLED_MULTICAST"): - pytest.skip("Multicast tests disabled.") - nameservers = None - else: - nameservers = ["localhost"] - - max_age = 0.5 - - with create_nameserver_instance(max_age=max_age, multicast_enabled=multicast_enabled): - with Subscribe("this_data", "counter", addr_listener=True, timeout=.2) as sub: - assert len(sub.addresses) == 0 - with Publish("data_provider", 0, ["this_data"], nameservers=nameservers): - time.sleep(.1) - next(sub.recv(.1)) - assert len(sub.addresses) == 1 - time.sleep(max_age * 4) - for msg in sub.recv(.1): - if msg is None: - break - time.sleep(.3) - assert len(sub.addresses) == 0 - with Publish("data_provider_2", 0, ["another_data"], nameservers=nameservers): - time.sleep(.1) - next(sub.recv(.1)) - assert len(sub.addresses) == 0 - sub.close() - - class TestPubSub(unittest.TestCase): """Testing the publishing and subscribing capabilities.""" @@ -317,35 +195,6 @@ def _get_port_from_publish_instance(min_port=None, max_port=None): return False -@pytest.mark.skipif( - os.getenv("DISABLED_MULTICAST"), - reason="Multicast tests disabled.", -) -def test_listener_container(): - """Test listener container.""" - from posttroll.listener import ListenerContainer - from posttroll.message import Message - from posttroll.publisher import NoisyPublisher - - with create_nameserver_instance(): - pub = NoisyPublisher("test", broadcast_interval=0.1) - pub.start() - sub = ListenerContainer(topics=["/counter"]) - time.sleep(.1) - for counter in range(5): - tested = False - msg_out = Message("/counter", "info", str(counter)) - pub.send(str(msg_out)) - - msg_in = sub.output_queue.get(True, 1) - if msg_in is not None: - assert str(msg_in) == str(msg_out) - tested = True - assert tested - pub.stop() - sub.stop() - - class TestListenerContainerNoNameserver(unittest.TestCase): """Testing listener container without nameserver.""" @@ -382,27 +231,6 @@ def test_listener_container(self): sub.stop() -class TestAddressReceiver(unittest.TestCase): - """Test the AddressReceiver.""" - - @mock.patch("posttroll.address_receiver.Message") - @mock.patch("posttroll.address_receiver.Publish") - @mock.patch("posttroll.address_receiver.MulticastReceiver") - def test_localhost_restriction(self, mcrec, pub, msg): - """Test address receiver restricted only to localhost.""" - mcr_instance = mock.Mock() - mcrec.return_value = mcr_instance - mcr_instance.return_value = "blabla", ("255.255.255.255", 12) - from posttroll.address_receiver import AddressReceiver - adr = AddressReceiver(restrict_to_localhost=True) - adr.start() - time.sleep(3) - msg.decode.assert_not_called() - adr.stop() - - - - ## Test create_publisher_from_config def test_publisher_with_invalid_arguments_crashes(): @@ -603,8 +431,8 @@ def _tcp_keepalive_no_settings(): @pytest.mark.usefixtures("_tcp_keepalive_settings") def test_publisher_tcp_keepalive(): """Test that TCP Keepalive is set for Publisher if the environment variables are present.""" - from posttroll.backends.zmq.publisher import UnsecureZMQPublisher - pub = UnsecureZMQPublisher(f"tcp://127.0.0.1:{str(free_port())}").start() + from posttroll.backends.zmq.publisher import ZMQPublisher + pub = ZMQPublisher(f"tcp://127.0.0.1:{str(free_port())}").start() _assert_tcp_keepalive(pub.publish_socket) pub.stop() @@ -612,8 +440,8 @@ def test_publisher_tcp_keepalive(): @pytest.mark.usefixtures("_tcp_keepalive_no_settings") def test_publisher_tcp_keepalive_not_set(): """Test that TCP Keepalive is not set on by default.""" - from posttroll.backends.zmq.publisher import UnsecureZMQPublisher - pub = UnsecureZMQPublisher(f"tcp://127.0.0.1:{str(free_port())}").start() + from posttroll.backends.zmq.publisher import ZMQPublisher + pub = ZMQPublisher(f"tcp://127.0.0.1:{str(free_port())}").start() _assert_no_tcp_keepalive(pub.publish_socket) pub.stop() @@ -621,8 +449,8 @@ def test_publisher_tcp_keepalive_not_set(): @pytest.mark.usefixtures("_tcp_keepalive_settings") def test_subscriber_tcp_keepalive(): """Test that TCP Keepalive is set for Subscriber if the environment variables are present.""" - from posttroll.backends.zmq.subscriber import UnsecureZMQSubscriber - sub = UnsecureZMQSubscriber(f"tcp://127.0.0.1:{str(free_port())}") + from posttroll.backends.zmq.subscriber import ZMQSubscriber + sub = ZMQSubscriber(f"tcp://127.0.0.1:{str(free_port())}") assert len(sub.addr_sub.values()) == 1 _assert_tcp_keepalive(list(sub.addr_sub.values())[0]) sub.stop() @@ -631,8 +459,8 @@ def test_subscriber_tcp_keepalive(): @pytest.mark.usefixtures("_tcp_keepalive_no_settings") def test_subscriber_tcp_keepalive_not_set(): """Test that TCP Keepalive is not set on by default.""" - from posttroll.backends.zmq.subscriber import UnsecureZMQSubscriber - sub = UnsecureZMQSubscriber(f"tcp://127.0.0.1:{str(free_port())}") + from posttroll.backends.zmq.subscriber import ZMQSubscriber + sub = ZMQSubscriber(f"tcp://127.0.0.1:{str(free_port())}") assert len(sub.addr_sub.values()) == 1 _assert_no_tcp_keepalive(list(sub.addr_sub.values())[0]) sub.close() @@ -656,35 +484,6 @@ def _assert_no_tcp_keepalive(socket): assert socket.getsockopt(zmq.TCP_KEEPALIVE_INTVL) == -1 -@pytest.mark.skipif( - os.getenv("DISABLED_MULTICAST"), - reason="Multicast tests disabled.", -) -def test_noisypublisher_heartbeat(): - """Test that the heartbeat in the NoisyPublisher works.""" - from posttroll.ns import NameServer - from posttroll.publisher import NoisyPublisher - from posttroll.subscriber import Subscribe - - ns_ = NameServer() - thr = Thread(target=ns_.run) - thr.start() - - pub = NoisyPublisher("test") - pub.start() - time.sleep(0.2) - - with Subscribe("test", topics="/heartbeat/test", nameserver="localhost") as sub: - time.sleep(0.2) - pub.heartbeat(min_interval=10) - msg = next(sub.recv(1)) - assert msg.type == "beat" - assert msg.data == {"min_interval": 10} - pub.stop() - ns_.stop() - thr.join() - - def test_switch_to_unknown_backend(): """Test switching to unknown backend.""" from posttroll.publisher import Publisher diff --git a/posttroll/tests/test_secure_zmq_backend.py b/posttroll/tests/test_secure_zmq_backend.py index 4c6a903..38d442e 100644 --- a/posttroll/tests/test_secure_zmq_backend.py +++ b/posttroll/tests/test_secure_zmq_backend.py @@ -10,6 +10,8 @@ from posttroll import config from posttroll.publisher import Publisher from posttroll.subscriber import Subscriber, create_subscriber_from_dict_config +from posttroll.tests.test_nameserver import create_nameserver_instance +from posttroll.ns import get_pub_address def create_keys(tmp_path): @@ -48,21 +50,21 @@ def create_keys(tmp_path): def test_ipc_pubsub_with_sec(tmp_path): """Test pub-sub on a secure ipc socket.""" - server_public_key, server_secret_key = zmq.auth.create_certificates(tmp_path, "server") - client_public_key, client_secret_key = zmq.auth.create_certificates(tmp_path, "client") + server_public_key_file, server_secret_key_file = zmq.auth.create_certificates(tmp_path, "server") + client_public_key_file, client_secret_key_file = zmq.auth.create_certificates(tmp_path, "client") ipc_address = f"ipc://{str(tmp_path)}/bla.ipc" - with config.set(backend="secure_zmq"): - subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202, - client_secret_key_file=client_secret_key, - server_public_key_file=server_public_key) + with config.set(backend="secure_zmq", + client_secret_key_file=client_secret_key_file, + clients_public_keys_directory=os.path.dirname(client_public_key_file), + server_public_key_file=server_public_key_file, + server_secret_key_file=server_secret_key_file): + subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202) sub = create_subscriber_from_dict_config(subscriber_settings) from posttroll.publisher import Publisher - pub = Publisher(ipc_address, - server_secret_key=server_secret_key, - public_keys_directory=os.path.dirname(client_public_key)) + pub = Publisher(ipc_address) pub.start() @@ -94,38 +96,37 @@ def test_switch_to_secure_zmq_backend(tmp_path): server_secret_key = secret_keys_dir / "server.key_secret" public_keys_directory = public_keys_dir - publisher_key_args = dict(server_secret_key=server_secret_key, - public_keys_directory=public_keys_directory) client_secret_key = secret_keys_dir / "client.key_secret" server_public_key = public_keys_dir / "server.key" - subscriber_key_args = dict(client_secret_key_file=client_secret_key, - server_public_key_file=server_public_key) - with config.set(backend="secure_zmq"): - Publisher("ipc://bla.ipc", **publisher_key_args) - Subscriber("ipc://bla.ipc", **subscriber_key_args) + with config.set(backend="secure_zmq", + client_secret_key_file=client_secret_key, + clients_public_keys_directory=public_keys_directory, + server_public_key_file=server_public_key, + server_secret_key_file=server_secret_key): + Publisher("ipc://bla.ipc") + Subscriber("ipc://bla.ipc") def test_ipc_pubsub_with_sec_and_factory_sub(tmp_path): """Test pub-sub on a secure ipc socket.""" - base_dir = tmp_path - public_keys_dir = base_dir / "public_keys" - secret_keys_dir = base_dir / "private_keys" + #create_keys(tmp_path) - create_keys(tmp_path) + server_public_key_file, server_secret_key_file = zmq.auth.create_certificates(tmp_path, "server") + client_public_key_file, client_secret_key_file = zmq.auth.create_certificates(tmp_path, "client") ipc_address = f"ipc://{str(tmp_path)}/bla.ipc" - with config.set(backend="secure_zmq"): - subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202, - client_secret_key_file=secret_keys_dir / "client.key_secret", - server_public_key_file=public_keys_dir / "server.key") + with config.set(backend="secure_zmq", + client_secret_key_file=client_secret_key_file, + clients_public_keys_directory=os.path.dirname(client_public_key_file), + server_public_key_file=server_public_key_file, + server_secret_key_file=server_secret_key_file): + subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202) sub = create_subscriber_from_dict_config(subscriber_settings) from posttroll.publisher import create_publisher_from_dict_config pub_settings = dict(address=ipc_address, - server_secret_key=secret_keys_dir / "server.key_secret", - public_keys_directory=public_keys_dir, nameservers=False, port=1789) pub = create_publisher_from_dict_config(pub_settings) @@ -146,3 +147,17 @@ def delayed_send(msg): sub.stop() thr.join() pub.stop() + +def test_switch_to_secure_backend_for_nameserver(tmp_path): + """Test switching backend for nameserver.""" + server_public_key_file, server_secret_key_file = zmq.auth.create_certificates(tmp_path, "server") + client_public_key_file, client_secret_key_file = zmq.auth.create_certificates(tmp_path, "client") + with config.set(backend="secure_zmq", + client_secret_key_file=client_secret_key_file, + clients_public_keys_directory=os.path.dirname(client_public_key_file), + server_public_key_file=server_public_key_file, + server_secret_key_file=server_secret_key_file): + + with create_nameserver_instance(): + res = get_pub_address("some_name") + assert res == ""