Skip to content

Commit

Permalink
Merge pull request saltstack#320 from saltstack/issue/3002.9/61865
Browse files Browse the repository at this point in the history
[3002.9] Fix bug in tcp transport
  • Loading branch information
Ch3LL authored May 25, 2022
2 parents a9d9220 + 93f101e commit 37d8bb7
Show file tree
Hide file tree
Showing 4 changed files with 363 additions and 2 deletions.
1 change: 1 addition & 0 deletions changelog/61865.fixed
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug in tcp transport
4 changes: 2 additions & 2 deletions salt/transport/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,7 +1558,6 @@ def handle_stream(self, stream, address):
# TODO: ACK the publish through IPC
@salt.ext.tornado.gen.coroutine
def publish_payload(self, package, _):
log.debug("TCP PubServer sending payload: %s", package)
package = self.pack_publish(package)
payload = salt.transport.frame.frame_msg(package["payload"])

Expand Down Expand Up @@ -1633,7 +1632,8 @@ def _publish_daemon(self, **kwargs):

# Check if io_loop was set outside
if self.io_loop is None:
self.io_loop = salt.ext.tornado.ioloop.IOLoop.current()
self.io_loop = salt.ext.tornado.ioloop.IOLoop()
self.io_loop.make_current()

# Spin up the publisher
pub_server = PubServer(
Expand Down
31 changes: 31 additions & 0 deletions tests/pytests/functional/transport/zeromq/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
from saltfactories.utils import random_string


@pytest.fixture
def salt_master(salt_factories):
config_defaults = {
"transport": "zeromq",
"auto_accept": True,
"sign_pub_messages": False,
}
factory = salt_factories.get_salt_master_daemon(
random_string("zeromq-master-"), config_defaults=config_defaults
)
return factory


@pytest.fixture
def salt_minion(salt_master):
config_defaults = {
"transport": "zeromq",
"master_ip": "127.0.0.1",
"master_port": salt_master.config["ret_port"],
"auth_timeout": 5,
"auth_tries": 1,
"master_uri": "tcp://127.0.0.1:{}".format(salt_master.config["ret_port"]),
}
factory = salt_master.get_salt_minion_daemon(
random_string("zeromq-minion-"), config_defaults=config_defaults
)
return factory
329 changes: 329 additions & 0 deletions tests/pytests/functional/transport/zeromq/test_pub_server_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,329 @@
import ctypes
import logging
import multiprocessing
import signal
import socket
import time

import pytest
import salt.config
import salt.exceptions
import salt.ext.tornado.gen
import salt.ext.tornado.ioloop
import salt.log.setup
import salt.master
import salt.transport.client
import salt.transport.server
import salt.transport.tcp
import salt.transport.zeromq
import salt.utils.msgpack
import salt.utils.platform
import salt.utils.process
import salt.utils.stringutils
import zmq.eventloop.ioloop
from saltfactories.utils.processes import terminate_process

log = logging.getLogger(__name__)


class RecvError(Exception):
"""
Raised by the Collector's _recv method when there is a problem
getting publishes from to the publisher.
"""


class Collector(salt.utils.process.SignalHandlingProcess):
def __init__(
self, minion_config, interface, port, aes_key, timeout=300, zmq_filtering=False
):
super().__init__()
self.minion_config = minion_config
self.interface = interface
self.port = port
self.aes_key = aes_key
self.timeout = timeout
self.hard_timeout = time.time() + timeout + 30
self.manager = multiprocessing.Manager()
self.results = self.manager.list()
self.zmq_filtering = zmq_filtering
self.stopped = multiprocessing.Event()
self.started = multiprocessing.Event()
self.running = multiprocessing.Event()
if salt.utils.msgpack.version >= (0, 5, 2):
# Under Py2 we still want raw to be set to True
msgpack_kwargs = {"raw": False}
else:
msgpack_kwargs = {"encoding": "utf-8"}
self.unpacker = salt.utils.msgpack.Unpacker(**msgpack_kwargs)

@property
def transport(self):
return self.minion_config["transport"]

def _rotate_secrets(self, now=None):
salt.master.SMaster.secrets["aes"] = {
"secret": multiprocessing.Array(
ctypes.c_char,
salt.utils.stringutils.to_bytes(
salt.crypt.Crypticle.generate_key_string()
),
),
"serial": multiprocessing.Value(
ctypes.c_longlong, lock=False # We'll use the lock from 'secret'
),
"reload": salt.crypt.Crypticle.generate_key_string,
"rotate_master_key": self._rotate_secrets,
}

def _setup_listener(self):
if self.transport == "zeromq":
ctx = zmq.Context()
self.sock = ctx.socket(zmq.SUB)
self.sock.setsockopt(zmq.LINGER, -1)
self.sock.setsockopt(zmq.SUBSCRIBE, b"")
pub_uri = "tcp://{}:{}".format(self.interface, self.port)
self.sock.connect(pub_uri)
else:
end = time.time() + 300
while True:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.connect((self.interface, self.port))
except ConnectionRefusedError:
if time.time() >= end:
raise
time.sleep(1)
else:
break
self.sock = salt.ext.tornado.iostream.IOStream(sock)

@salt.ext.tornado.gen.coroutine
def _recv(self):
exc = None
if self.transport == "zeromq":
try:
payload = self.sock.recv(zmq.NOBLOCK)
serial_payload = salt.payload.Serial({}).loads(payload)
raise salt.ext.tornado.gen.Return(serial_payload)
except (zmq.ZMQError, salt.exceptions.SaltDeserializationError):
exc = RecvError("ZMQ Error")
else:
for msg in self.unpacker:
serial_payload = salt.payload.Serial({}).loads(msg["body"])
raise salt.ext.tornado.gen.Return(serial_payload)
byts = yield self.sock.read_bytes(8096, partial=True)
self.unpacker.feed(byts)
for msg in self.unpacker:
serial_payload = salt.payload.Serial({}).loads(msg["body"])
raise salt.ext.tornado.gen.Return(serial_payload)
exc = RecvError("TCP Error")
raise exc

@salt.ext.tornado.gen.coroutine
def _run(self, loop):
try:
self._setup_listener()
except Exception: # pylint: disable=broad-except
self.started.set()
log.exception("Failed to start listening")
return
self.started.set()
last_msg = time.time()
serial = salt.payload.Serial(self.minion_config)
crypticle = salt.crypt.Crypticle(self.minion_config, self.aes_key)
while True:
curr_time = time.time()
if time.time() > self.hard_timeout:
log.error("Hard timeout reaced in test collector!")
break
if curr_time - last_msg >= self.timeout:
log.error("Receive timeout reaced in test collector!")
break
try:
payload = yield self._recv()
except RecvError:
time.sleep(0.01)
else:
try:
payload = crypticle.loads(payload["load"])
if not payload:
continue
if "start" in payload:
log.info("Collector started")
self.running.set()
continue
if "stop" in payload:
log.info("Collector stopped")
break
last_msg = time.time()
self.results.append(payload["jid"])
except salt.exceptions.SaltDeserializationError:
log.error("Deserializer Error")
if not self.zmq_filtering:
log.exception("Failed to deserialize...")
break
loop.stop()

def run(self):
"""
Gather results until then number of seconds specified by timeout passes
without receiving a message
"""
loop = salt.ext.tornado.ioloop.IOLoop()
loop.add_callback(self._run, loop)
loop.start()

def __enter__(self):
self.manager.__enter__()
self.start()
# Wait until we can start receiving events
self.started.wait()
self.started.clear()
return self

def __exit__(self, *args):
# Wait until we either processed all expected messages or we reach the hard timeout
join_secs = self.hard_timeout - time.time()
log.info("Waiting at most %s seconds before exiting the collector", join_secs)
self.join(join_secs)
self.terminate()
# Cast our manager.list into a plain list
self.results = list(self.results)
# Terminate our multiprocessing manager
self.manager.__exit__(*args)
log.debug("The collector has exited")
self.stopped.set()


class PubServerChannelProcess(salt.utils.process.SignalHandlingProcess):
def __init__(self, master_config, minion_config, **collector_kwargs):
super().__init__()
self._closing = False
self.master_config = master_config
self.minion_config = minion_config
self.collector_kwargs = collector_kwargs
self.aes_key = salt.crypt.Crypticle.generate_key_string()
salt.master.SMaster.secrets["aes"] = {
"secret": multiprocessing.Array(
ctypes.c_char, salt.utils.stringutils.to_bytes(self.aes_key),
),
"serial": multiprocessing.Value(
ctypes.c_longlong, lock=False # We'll use the lock from 'secret'
),
}
self.process_manager = salt.utils.process.ProcessManager(
name="ZMQ-PubServer-ProcessManager"
)
self.pub_server_channel = salt.transport.server.PubServerChannel.factory(
self.master_config
)
self.pub_server_channel.pre_fork(
self.process_manager,
kwargs={"log_queue": salt.log.setup.get_multiprocessing_logging_queue()},
)
self.queue = multiprocessing.Queue()
self.stopped = multiprocessing.Event()
self.collector = Collector(
self.minion_config,
self.master_config["interface"],
self.master_config["publish_port"],
self.aes_key,
**self.collector_kwargs
)

def run(self):
try:
while True:
payload = self.queue.get()
if payload is None:
log.debug("We received the stop sentinal")
break
self.pub_server_channel.publish(payload)
except KeyboardInterrupt:
pass
finally:
self.stopped.set()

def _handle_signals(self, signum, sigframe):
self.close()
super()._handle_signals(signum, sigframe)

def close(self):
if self._closing:
return
self._closing = True
if self.process_manager is None:
return
self.process_manager.stop_restarting()
self.process_manager.send_signal_to_processes(signal.SIGTERM)
if hasattr(self.pub_server_channel, "pub_close"):
self.pub_server_channel.pub_close()
# Really terminate any process still left behind
for pid in self.process_manager._process_map:
terminate_process(pid=pid, kill_children=True, slow_stop=False)
self.process_manager = None

def publish(self, payload):
self.queue.put(payload)

def __enter__(self):
self.start()
self.collector.__enter__()
attempts = 300
while attempts > 0:
self.publish({"tgt_type": "glob", "tgt": "*", "jid": -1, "start": True})
if self.collector.running.wait(1) is True:
break
attempts -= 1
else:
pytest.fail("Failed to confirm the collector has started")
return self

def __exit__(self, *args):
# Publish a payload to tell the collection it's done processing
self.publish({"tgt_type": "glob", "tgt": "*", "jid": -1, "stop": True})
# Now trigger the collector to also exit
self.collector.__exit__(*args)
# We can safely wait here without a timeout because the Collector instance has a
# hard timeout set, so eventually Collector.stopped will be set
self.collector.stopped.wait()
self.collector.join()
# Stop our own processing
self.queue.put(None)
# Wait at most 10 secs for the above `None` in the queue to be processed
self.stopped.wait(10)
self.close()
self.terminate()
self.join()
log.info("The PubServerChannelProcess has terminated")


@pytest.fixture(params=["tcp", "zeromq"])
def transport(request):
yield request.param


@pytest.mark.skip_on_windows
@pytest.mark.slow_test
def test_publish_to_pubserv_ipc(salt_master, salt_minion, transport):
"""
Test sending 10K messags to ZeroMQPubServerChannel using IPC transport
ZMQ's ipc transport not supported on Windows
"""
opts = dict(
salt_master.config.copy(), ipc_mode="ipc", pub_hwm=0, transport=transport
)
minion_opts = dict(salt_minion.config.copy(), transport=transport)
with PubServerChannelProcess(opts, minion_opts) as server_channel:
send_num = 10000
expect = []
for idx in range(send_num):
expect.append(idx)
load = {"tgt_type": "glob", "tgt": "*", "jid": idx}
server_channel.publish(load)
results = server_channel.collector.results
assert len(results) == send_num, "{} != {}, difference: {}".format(
len(results), send_num, set(expect).difference(results)
)

0 comments on commit 37d8bb7

Please sign in to comment.