From 97b2f94e2069b3ff878ae99bd65b3489f2b768c7 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Wed, 22 May 2024 10:14:47 +0200 Subject: [PATCH] Fix style --- posttroll/__init__.py | 2 +- posttroll/address_receiver.py | 12 +++++------- posttroll/backends/zmq/ns.py | 6 ++---- posttroll/backends/zmq/socket.py | 3 +-- posttroll/backends/zmq/subscriber.py | 5 ----- posttroll/bbmcast.py | 7 ++++--- posttroll/message.py | 3 +++ posttroll/message_broadcaster.py | 3 ++- posttroll/subscriber.py | 1 + posttroll/tests/test_bbmcast.py | 13 ++++++++----- posttroll/tests/test_message.py | 12 ------------ posttroll/tests/test_nameserver.py | 1 - posttroll/tests/test_pubsub.py | 14 ++++++++++++-- posttroll/tests/test_secure_zmq_backend.py | 11 ++++++----- posttroll/tests/test_unsecure_zmq_backend.py | 1 + 15 files changed, 46 insertions(+), 48 deletions(-) diff --git a/posttroll/__init__.py b/posttroll/__init__.py index 46b0eaf..df053e3 100644 --- a/posttroll/__init__.py +++ b/posttroll/__init__.py @@ -68,5 +68,5 @@ def strp_isoformat(strg): else: dat, mis = strg.split(".") dat = dt.datetime.strptime(dat, "%Y-%m-%dT%H:%M:%S") - mis = int(float("." + mis)*1000000) + mis = int(float("." + mis) * 1000000) return dat.replace(microsecond=mis) diff --git a/posttroll/address_receiver.py b/posttroll/address_receiver.py index 6ca006e..42c0d4c 100644 --- a/posttroll/address_receiver.py +++ b/posttroll/address_receiver.py @@ -38,7 +38,7 @@ from zmq import ZMQError from posttroll import config -from posttroll.bbmcast import MulticastReceiver, SocketTimeout, get_configured_broadcast_port +from posttroll.bbmcast import MulticastReceiver, get_configured_broadcast_port from posttroll.message import Message from posttroll.publisher import Publish @@ -57,6 +57,7 @@ def get_configured_address_port(): return config.get("address_publish_port", DEFAULT_ADDRESS_PUBLISH_PORT) + def get_local_ips(): """Get local IP addresses.""" inet_addrs = [netifaces.ifaddresses(iface).get(netifaces.AF_INET) @@ -162,14 +163,11 @@ def _run(self): data, fromaddr = recv() except TimeoutError: if self._do_run: + if self._multicast_enabled: + LOGGER.debug("Multicast socket timed out on recv!") continue else: raise - - except SocketTimeout: - if self._multicast_enabled: - LOGGER.debug("Multicast socket timed out on recv!") - continue except ZMQError: return finally: @@ -229,7 +227,7 @@ def set_up_address_receiver(self, port): from posttroll.backends.zmq.address_receiver import SimpleReceiver recv = SimpleReceiver(port, timeout=2) nameservers = ["localhost"] - return nameservers,recv + return nameservers, recv def _add(self, adr, metadata): """Add an address.""" diff --git a/posttroll/backends/zmq/ns.py b/posttroll/backends/zmq/ns.py index 5827920..dc0fcfb 100644 --- a/posttroll/backends/zmq/ns.py +++ b/posttroll/backends/zmq/ns.py @@ -43,10 +43,8 @@ def _fetch_address_using_socket(socket, name, timeout): socket.send_string(str(message)) # Get the reply. - #socket.poll(timeout) - #message = socket.recv(timeout) for message, _ in socket_receiver.receive(socket, timeout=timeout): - return message.data + return message.data except TimeoutError: raise TimeoutError("Didn't get an address after %d seconds." % timeout) @@ -61,6 +59,7 @@ def create_req_socket(timeout, nameserver_address): socket = set_up_client_socket(REQ, nameserver_address, options) return socket + class ZMQNameServer: """The name server.""" @@ -104,7 +103,6 @@ def close_sockets_and_threads(self): with suppress(AttributeError): self._authenticator.stop() - def stop(self): """Stop the name server.""" self.running = False diff --git a/posttroll/backends/zmq/socket.py b/posttroll/backends/zmq/socket.py index 132bb08..7adb295 100644 --- a/posttroll/backends/zmq/socket.py +++ b/posttroll/backends/zmq/socket.py @@ -117,10 +117,9 @@ def create_secure_server_socket(socket_type): # Tell authenticator to use the certificate in a directory authenticator_thread.configure_curve(domain="*", location=clients_public_keys_directory) - server_socket = ctx.socket(socket_type) - server_public, server_secret =zmq.auth.load_certificate(server_secret_key) + server_public, server_secret = zmq.auth.load_certificate(server_secret_key) server_socket.curve_secretkey = server_secret server_socket.curve_publickey = server_public server_socket.curve_server = True diff --git a/posttroll/backends/zmq/subscriber.py b/posttroll/backends/zmq/subscriber.py index afdaebd..836f590 100644 --- a/posttroll/backends/zmq/subscriber.py +++ b/posttroll/backends/zmq/subscriber.py @@ -28,7 +28,6 @@ def __init__(self, addresses, topics="", message_filter=None, translate=False): self._hooks = [] self._hooks_cb = {} - #self.poller = Poller() self._sock_receiver = SocketReceiver() self._lock = Lock() @@ -119,7 +118,6 @@ def _add_hook(self, socket, callback): self._hooks.append(socket) self._hooks_cb[socket] = callback - @property def addresses(self): """Get the addresses.""" @@ -165,9 +163,6 @@ def _new_messages(self, timeout): if self._loop: LOGGER.exception("Receive failed: %s", str(err)) - - - def __call__(self, **kwargs): """Handle calls with class instance.""" return self.recv(**kwargs) diff --git a/posttroll/bbmcast.py b/posttroll/bbmcast.py index c2cf7b3..d9f3ae0 100644 --- a/posttroll/bbmcast.py +++ b/posttroll/bbmcast.py @@ -70,12 +70,12 @@ DEFAULT_BROADCAST_PORT = 21200 + def get_configured_broadcast_port(): """Get the configured nameserver port.""" return config.get("broadcast_port", DEFAULT_BROADCAST_PORT) - # ----------------------------------------------------------------------------- # # Sender. @@ -114,8 +114,8 @@ def mcast_sender(mcgroup=None): if _is_broadcast_group(mcgroup): group = "" sock.setsockopt(SOL_SOCKET, SO_BROADCAST, 1) - elif((int(mcgroup.split(".")[0]) > 239) or - (int(mcgroup.split(".")[0]) < 224)): + elif ((int(mcgroup.split(".")[0]) > 239) or + (int(mcgroup.split(".")[0]) < 224)): raise IOError(f"Invalid multicast address {mcgroup}") else: group = mcgroup @@ -130,6 +130,7 @@ def mcast_sender(mcgroup=None): raise return sock, group + def get_mc_group(): try: mcgroup = os.environ["PYTROLL_MC_GROUP"] diff --git a/posttroll/message.py b/posttroll/message.py index 541c0af..ab68484 100644 --- a/posttroll/message.py +++ b/posttroll/message.py @@ -282,12 +282,14 @@ def _decode(rawstr): return msg + def _check_for_version(raw): version = raw[4][:len(_VERSION)] if not _is_valid_version(version): raise MessageError("Invalid Message version: '%s'" % str(version)) return version + def _check_for_element_count(rawstr): raw = re.split(r"\s+", rawstr, maxsplit=6) if len(raw) < 5: @@ -296,6 +298,7 @@ def _check_for_element_count(rawstr): return raw + def _check_for_magic_word(rawstr): """Check for the magick word.""" try: diff --git a/posttroll/message_broadcaster.py b/posttroll/message_broadcaster.py index d72dd4c..4990c36 100644 --- a/posttroll/message_broadcaster.py +++ b/posttroll/message_broadcaster.py @@ -34,6 +34,7 @@ LOGGER = logging.getLogger(__name__) + class DesignatedReceiversSender: """Sends message to multiple *receivers* on *port*.""" def __init__(self, default_port, receivers): @@ -51,7 +52,7 @@ def close(self): """Close the sender.""" return self._sender.close() -#----------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- # # General thread to broadcast messages. # diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py index c32ccbf..fc3a8c1 100644 --- a/posttroll/subscriber.py +++ b/posttroll/subscriber.py @@ -155,6 +155,7 @@ def _magickfy_topics(topics): ts_.append(t__) return ts_ + class NSSubscriber: """Automatically subscribe to *services*. diff --git a/posttroll/tests/test_bbmcast.py b/posttroll/tests/test_bbmcast.py index 1a47b40..b61b26c 100644 --- a/posttroll/tests/test_bbmcast.py +++ b/posttroll/tests/test_bbmcast.py @@ -48,6 +48,7 @@ def test_mcast_sender_works_with_valid_addresses(): socket.close() + def test_mcast_sender_uses_broadcast_for_0s(): """Test mcast_sender uses broadcast for 0.0.0.0.""" mcgroup = "0.0.0.0" @@ -56,6 +57,7 @@ def test_mcast_sender_uses_broadcast_for_0s(): assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1 socket.close() + def test_mcast_sender_uses_broadcast_for_255s(): """Test mcast_sender uses broadcast for 255.255.255.255.""" mcgroup = "255.255.255.255" @@ -64,6 +66,7 @@ def test_mcast_sender_uses_broadcast_for_255s(): assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1 socket.close() + def test_mcast_sender_raises_for_invalit_adresses(): """Test mcast_sender uses broadcast for 0.0.0.0.""" mcgroup = (str(random.randint(0, 223)) + "." + @@ -78,7 +81,7 @@ def test_mcast_sender_raises_for_invalit_adresses(): str(random.randint(0, 255)) + "." + str(random.randint(0, 255))) with pytest.raises(OSError, match="Invalid multicast address .*"): - bbmcast.mcast_sender(mcgroup) + bbmcast.mcast_sender(mcgroup) def test_mcast_receiver_works_with_valid_addresses(): @@ -126,7 +129,7 @@ def test_multicast_roundtrip(reraise): """Test sending and receiving a multicast message.""" mcgroup = bbmcast.DEFAULT_MC_GROUP mcport = 5555 - rec_socket, rec_group = bbmcast.mcast_receiver(mcport, mcgroup) + rec_socket, _rec_group = bbmcast.mcast_receiver(mcport, mcgroup) rec_socket.settimeout(.1) message = "Ho Ho Ho!" @@ -136,7 +139,7 @@ def check_message(sock, message): data, _ = sock.recvfrom(1024) assert data.decode() == message - snd_socket, snd_group = bbmcast.mcast_sender(mcgroup) + snd_socket, _snd_group = bbmcast.mcast_sender(mcgroup) thr = Thread(target=check_message, args=(rec_socket, message)) thr.start() @@ -152,7 +155,7 @@ def test_broadcast_roundtrip(reraise): """Test sending and receiving a broadcast message.""" mcgroup = "0.0.0.0" mcport = 5555 - rec_socket, rec_group = bbmcast.mcast_receiver(mcport, mcgroup) + rec_socket, _rec_group = bbmcast.mcast_receiver(mcport, mcgroup) message = "Ho Ho Ho!" @@ -161,7 +164,7 @@ def check_message(sock, message): data, _ = sock.recvfrom(1024) assert data.decode() == message - snd_socket, snd_group = bbmcast.mcast_sender(mcgroup) + snd_socket, _snd_group = bbmcast.mcast_sender(mcgroup) thr = Thread(target=check_message, args=(rec_socket, message)) thr.start() diff --git a/posttroll/tests/test_message.py b/posttroll/tests/test_message.py index 5aa88bf..af97236 100644 --- a/posttroll/tests/test_message.py +++ b/posttroll/tests/test_message.py @@ -154,15 +154,3 @@ def test_serialization(self): msg = json.loads(local_dump) for key, val in msg.items(): assert val == metadata.get(key) - - -def suite(): - """Create the suite for test_message.""" - loader = unittest.TestLoader() - mysuite = unittest.TestSuite() - mysuite.addTest(loader.loadTestsFromTestCase(Test)) - - return mysuite - -if __name__ == "__main__": - unittest.main() diff --git a/posttroll/tests/test_nameserver.py b/posttroll/tests/test_nameserver.py index 8f8689b..f4fe81d 100644 --- a/posttroll/tests/test_nameserver.py +++ b/posttroll/tests/test_nameserver.py @@ -60,7 +60,6 @@ def create_nameserver_instance(max_age=3, multicast_enabled=True): thr.join() - class TestAddressReceiver(unittest.TestCase): """Test the AddressReceiver.""" diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index c72011d..c152bf8 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -170,7 +170,7 @@ def test_pub_minmax_port_from_instanciation(self): # Using range of ports defined at instantation time, this # should override environment variables for port in range(50000, 60000): - res = _get_port_from_publish_instance(min_port=port, max_port=port+1) + res = _get_port_from_publish_instance(min_port=port, max_port=port + 1) if res is False: # The port wasn't free, try again continue @@ -231,7 +231,7 @@ def test_listener_container(self): sub.stop() -## Test create_publisher_from_config +# Test create_publisher_from_config def test_publisher_with_invalid_arguments_crashes(): """Test that only valid arguments are passed to Publisher.""" @@ -248,6 +248,7 @@ def test_publisher_is_selected(): assert isinstance(pub, Publisher) assert pub is not None + @mock.patch("posttroll.publisher.Publisher") def test_publisher_all_arguments(Publisher): """Test that only valid arguments are passed to Publisher.""" @@ -258,11 +259,13 @@ def test_publisher_all_arguments(Publisher): assert Publisher.call_args[0][0].startswith("tcp://*:") assert Publisher.call_args[0][0].endswith(str(settings["port"])) + def test_no_name_raises_keyerror(): """Trying to create a NoisyPublisher without a given name will raise KeyError.""" with pytest.raises(KeyError): _ = create_publisher_from_dict_config(dict()) + def test_noisypublisher_is_selected_only_name(): """Test that NoisyPublisher is selected as publisher class.""" from posttroll.publisher import NoisyPublisher @@ -272,6 +275,7 @@ def test_noisypublisher_is_selected_only_name(): pub = create_publisher_from_dict_config(settings) assert isinstance(pub, NoisyPublisher) + def test_noisypublisher_is_selected_name_and_port(): """Test that NoisyPublisher is selected as publisher class.""" from posttroll.publisher import NoisyPublisher @@ -281,6 +285,7 @@ def test_noisypublisher_is_selected_name_and_port(): pub = create_publisher_from_dict_config(settings) assert isinstance(pub, NoisyPublisher) + @mock.patch("posttroll.publisher.NoisyPublisher") def test_noisypublisher_all_arguments(NoisyPublisher): """Test that only valid arguments are passed to NoisyPublisher.""" @@ -293,6 +298,7 @@ def test_noisypublisher_all_arguments(NoisyPublisher): _check_valid_settings_in_call(settings, NoisyPublisher, ignore=["name"]) assert NoisyPublisher.call_args[0][0] == settings["name"] + def test_publish_is_not_noisy(): """Test that Publisher is selected with the context manager when it should be.""" from posttroll.publisher import Publish @@ -300,6 +306,7 @@ def test_publish_is_not_noisy(): with Publish("service_name", port=40000, nameservers=False) as pub: assert isinstance(pub, Publisher) + def test_publish_is_noisy_only_name(): """Test that NoisyPublisher is selected with the context manager when only name is given.""" from posttroll.publisher import NoisyPublisher, Publish @@ -307,6 +314,7 @@ def test_publish_is_noisy_only_name(): with Publish("service_name") as pub: assert isinstance(pub, NoisyPublisher) + def test_publish_is_noisy_with_port(): """Test that NoisyPublisher is selected with the context manager when port is given.""" from posttroll.publisher import NoisyPublisher, Publish @@ -314,6 +322,7 @@ def test_publish_is_noisy_with_port(): with Publish("service_name", port=40001) as pub: assert isinstance(pub, NoisyPublisher) + def test_publish_is_noisy_with_nameservers(): """Test that NoisyPublisher is selected with the context manager when nameservers are given.""" from posttroll.publisher import NoisyPublisher, Publish @@ -412,6 +421,7 @@ def _tcp_keepalive_settings(monkeypatch): with config.set(tcp_keepalive=1, tcp_keepalive_cnt=10, tcp_keepalive_idle=1, tcp_keepalive_intvl=1): yield + @contextmanager def reset_config_for_tests(): """Reset the config for testing.""" diff --git a/posttroll/tests/test_secure_zmq_backend.py b/posttroll/tests/test_secure_zmq_backend.py index 74619c9..f467fe2 100644 --- a/posttroll/tests/test_secure_zmq_backend.py +++ b/posttroll/tests/test_secure_zmq_backend.py @@ -26,10 +26,10 @@ def create_keys(tmp_path): secret_keys_dir.mkdir() # create new keys in certificates dir - server_public_file, server_secret_file = zmq.auth.create_certificates( + _server_public_file, _server_secret_file = zmq.auth.create_certificates( keys_dir, "server" ) - client_public_file, client_secret_file = zmq.auth.create_certificates( + _client_public_file, _client_secret_file = zmq.auth.create_certificates( keys_dir, "client" ) @@ -66,8 +66,8 @@ def test_ipc_pubsub_with_sec(tmp_path): pub = Publisher(ipc_address) - pub.start() + def delayed_send(msg): time.sleep(.2) from posttroll.message import Message @@ -111,7 +111,7 @@ def test_switch_to_secure_zmq_backend(tmp_path): def test_ipc_pubsub_with_sec_and_factory_sub(tmp_path): """Test pub-sub on a secure ipc socket.""" - #create_keys(tmp_path) + # create_keys(tmp_path) server_public_key_file, server_secret_key_file = zmq.auth.create_certificates(tmp_path, "server") client_public_key_file, client_secret_key_file = zmq.auth.create_certificates(tmp_path, "client") @@ -131,6 +131,7 @@ def test_ipc_pubsub_with_sec_and_factory_sub(tmp_path): pub = create_publisher_from_dict_config(pub_settings) pub.start() + def delayed_send(msg): time.sleep(.2) from posttroll.message import Message @@ -148,6 +149,7 @@ def delayed_send(msg): thr.join() pub.stop() + def test_switch_to_secure_backend_for_nameserver(tmp_path): """Test switching backend for nameserver.""" server_public_key_file, server_secret_key_file = zmq.auth.create_certificates(tmp_path, "server") @@ -163,7 +165,6 @@ def test_switch_to_secure_backend_for_nameserver(tmp_path): assert res == "" - def test_create_certificates_cli(tmp_path): """Test the certificate creation cli.""" from posttroll.backends.zmq import generate_keys diff --git a/posttroll/tests/test_unsecure_zmq_backend.py b/posttroll/tests/test_unsecure_zmq_backend.py index 66dbd6e..1b2b469 100644 --- a/posttroll/tests/test_unsecure_zmq_backend.py +++ b/posttroll/tests/test_unsecure_zmq_backend.py @@ -18,6 +18,7 @@ def test_ipc_pubsub(tmp_path): sub = create_subscriber_from_dict_config(subscriber_settings) pub = Publisher(ipc_address) pub.start() + def delayed_send(msg): time.sleep(.2) from posttroll.message import Message