Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mraspaud committed May 3, 2024
1 parent fb30354 commit 7bc00e6
Show file tree
Hide file tree
Showing 17 changed files with 658 additions and 519 deletions.
4 changes: 2 additions & 2 deletions posttroll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from donfig import Config

config = Config("posttroll")
config = Config("posttroll", defaults=[dict(backend="unsecure_zmq")])
# context = {}
logger = logging.getLogger(__name__)

Expand All @@ -40,7 +40,7 @@ def get_context():
This function takes care of creating new contexts in case of forks.
"""
backend = config.get("backend", "unsecure_zmq")
backend = config["backend"]
if "zmq" in backend:
from posttroll.backends.zmq import get_context
return get_context()
Expand Down
28 changes: 20 additions & 8 deletions posttroll/address_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@
import netifaces

from posttroll import config
from posttroll.bbmcast import MulticastReceiver, SocketTimeout
from posttroll.bbmcast import MulticastReceiver, SocketTimeout, get_configured_broadcast_port
from posttroll.message import Message
from posttroll.publisher import Publish
from zmq import ZMQError

__all__ = ("AddressReceiver", "getaddress")

LOGGER = logging.getLogger(__name__)

debug = os.environ.get("DEBUG", False)
broadcast_port = 21200

DEFAULT_ADDRESS_PUBLISH_PORT = 16543

Expand Down Expand Up @@ -144,13 +144,13 @@ def _check_age(self, pub, min_interval=zero_seconds):
msg = Message("/address/" + metadata["name"], "info", mda)
to_del.append(addr)
LOGGER.info(f"publish remove '{msg}'")
pub.send(msg.encode())
pub.send(str(msg.encode()))
for addr in to_del:
del self._addresses[addr]

def _run(self):
"""Run the receiver."""
port = broadcast_port
port = get_configured_broadcast_port()
nameservers, recv = self.set_up_address_receiver(port)

self._is_running = True
Expand All @@ -159,7 +159,16 @@ def _run(self):
try:
while self._do_run:
try:
data, fromaddr = recv()
rerun = True
while rerun:
try:
data, fromaddr = recv()
rerun = False
except TimeoutError:
if self._do_run:
continue
else:
raise
if self._multicast_enabled:
ip_, port = fromaddr
if self._restrict_to_localhost and ip_ not in self._local_ips:
Expand All @@ -171,6 +180,8 @@ def _run(self):
if self._multicast_enabled:
LOGGER.debug("Multicast socket timed out on recv!")
continue
except ZMQError:
return
finally:
self._check_age(pub, min_interval=self._max_age / 20)
if self._do_heartbeat:
Expand Down Expand Up @@ -216,9 +227,10 @@ def set_up_address_receiver(self, port):
break

else:
if config.get("backend", "unsecure_zmq") == "unsecure_zmq":
from posttroll.backends.zmq.address_receiver import SimpleReceiver
recv = SimpleReceiver(port)
if config["backend"] not in ["unsecure_zmq", "secure_zmq"]:
raise NotImplementedError
from posttroll.backends.zmq.address_receiver import SimpleReceiver
recv = SimpleReceiver(port, timeout=2)
nameservers = ["localhost"]
return nameservers,recv

Expand Down
55 changes: 48 additions & 7 deletions posttroll/backends/zmq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import zmq

from posttroll import config
from posttroll.message import Message

logger = logging.getLogger(__name__)
context = {}
Expand All @@ -21,14 +22,54 @@ def get_context():
logger.debug("renewed context for PID %d", pid)
return context[pid]

def destroy_context(linger=None):
pid = os.getpid()
context.pop(pid).destroy(linger)

def _set_tcp_keepalive(socket):
"""Set the tcp keepalive parameters on *socket*."""
_set_int_sockopt(socket, zmq.TCP_KEEPALIVE, config.get("tcp_keepalive", None))
_set_int_sockopt(socket, zmq.TCP_KEEPALIVE_CNT, config.get("tcp_keepalive_cnt", None))
_set_int_sockopt(socket, zmq.TCP_KEEPALIVE_IDLE, config.get("tcp_keepalive_idle", None))
_set_int_sockopt(socket, zmq.TCP_KEEPALIVE_INTVL, config.get("tcp_keepalive_intvl", None))
keepalive_options = get_tcp_keepalive_options()
for param, value in keepalive_options.items():
socket.setsockopt(param, value)

def get_tcp_keepalive_options():
"""Get the tcp_keepalive options from config."""
keepalive_options = dict()
for opt in ("tcp_keepalive",
"tcp_keepalive_cnt",
"tcp_keepalive_idle",
"tcp_keepalive_intvl"):
try:
value = int(config[opt])
except (KeyError, TypeError):
continue
param = getattr(zmq, opt.upper())
keepalive_options[param] = value
return keepalive_options


class SocketReceiver:

def __init__(self):
self._poller = zmq.Poller()

def register(self, socket):
"""Register the socket."""
self._poller.register(socket, zmq.POLLIN)

def unregister(self, socket):
"""Unregister the socket."""
self._poller.unregister(socket)

def _set_int_sockopt(socket, param, value):
if value is not None:
socket.setsockopt(param, int(value))
def receive(self, *sockets, timeout=None):
"""Timeout is in seconds."""
if timeout:
timeout *= 1000
socks = dict(self._poller.poll(timeout=timeout))
if socks:
for sock in sockets:
if socks.get(sock) == zmq.POLLIN:
received = sock.recv_string(zmq.NOBLOCK)
yield Message.decode(received), sock
else:
raise TimeoutError("Did not receive anything on sockets.")
24 changes: 17 additions & 7 deletions posttroll/backends/zmq/address_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,35 @@
from zmq import LINGER, REP

from posttroll.address_receiver import get_configured_address_port
from posttroll.backends.zmq import get_context
from posttroll.backends.zmq.socket import set_up_server_socket


class SimpleReceiver(object):
"""Simple listing on port for address messages."""

def __init__(self, port=None):
def __init__(self, port=None, timeout=2):
"""Set up the receiver."""
self._port = port or get_configured_address_port()
self._socket = get_context().socket(REP)
self._socket.bind("tcp://*:" + str(port))
address = "tcp://*:" + str(port)
self._socket, _, self._authenticator = set_up_server_socket(REP, address)
self._running = True
self.timeout = timeout

def __call__(self):
"""Receive a message."""
message = self._socket.recv_string()
self._socket.send_string("ok")
return message, None
while self._running:
try:
message = self._socket.recv_string(self.timeout)
except TimeoutError:
continue
else:
self._socket.send_string("ok")
return message, None

def close(self):
"""Close the receiver."""
self._running = False
self._socket.setsockopt(LINGER, 1)
self._socket.close()
if self._authenticator:
self._authenticator.stop()
21 changes: 11 additions & 10 deletions posttroll/backends/zmq/message_broadcaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,19 @@
import logging
import threading

from posttroll.backends.zmq.socket import set_up_client_socket
from zmq import LINGER, NOBLOCK, REQ, ZMQError

from posttroll.backends.zmq import get_context

logger = logging.getLogger(__name__)


class UnsecureZMQDesignatedReceiversSender:
class ZMQDesignatedReceiversSender:
"""Sends message to multiple *receivers* on *port*."""

def __init__(self, default_port, receivers):
"""Set up the sender."""
self.default_port = default_port

self.receivers = receivers
self._shutdown_event = threading.Event()

Expand All @@ -28,13 +27,14 @@ def __call__(self, data):
def _send_to_address(self, address, data, timeout=10):
"""Send data to *address* and *port* without verification of response."""
# Socket to talk to server
socket = get_context().socket(REQ)
if address.find(":") == -1:
full_address = "tcp://%s:%d" % (address, self.default_port)
else:
full_address = "tcp://%s" % address
options = {LINGER: int(timeout * 1000)}
socket = set_up_client_socket(REQ, full_address, options)
try:
socket.setsockopt(LINGER, timeout * 1000)
if address.find(":") == -1:
socket.connect("tcp://%s:%d" % (address, self.default_port))
else:
socket.connect("tcp://%s" % address)

socket.send_string(data)
while not self._shutdown_event.is_set():
try:
Expand All @@ -43,10 +43,11 @@ def _send_to_address(self, address, data, timeout=10):
self._shutdown_event.wait(.1)
continue
if message != "ok":
logger.warn("invalid acknowledge received: %s" % message)
logger.warning("invalid acknowledge received: %s" % message)
break

finally:
socket.setsockopt(LINGER, 1)
socket.close()

def close(self):
Expand Down
Loading

0 comments on commit 7bc00e6

Please sign in to comment.