Skip to content

Commit

Permalink
Use poll instead of select for Linux
Browse files Browse the repository at this point in the history
  • Loading branch information
eandersson committed Sep 24, 2024
1 parent 206e291 commit 52444d1
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 20 deletions.
9 changes: 9 additions & 0 deletions amqpstorm/compatibility.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Python 2/3 Compatibility layer."""

import socket
import sys

try:
Expand Down Expand Up @@ -39,6 +40,14 @@ class DummyException(Exception):
"""


def get_default_poller():
if hasattr(socket, 'poll'):
return 'poll'
return 'select'


DEFAULT_POLLER = get_default_poller()

SSL_CERT_MAP = {}
SSL_VERSIONS = {}
SSL_OPTIONS = [
Expand Down
4 changes: 3 additions & 1 deletion amqpstorm/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class Connection(Stateful):
:param bool ssl: Enable SSL
:param dict ssl_options: SSL kwargs
:param dict client_properties: None or dict of client properties
:param str poller: select or poll
:param bool lazy: Lazy initialize the connection
:raises AMQPConnectionError: Raises if the connection
Expand All @@ -85,7 +86,8 @@ def __init__(self, hostname, username, password, port=5672, **kwargs):
'timeout': kwargs.get('timeout', DEFAULT_SOCKET_TIMEOUT),
'ssl': kwargs.get('ssl', False),
'ssl_options': kwargs.get('ssl_options', {}),
'client_properties': kwargs.get('client_properties', {})
'client_properties': kwargs.get('client_properties', {}),
'poller': kwargs.get('poller', compatibility.DEFAULT_POLLER),
}
self._validate_parameters()
self._io = IO(self.parameters, exceptions=self._exceptions,
Expand Down
61 changes: 50 additions & 11 deletions amqpstorm/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,10 @@
POLL_TIMEOUT = 1.0


class Poller(object):
"""Socket Read Poller."""

def __init__(self, fileno, exceptions, timeout=5):
self.select = select
class BasePoller(object):
def __init__(self, fileno, exceptions):
self._fileno = fileno
self._exceptions = exceptions
self.timeout = timeout

@property
def fileno(self):
Expand All @@ -35,22 +31,60 @@ def fileno(self):
"""
return self._fileno

def close(self):
pass


class SelectPoller(BasePoller):
""""Socket Read Poller using select.select."""

@property
def is_ready(self):
"""Is Socket Ready.
:rtype: tuple
"""
try:
ready, _, _ = self.select.select([self.fileno], [], [],
POLL_TIMEOUT)
ready, _, _ = select.select([self.fileno], [], [], POLL_TIMEOUT)
return bool(ready)
except self.select.error as why:
except select.error as why:
if why.args[0] != EINTR:
self._exceptions.append(AMQPConnectionError(why))
return False


class Poller(BasePoller):
"""Socket Read Poller using select.poll."""

def __init__(self, fileno, exceptions):
super().__init__(fileno, exceptions)
self.poller = select.poll()
self.poller.register(self._fileno, select.POLLIN | select.POLLPRI)

@property
def is_ready(self):
"""Check if the socket is ready for reading.
:rtype: bool
"""
try:
events = self.poller.poll(POLL_TIMEOUT)
for fd, event in events:
if fd == self.fileno:
return True
except select.error as why:
if why.args[0] != EINTR:
self._exceptions.append(AMQPConnectionError(why))
return False

def close(self):
"""Unregister the file descriptor."""
try:
self.poller.unregister(self.fileno)
except OSError:
pass


class IO(object):
"""Internal Input/Output handler."""

Expand All @@ -66,6 +100,7 @@ def __init__(self, parameters, exceptions=None, on_read_impl=None):
self.poller = None
self.socket = None
self.use_ssl = self._parameters['ssl']
self.poller_type = self._parameters['poller']

def close(self):
"""Close Socket.
Expand Down Expand Up @@ -102,8 +137,10 @@ def open(self):
self._running.set()
sock_addresses = self._get_socket_addresses()
self.socket = self._find_address_and_connect(sock_addresses)
self.poller = Poller(self.socket.fileno(), self._exceptions,
timeout=self._parameters['timeout'])
if self.poller_type == 'poll':
self.poller = Poller(self.socket.fileno(), self._exceptions)
else:
self.poller = SelectPoller(self.socket.fileno(), self._exceptions)
self._inbound_thread = self._create_inbound_thread()
finally:
self._wr_lock.release()
Expand Down Expand Up @@ -147,6 +184,8 @@ def _close_socket(self):
if not self.socket:
return
try:
if self.poller:
self.poller.close()
if self.use_ssl:
self.socket.unwrap()
self.socket.shutdown(socket.SHUT_RDWR)
Expand Down
17 changes: 17 additions & 0 deletions amqpstorm/tests/functional/test_reliability.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,20 @@ def test_functional_publish_and_consume_until_empty(self):
'not all messages consumed')

channel.close()


connection = Connection(
HOST, USERNAME, PASSWORD
)
channel = connection.channel()
channel.confirm_deliveries()
try:
channel.basic.publish(
body='message',
routing_key='self.queue_name',
exchange='invalid'
)
except (AMQPConnectionError, AMQPChannelError):
pass
channel.close()
connection.close()
32 changes: 25 additions & 7 deletions amqpstorm/tests/unit/io/test_io_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from amqpstorm import compatibility
from amqpstorm.io import IO
from amqpstorm.io import Poller
from amqpstorm.io import SelectPoller
from amqpstorm.tests.utility import FakeConnection
from amqpstorm.tests.utility import TestFramework

Expand Down Expand Up @@ -163,21 +164,38 @@ def test_io_raises_gaierror(self, _):
io._get_socket_addresses
)

@mock.patch('amqpstorm.io.select.select',
side_effect=select.error('travis-ci'))
def test_io_poller_raises(self, _):
@mock.patch('select.select')
def test_io_poller_raises(self, mock_select):
mock_select.side_effect = select.error('travis-ci')
exceptions = []
poller = Poller(0, exceptions, 30)
poller = SelectPoller(0, exceptions)
self.assertFalse(poller.is_ready)
self.assertTrue(exceptions)

@mock.patch('amqpstorm.io.select.select', side_effect=select.error(EINTR))
def test_io_poller_eintr(self, _):
@mock.patch('select.select')
def test_io_select_poller_eintr(self, mock_select):
mock_select.side_effect = select.error(EINTR)
exceptions = []
poller = Poller(0, exceptions, 30)
poller = SelectPoller(0, exceptions)
self.assertFalse(poller.is_ready)
self.assertFalse(exceptions)

@mock.patch('select.poll')
def test_io_poll_poller_eintr(self, mock_poll):
mock_poll().poll.side_effect = select.error(EINTR)
exceptions = []
poller = Poller(0, exceptions)
self.assertFalse(poller.is_ready)
self.assertFalse(exceptions)

@mock.patch('select.poll')
def test_io_poll_poller_raises(self, mock_poll):
mock_poll().poll.side_effect = select.error('travis-ci')
exceptions = []
poller = Poller(0, exceptions)
self.assertFalse(poller.is_ready)
self.assertTrue(exceptions)

def test_io_simple_receive_when_socket_not_set(self):
connection = FakeConnection()
io = IO(connection.parameters, exceptions=connection.exceptions)
Expand Down
3 changes: 2 additions & 1 deletion amqpstorm/tests/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def __init__(self, state=Connection.OPEN, on_write=None):
'heartbeat': 60,
'timeout': 30,
'ssl': False,
'ssl_options': {}
'ssl_options': {},
'poller': 'select',
}
self.set_state(state)
self.on_write = on_write
Expand Down
2 changes: 2 additions & 0 deletions amqpstorm/uri_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def _parse_uri_options(self, parsed_uri, use_ssl=False, ssl_options=None):
'virtual_host': vhost,
'heartbeat': int(kwargs.pop('heartbeat',
[DEFAULT_HEARTBEAT_TIMEOUT])[0]),
'poller': kwargs.pop('poller',
[compatibility.DEFAULT_POLLER])[0],
'timeout': int(kwargs.pop('timeout',
[DEFAULT_SOCKET_TIMEOUT])[0])
}
Expand Down

0 comments on commit 52444d1

Please sign in to comment.