Skip to content

Commit

Permalink
Fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
mraspaud committed May 22, 2024
1 parent 96dea2f commit 97b2f94
Show file tree
Hide file tree
Showing 15 changed files with 46 additions and 48 deletions.
2 changes: 1 addition & 1 deletion posttroll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,5 @@ def strp_isoformat(strg):
else:
dat, mis = strg.split(".")
dat = dt.datetime.strptime(dat, "%Y-%m-%dT%H:%M:%S")
mis = int(float("." + mis)*1000000)
mis = int(float("." + mis) * 1000000)

Check warning on line 71 in posttroll/__init__.py

View check run for this annotation

Codecov / codecov/patch

posttroll/__init__.py#L71

Added line #L71 was not covered by tests
return dat.replace(microsecond=mis)
12 changes: 5 additions & 7 deletions posttroll/address_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from zmq import ZMQError

from posttroll import config
from posttroll.bbmcast import MulticastReceiver, SocketTimeout, get_configured_broadcast_port
from posttroll.bbmcast import MulticastReceiver, get_configured_broadcast_port
from posttroll.message import Message
from posttroll.publisher import Publish

Expand All @@ -57,6 +57,7 @@
def get_configured_address_port():
return config.get("address_publish_port", DEFAULT_ADDRESS_PUBLISH_PORT)


def get_local_ips():
"""Get local IP addresses."""
inet_addrs = [netifaces.ifaddresses(iface).get(netifaces.AF_INET)
Expand Down Expand Up @@ -162,14 +163,11 @@ def _run(self):
data, fromaddr = recv()
except TimeoutError:
if self._do_run:
if self._multicast_enabled:
LOGGER.debug("Multicast socket timed out on recv!")
continue
else:
raise

except SocketTimeout:
if self._multicast_enabled:
LOGGER.debug("Multicast socket timed out on recv!")
continue
except ZMQError:
return

Check warning on line 172 in posttroll/address_receiver.py

View check run for this annotation

Codecov / codecov/patch

posttroll/address_receiver.py#L171-L172

Added lines #L171 - L172 were not covered by tests
finally:
Expand Down Expand Up @@ -229,7 +227,7 @@ def set_up_address_receiver(self, port):
from posttroll.backends.zmq.address_receiver import SimpleReceiver
recv = SimpleReceiver(port, timeout=2)
nameservers = ["localhost"]
return nameservers,recv
return nameservers, recv

def _add(self, adr, metadata):
"""Add an address."""
Expand Down
6 changes: 2 additions & 4 deletions posttroll/backends/zmq/ns.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,8 @@ def _fetch_address_using_socket(socket, name, timeout):
socket.send_string(str(message))

# Get the reply.
#socket.poll(timeout)
#message = socket.recv(timeout)
for message, _ in socket_receiver.receive(socket, timeout=timeout):
return message.data
return message.data
except TimeoutError:
raise TimeoutError("Didn't get an address after %d seconds."
% timeout)
Expand All @@ -61,6 +59,7 @@ def create_req_socket(timeout, nameserver_address):
socket = set_up_client_socket(REQ, nameserver_address, options)
return socket


class ZMQNameServer:
"""The name server."""

Expand Down Expand Up @@ -104,7 +103,6 @@ def close_sockets_and_threads(self):
with suppress(AttributeError):
self._authenticator.stop()


def stop(self):
"""Stop the name server."""
self.running = False
3 changes: 1 addition & 2 deletions posttroll/backends/zmq/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,9 @@ def create_secure_server_socket(socket_type):
# Tell authenticator to use the certificate in a directory
authenticator_thread.configure_curve(domain="*", location=clients_public_keys_directory)


server_socket = ctx.socket(socket_type)

server_public, server_secret =zmq.auth.load_certificate(server_secret_key)
server_public, server_secret = zmq.auth.load_certificate(server_secret_key)
server_socket.curve_secretkey = server_secret
server_socket.curve_publickey = server_public
server_socket.curve_server = True
Expand Down
5 changes: 0 additions & 5 deletions posttroll/backends/zmq/subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def __init__(self, addresses, topics="", message_filter=None, translate=False):
self._hooks = []
self._hooks_cb = {}

#self.poller = Poller()
self._sock_receiver = SocketReceiver()
self._lock = Lock()

Expand Down Expand Up @@ -119,7 +118,6 @@ def _add_hook(self, socket, callback):
self._hooks.append(socket)
self._hooks_cb[socket] = callback


@property
def addresses(self):
"""Get the addresses."""
Expand Down Expand Up @@ -165,9 +163,6 @@ def _new_messages(self, timeout):
if self._loop:
LOGGER.exception("Receive failed: %s", str(err))

Check warning on line 164 in posttroll/backends/zmq/subscriber.py

View check run for this annotation

Codecov / codecov/patch

posttroll/backends/zmq/subscriber.py#L163-L164

Added lines #L163 - L164 were not covered by tests




def __call__(self, **kwargs):
"""Handle calls with class instance."""
return self.recv(**kwargs)

Check warning on line 168 in posttroll/backends/zmq/subscriber.py

View check run for this annotation

Codecov / codecov/patch

posttroll/backends/zmq/subscriber.py#L168

Added line #L168 was not covered by tests
Expand Down
7 changes: 4 additions & 3 deletions posttroll/bbmcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@

DEFAULT_BROADCAST_PORT = 21200


def get_configured_broadcast_port():
"""Get the configured nameserver port."""
return config.get("broadcast_port", DEFAULT_BROADCAST_PORT)



# -----------------------------------------------------------------------------
#
# Sender.
Expand Down Expand Up @@ -114,8 +114,8 @@ def mcast_sender(mcgroup=None):
if _is_broadcast_group(mcgroup):
group = "<broadcast>"
sock.setsockopt(SOL_SOCKET, SO_BROADCAST, 1)
elif((int(mcgroup.split(".")[0]) > 239) or
(int(mcgroup.split(".")[0]) < 224)):
elif ((int(mcgroup.split(".")[0]) > 239) or
(int(mcgroup.split(".")[0]) < 224)):
raise IOError(f"Invalid multicast address {mcgroup}")
else:
group = mcgroup
Expand All @@ -130,6 +130,7 @@ def mcast_sender(mcgroup=None):
raise
return sock, group


def get_mc_group():
try:
mcgroup = os.environ["PYTROLL_MC_GROUP"]
Expand Down
3 changes: 3 additions & 0 deletions posttroll/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,14 @@ def _decode(rawstr):

return msg


def _check_for_version(raw):
version = raw[4][:len(_VERSION)]
if not _is_valid_version(version):
raise MessageError("Invalid Message version: '%s'" % str(version))

Check warning on line 289 in posttroll/message.py

View check run for this annotation

Codecov / codecov/patch

posttroll/message.py#L289

Added line #L289 was not covered by tests
return version


def _check_for_element_count(rawstr):
raw = re.split(r"\s+", rawstr, maxsplit=6)
if len(raw) < 5:
Expand All @@ -296,6 +298,7 @@ def _check_for_element_count(rawstr):

return raw


def _check_for_magic_word(rawstr):
"""Check for the magick word."""
try:
Expand Down
3 changes: 2 additions & 1 deletion posttroll/message_broadcaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

LOGGER = logging.getLogger(__name__)


class DesignatedReceiversSender:
"""Sends message to multiple *receivers* on *port*."""
def __init__(self, default_port, receivers):
Expand All @@ -51,7 +52,7 @@ def close(self):
"""Close the sender."""
return self._sender.close()

#-----------------------------------------------------------------------------
# ----------------------------------------------------------------------------
#
# General thread to broadcast messages.
#
Expand Down
1 change: 1 addition & 0 deletions posttroll/subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def _magickfy_topics(topics):
ts_.append(t__)
return ts_


class NSSubscriber:
"""Automatically subscribe to *services*.
Expand Down
13 changes: 8 additions & 5 deletions posttroll/tests/test_bbmcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_mcast_sender_works_with_valid_addresses():

socket.close()


def test_mcast_sender_uses_broadcast_for_0s():
"""Test mcast_sender uses broadcast for 0.0.0.0."""
mcgroup = "0.0.0.0"
Expand All @@ -56,6 +57,7 @@ def test_mcast_sender_uses_broadcast_for_0s():
assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1
socket.close()


def test_mcast_sender_uses_broadcast_for_255s():
"""Test mcast_sender uses broadcast for 255.255.255.255."""
mcgroup = "255.255.255.255"
Expand All @@ -64,6 +66,7 @@ def test_mcast_sender_uses_broadcast_for_255s():
assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1
socket.close()


def test_mcast_sender_raises_for_invalit_adresses():
"""Test mcast_sender uses broadcast for 0.0.0.0."""
mcgroup = (str(random.randint(0, 223)) + "." +
Expand All @@ -78,7 +81,7 @@ def test_mcast_sender_raises_for_invalit_adresses():
str(random.randint(0, 255)) + "." +
str(random.randint(0, 255)))
with pytest.raises(OSError, match="Invalid multicast address .*"):
bbmcast.mcast_sender(mcgroup)
bbmcast.mcast_sender(mcgroup)


def test_mcast_receiver_works_with_valid_addresses():
Expand Down Expand Up @@ -126,7 +129,7 @@ def test_multicast_roundtrip(reraise):
"""Test sending and receiving a multicast message."""
mcgroup = bbmcast.DEFAULT_MC_GROUP
mcport = 5555
rec_socket, rec_group = bbmcast.mcast_receiver(mcport, mcgroup)
rec_socket, _rec_group = bbmcast.mcast_receiver(mcport, mcgroup)
rec_socket.settimeout(.1)

message = "Ho Ho Ho!"
Expand All @@ -136,7 +139,7 @@ def check_message(sock, message):
data, _ = sock.recvfrom(1024)
assert data.decode() == message

snd_socket, snd_group = bbmcast.mcast_sender(mcgroup)
snd_socket, _snd_group = bbmcast.mcast_sender(mcgroup)

thr = Thread(target=check_message, args=(rec_socket, message))
thr.start()
Expand All @@ -152,7 +155,7 @@ def test_broadcast_roundtrip(reraise):
"""Test sending and receiving a broadcast message."""
mcgroup = "0.0.0.0"
mcport = 5555
rec_socket, rec_group = bbmcast.mcast_receiver(mcport, mcgroup)
rec_socket, _rec_group = bbmcast.mcast_receiver(mcport, mcgroup)

message = "Ho Ho Ho!"

Expand All @@ -161,7 +164,7 @@ def check_message(sock, message):
data, _ = sock.recvfrom(1024)
assert data.decode() == message

snd_socket, snd_group = bbmcast.mcast_sender(mcgroup)
snd_socket, _snd_group = bbmcast.mcast_sender(mcgroup)

thr = Thread(target=check_message, args=(rec_socket, message))
thr.start()
Expand Down
12 changes: 0 additions & 12 deletions posttroll/tests/test_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,3 @@ def test_serialization(self):
msg = json.loads(local_dump)
for key, val in msg.items():
assert val == metadata.get(key)


def suite():
"""Create the suite for test_message."""
loader = unittest.TestLoader()
mysuite = unittest.TestSuite()
mysuite.addTest(loader.loadTestsFromTestCase(Test))

return mysuite

if __name__ == "__main__":
unittest.main()
1 change: 0 additions & 1 deletion posttroll/tests/test_nameserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def create_nameserver_instance(max_age=3, multicast_enabled=True):
thr.join()



class TestAddressReceiver(unittest.TestCase):
"""Test the AddressReceiver."""

Expand Down
14 changes: 12 additions & 2 deletions posttroll/tests/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_pub_minmax_port_from_instanciation(self):
# Using range of ports defined at instantation time, this
# should override environment variables
for port in range(50000, 60000):
res = _get_port_from_publish_instance(min_port=port, max_port=port+1)
res = _get_port_from_publish_instance(min_port=port, max_port=port + 1)
if res is False:
# The port wasn't free, try again
continue
Expand Down Expand Up @@ -231,7 +231,7 @@ def test_listener_container(self):
sub.stop()


## Test create_publisher_from_config
# Test create_publisher_from_config

def test_publisher_with_invalid_arguments_crashes():
"""Test that only valid arguments are passed to Publisher."""
Expand All @@ -248,6 +248,7 @@ def test_publisher_is_selected():
assert isinstance(pub, Publisher)
assert pub is not None


@mock.patch("posttroll.publisher.Publisher")
def test_publisher_all_arguments(Publisher):
"""Test that only valid arguments are passed to Publisher."""
Expand All @@ -258,11 +259,13 @@ def test_publisher_all_arguments(Publisher):
assert Publisher.call_args[0][0].startswith("tcp://*:")
assert Publisher.call_args[0][0].endswith(str(settings["port"]))


def test_no_name_raises_keyerror():
"""Trying to create a NoisyPublisher without a given name will raise KeyError."""
with pytest.raises(KeyError):
_ = create_publisher_from_dict_config(dict())


def test_noisypublisher_is_selected_only_name():
"""Test that NoisyPublisher is selected as publisher class."""
from posttroll.publisher import NoisyPublisher
Expand All @@ -272,6 +275,7 @@ def test_noisypublisher_is_selected_only_name():
pub = create_publisher_from_dict_config(settings)
assert isinstance(pub, NoisyPublisher)


def test_noisypublisher_is_selected_name_and_port():
"""Test that NoisyPublisher is selected as publisher class."""
from posttroll.publisher import NoisyPublisher
Expand All @@ -281,6 +285,7 @@ def test_noisypublisher_is_selected_name_and_port():
pub = create_publisher_from_dict_config(settings)
assert isinstance(pub, NoisyPublisher)


@mock.patch("posttroll.publisher.NoisyPublisher")
def test_noisypublisher_all_arguments(NoisyPublisher):
"""Test that only valid arguments are passed to NoisyPublisher."""
Expand All @@ -293,27 +298,31 @@ def test_noisypublisher_all_arguments(NoisyPublisher):
_check_valid_settings_in_call(settings, NoisyPublisher, ignore=["name"])
assert NoisyPublisher.call_args[0][0] == settings["name"]


def test_publish_is_not_noisy():
"""Test that Publisher is selected with the context manager when it should be."""
from posttroll.publisher import Publish

with Publish("service_name", port=40000, nameservers=False) as pub:
assert isinstance(pub, Publisher)


def test_publish_is_noisy_only_name():
"""Test that NoisyPublisher is selected with the context manager when only name is given."""
from posttroll.publisher import NoisyPublisher, Publish

with Publish("service_name") as pub:
assert isinstance(pub, NoisyPublisher)


def test_publish_is_noisy_with_port():
"""Test that NoisyPublisher is selected with the context manager when port is given."""
from posttroll.publisher import NoisyPublisher, Publish

with Publish("service_name", port=40001) as pub:
assert isinstance(pub, NoisyPublisher)


def test_publish_is_noisy_with_nameservers():
"""Test that NoisyPublisher is selected with the context manager when nameservers are given."""
from posttroll.publisher import NoisyPublisher, Publish
Expand Down Expand Up @@ -412,6 +421,7 @@ def _tcp_keepalive_settings(monkeypatch):
with config.set(tcp_keepalive=1, tcp_keepalive_cnt=10, tcp_keepalive_idle=1, tcp_keepalive_intvl=1):
yield


@contextmanager
def reset_config_for_tests():
"""Reset the config for testing."""
Expand Down
Loading

0 comments on commit 97b2f94

Please sign in to comment.