diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index 2b836a7..0000000 --- a/.coveragerc +++ /dev/null @@ -1,8 +0,0 @@ -[run] -branch = True -source = token_bucket - -parallel = True - -[report] -show_missing = True diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 6cb3f47..97c0a36 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -11,39 +11,26 @@ on: jobs: run_tox: - name: tox -e ${{ matrix.toxenv }} (${{matrix.python-version}} on ${{ matrix.os }}) + name: tox run (${{ matrix.python-version }} on ${{ matrix.os }}) runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: - python-version: - - "3.8" - os: - - "ubuntu-20.04" - toxenv: - - "pep8" + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "pypy3.9"] + os: ["ubuntu-22.04"] include: - - python-version: "3.5" - os: ubuntu-20.04 - toxenv: py35 - - python-version: "3.6" - os: ubuntu-20.04 - toxenv: py36 - python-version: "3.7" - os: ubuntu-20.04 - toxenv: py37 - - python-version: "3.8" - os: ubuntu-20.04 - toxenv: py38 - - python-version: "3.9" - os: ubuntu-20.04 - toxenv: py39 - - python-version: "3.10" - os: ubuntu-20.04 - toxenv: py310 - - python-version: pypy3 - os: ubuntu-20.04 - toxenv: pypy3 + os: "ubuntu-22.04" + coverage: true + mypy: true + pep8: true + - python-verson: "3.11" + os: "ubuntu-22.04" + mypy: true + + env: + OS: ${{ matrix.os }} + PYTHON: ${{ matrix.python-version }} # Steps to run in each job. # Some are GitHub actions, others run shell commands. @@ -52,35 +39,41 @@ jobs: uses: actions/checkout@v2 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - - name: Set up Python 3.8 env - if: ${{ matrix.toxenv == 'py35' }} - run: | - sudo apt-get update - sudo apt-get install -y build-essential python3.8 python3.8-dev python3.8-venv - python3.8 -m venv py38-env - - name: Install dependencies run: | - if [ -f py38-env/bin/activate ]; then source py38-env/bin/activate; fi - python -m pip install --upgrade pip - pip install coverage tox + pip install tox python --version pip --version tox --version - coverage --version - - name: Run tests + - name: Setup test suite + run: | + tox run -vv --notest + + - name: Run test suite + run: | + tox run --skip-pkg-install + + - name: Check pep8 + if: matrix.pep8 + run: | + tox run -e pep8 + + - name: Check mypy + if: matrix.mypy run: | - if [ -f py38-env/bin/activate ]; then source py38-env/bin/activate; fi - tox -e ${{ matrix.toxenv }} + tox run -e mypy - name: Upload coverage to Codecov - uses: codecov/codecov-action@v1 - if: ${{ matrix.toxenv == 'py38' }} + uses: codecov/codecov-action@v3 + if: matrix.coverage with: - env_vars: PYTHON + env_vars: OS,PYTHON + files: ./coverage.xml + flags: unittests + name: codecov-umbrella fail_ci_if_error: true diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 31841ca..0000000 --- a/MANIFEST.in +++ /dev/null @@ -1,4 +0,0 @@ -include tests/* -include tools/* -include requirements/test -include tox.ini diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..6e6ad94 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,71 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "token-bucket" +dynamic = ["version"] +description = "Very fast implementation of the token bucket algorithm." +readme = "README.rst" +license = "Apache-2.0" +requires-python = ">=3.7" +authors = [{ name = "kgriffs", email = "mail@kgriffs.com" }] +keywords = [ + "bucket", + "cloud", + "http", + "https", + "limiting", + "rate", + "throttling", + "token", + "web", +] +classifiers = [ + "Development Status :: 4 - Beta", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Intended Audience :: System Administrators", + "License :: OSI Approved :: Apache Software License", + "Natural Language :: English", + "Operating System :: MacOS :: MacOS X", + "Operating System :: Microsoft :: Windows", + "Operating System :: POSIX", + "Programming Language :: Python", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", + "Topic :: Internet :: WWW/HTTP", + "Topic :: Software Development :: Libraries", +] +dependencies = [] + +[project.urls] +Homepage = "https://github.com/falconry/token-bucket" + +[tool.hatch.version] +path = "src/token_bucket/version.py" + +[tool.hatch.build] +source = ["src"] + +[tool.coverage.run] +branch = true +source = ["src"] +parallel = true + +[tool.coverage.report] +show_missing = true +exclude_lines = [ + "pragma: no cover", + "if __name__ == .__main__.:", + "@(abc\\.)?abstractmethod", +] + +[tool.black] +line-length = 88 +target-version = ['py37', 'py38'] diff --git a/requirements/tests b/requirements/tests index 7093b61..21d4862 100644 --- a/requirements/tests +++ b/requirements/tests @@ -1,2 +1,3 @@ -coverage pytest +pytest-cov +freezegun diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 4629a02..0000000 --- a/setup.cfg +++ /dev/null @@ -1,11 +0,0 @@ -[egg_info] -tag_build = dev1 - -[wheel] -universal = 1 - -[aliases] -test = pytest - -[tool:pytest] -addopts = tests diff --git a/setup.py b/setup.py deleted file mode 100644 index 892fd32..0000000 --- a/setup.py +++ /dev/null @@ -1,48 +0,0 @@ -import imp -import io -from os import path - -from setuptools import find_packages, setup - -VERSION = imp.load_source('version', path.join('.', 'token_bucket', 'version.py')) -VERSION = VERSION.__version__ - - -setup( - name='token_bucket', - version=VERSION, - description='Very fast implementation of the token bucket algorithm.', - long_description=io.open('README.rst', 'r', encoding='utf-8').read(), - classifiers=[ - 'Development Status :: 4 - Beta', - 'Environment :: Web Environment', - 'Natural Language :: English', - 'Intended Audience :: Developers', - 'Intended Audience :: System Administrators', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: MacOS :: MacOS X', - 'Operating System :: Microsoft :: Windows', - 'Operating System :: POSIX', - 'Topic :: Internet :: WWW/HTTP', - 'Topic :: Software Development :: Libraries', - 'Programming Language :: Python', - 'Programming Language :: Python :: Implementation :: CPython', - 'Programming Language :: Python :: Implementation :: PyPy', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - ], - keywords='web http https cloud rate limiting token bucket throttling', - author='kgriffs', - author_email='mail@kgriffs.com', - url='https://github.com/falconry/token-bucket', - license='Apache 2.0', - packages=find_packages(exclude=['tests']), - python_requires='>=3.5', - install_requires=[], - setup_requires=['pytest-runner'], - tests_require=['pytest'], -) diff --git a/token_bucket/__init__.py b/src/token_bucket/__init__.py similarity index 56% rename from token_bucket/__init__.py rename to src/token_bucket/__init__.py index a68c1c3..04b672d 100644 --- a/token_bucket/__init__.py +++ b/src/token_bucket/__init__.py @@ -5,7 +5,9 @@ # not use this "front-door" module, but rather import using the # fully-qualified paths. -from .version import __version__ # NOQA -from .storage import MemoryStorage # NOQA -from .storage_base import StorageBase # NOQA -from .limiter import Limiter # NOQA +from .limiter import Limiter +from .storage import MemoryStorage +from .storage_base import StorageBase +from .version import __version__ + +__all__ = ["Limiter", "MemoryStorage", "StorageBase", "__version__"] diff --git a/token_bucket/limiter.py b/src/token_bucket/limiter.py similarity index 76% rename from token_bucket/limiter.py rename to src/token_bucket/limiter.py index ea1c631..49f058d 100644 --- a/token_bucket/limiter.py +++ b/src/token_bucket/limiter.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union + +from .storage_base import KeyType from .storage_base import StorageBase -class Limiter(object): +class Limiter: """Limits demand for a finite resource via keyed token buckets. A limiter manages a set of token buckets that have an identical @@ -60,45 +63,36 @@ class Limiter(object): """ __slots__ = ( - '_rate', - '_capacity', - '_storage', + "_rate", + "_capacity", + "_storage", ) - def __init__(self, rate, capacity, storage): - if not isinstance(rate, (float, int)): - raise TypeError('rate must be an int or float') - + def __init__(self, rate: Union[float, int], capacity: int, storage: StorageBase): if rate <= 0: - raise ValueError('rate must be > 0') - - if not isinstance(capacity, int): - raise TypeError('capacity must be an int') + raise ValueError("rate must be > 0") if capacity < 1: - raise ValueError('capacity must be >= 1') - - if not isinstance(storage, StorageBase): - raise TypeError('storage must be a subclass of StorageBase') + raise ValueError("capacity must be >= 1") self._rate = rate self._capacity = capacity self._storage = storage - def consume(self, key, num_tokens=1): + def consume(self, key: KeyType, num_tokens: int = 1) -> bool: """Attempt to take one or more tokens from a bucket. If the specified token bucket does not yet exist, it will be created and initialized to full capacity before proceeding. Args: - key (bytes): A string or bytes object that specifies the + key: A string or bytes object that specifies the token bucket to consume from. If a global limit is desired for all consumers, the same key may be used for every call to consume(). Otherwise, a key based on consumer identity may be used to segregate limits. Keyword Args: - num_tokens (int): The number of tokens to attempt to + num_tokens: The number of tokens to attempt to consume, defaulting to 1 if not specified. It may be appropriate to ask for more than one token according to the proportion of the resource that a given request @@ -106,24 +100,15 @@ def consume(self, key, num_tokens=1): resource. Returns: - bool: True if the requested number of tokens were removed + True if the requested number of tokens were removed from the bucket (conforming), otherwise False (non- conforming). The entire number of tokens requested must be available in the bucket to be conforming. Otherwise, no tokens will be removed (it's all or nothing). """ - if not key: - if key is None: - raise TypeError('key may not be None') - - raise ValueError('key must not be a non-empty string or bytestring') - - if num_tokens is None: - raise TypeError('num_tokens may not be None') - if num_tokens < 1: - raise ValueError('num_tokens must be >= 1') + raise ValueError("num_tokens must be >= 1") self._storage.replenish(key, self._rate, self._capacity) return self._storage.consume(key, num_tokens) diff --git a/src/token_bucket/py.typed b/src/token_bucket/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/token_bucket/storage.py b/src/token_bucket/storage.py similarity index 92% rename from token_bucket/storage.py rename to src/token_bucket/storage.py index 9a93093..da9b555 100644 --- a/token_bucket/storage.py +++ b/src/token_bucket/storage.py @@ -13,9 +13,14 @@ # limitations under the License. import time +from typing import Dict, List +from .storage_base import KeyType from .storage_base import StorageBase +TOKEN_POS = 0 +REPLENISH_TIME_POS = 1 + class MemoryStorage(StorageBase): """In-memory token bucket storage engine. @@ -30,29 +35,29 @@ class MemoryStorage(StorageBase): """ def __init__(self): - self._buckets = {} + self._buckets: Dict[KeyType, List[float]] = {} - def get_token_count(self, key): + def get_token_count(self, key: KeyType) -> float: """Query the current token count for the given bucket. Note that the bucket is not replenished first, so the count will be what it was the last time replenish() was called. Args: - key (str): Name of the bucket to query. + key: Name of the bucket to query. Returns: - float: Number of tokens currently in the bucket (may be + Number of tokens currently in the bucket (may be fractional). """ try: - return self._buckets[key][0] + return self._buckets[key][TOKEN_POS] except KeyError: pass return 0 - def replenish(self, key, rate, capacity): + def replenish(self, key: KeyType, rate: float, capacity: int) -> None: """Add tokens to a bucket per the given rate. This method is exposed for use by the token_bucket.Limiter @@ -109,23 +114,21 @@ def replenish(self, key, rate, capacity): # Limit to capacity min( capacity, - # NOTE(kgriffs): The new value is the current number # of tokens in the bucket plus the number of # tokens generated since last time. Fractional # tokens are permitted in order to improve # accuracy (now is a float, and rate may be also). - tokens_in_bucket + (rate * (now - last_replenished_at)) + tokens_in_bucket + (rate * (now - last_replenished_at)), ), - # Update the timestamp for use next time - now + now, ] except KeyError: self._buckets[key] = [capacity, time.monotonic()] - def consume(self, key, num_tokens): + def consume(self, key: KeyType, num_tokens: int) -> bool: """Attempt to take one or more tokens from a bucket. This method is exposed for use by the token_bucket.Limiter @@ -134,7 +137,7 @@ def consume(self, key, num_tokens): # NOTE(kgriffs): Assume that the key will be present, since # replenish() will always be called before consume(). - tokens_in_bucket = self._buckets[key][0] + tokens_in_bucket = self._buckets[key][TOKEN_POS] if tokens_in_bucket < num_tokens: return False @@ -174,6 +177,5 @@ def consume(self, key, num_tokens): # much contention for the lock during such a short # time window, but we might as well remove the # possibility given the points above. - - self._buckets[key][0] -= num_tokens + self._buckets[key][TOKEN_POS] -= num_tokens return True diff --git a/token_bucket/storage_base.py b/src/token_bucket/storage_base.py similarity index 72% rename from token_bucket/storage_base.py rename to src/token_bucket/storage_base.py index e64abdf..80424e3 100644 --- a/token_bucket/storage_base.py +++ b/src/token_bucket/storage_base.py @@ -13,28 +13,29 @@ # limitations under the License. import abc +from typing import Union +KeyType = Union[str, bytes] -class StorageBase(object): - __metaclass__ = abc.ABCMeta +class StorageBase(abc.ABC): @abc.abstractmethod - def get_token_count(self, key): + def get_token_count(self, key: KeyType) -> float: """Query the current token count for the given bucket. Note that the bucket is not replenished first, so the count will be what it was the last time replenish() was called. Args: - key (str): Name of the bucket to query. + key: Name of the bucket to query. Returns: - float: Number of tokens currently in the bucket (may be + Number of tokens currently in the bucket (may be fractional). """ @abc.abstractmethod - def replenish(self, key, rate, capacity): + def replenish(self, key: KeyType, rate: float, capacity: int) -> None: """Add tokens to a bucket per the given rate. Conceptually, tokens are added to the bucket at a rate of one @@ -44,28 +45,28 @@ def replenish(self, key, rate, capacity): bucket was replenished. Args: - key (str): Name of the bucket to replenish. - rate (float): Number of tokens per second to add to the + key: Name of the bucket to replenish. + rate: Number of tokens per second to add to the bucket. Over time, the number of tokens that can be consumed is limited by this rate. - capacity (int): Maximum number of tokens that the bucket + capacity: Maximum number of tokens that the bucket can hold. Once the bucket if full, additional tokens are discarded. """ @abc.abstractmethod - def consume(self, key, num_tokens): + def consume(self, key: KeyType, num_tokens: int) -> bool: """Attempt to take one or more tokens from a bucket. Args: - key (str): Name of the bucket to replenish. - num_tokens (int): Number of tokens to try to consume from + key: Name of the bucket to replenish. + num_tokens: Number of tokens to try to consume from the bucket. If the bucket contains fewer than the requested number, no tokens are removed (i.e., it's all or nothing). Returns: - bool: True if the requested number of tokens were removed + True if the requested number of tokens were removed from the bucket (conforming), otherwise False (non- conforming). """ diff --git a/token_bucket/version.py b/src/token_bucket/version.py similarity index 74% rename from token_bucket/version.py rename to src/token_bucket/version.py index 9897f79..f0e7803 100644 --- a/token_bucket/version.py +++ b/src/token_bucket/version.py @@ -1,4 +1,4 @@ """Package version.""" -__version__ = '0.4.0' +__version__ = "0.4.0" """Current version of token_bucket.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..397c537 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,20 @@ +import platform + +from pytest import Config + + +def pytest_configure(config: Config): + # When testing with PyPy and coverage, tests become incredible slow and + # could break them. There are several issues reported with other plugins + # too. So you should check carefully if PyPy support it. + # https://github.com/pytest-dev/pytest-cov/issues/418 + # https://github.com/pytest-dev/pytest/issues/7675 + if platform.python_implementation() == "PyPy": + cov = config.pluginmanager.get_plugin("_cov") + + # probably pytest_cov is not installed + if cov: + cov.options.no_cov = True + + if cov.cov_controller: + cov.cov_controller.pause() diff --git a/tests/test_limiter.py b/tests/test_limiter.py index aabaaed..1d020c5 100644 --- a/tests/test_limiter.py +++ b/tests/test_limiter.py @@ -1,25 +1,39 @@ +import datetime import functools -import time +from typing import Type import uuid +from freezegun import freeze_time +from freezegun.api import FrozenDateTimeFactory import pytest import token_bucket -@pytest.mark.parametrize('rate,capacity', [ - (0.3, 1), - (1, 1), - (2.5, 1), # Fractional rates are valid - (10, 100), # Long recovery time after bursting - (10, 10), - (10, 1), # Disallow bursting - (100, 100), - (100, 10), - (100, 1), # Disallow bursting -]) -def test_general_functionality(rate, capacity): - key = 'key' +@pytest.fixture +def frozen_time(): + with freeze_time() as ft: + yield ft + + +@pytest.mark.parametrize( + ("rate", "capacity"), + [ + (0.3, 1), + (1, 1), + (2.5, 1), # Fractional rates are valid + (10, 100), # Long recovery time after bursting + (10, 10), + (10, 1), # Disallow bursting + (100, 100), + (100, 10), + (100, 1), # Disallow bursting + ], +) +def test_general_functionality( + rate: int, capacity: int, frozen_time: FrozenDateTimeFactory +): + key = "key" storage = token_bucket.MemoryStorage() limiter = token_bucket.Limiter(rate, capacity, storage) @@ -27,56 +41,42 @@ def test_general_functionality(rate, capacity): consume_one = functools.partial(limiter.consume, key) - # NOTE(kgriffs) Trigger creation of the bucket and then - # sleep to ensure it is at full capacity before testing it. - consume_one() - time.sleep(float(capacity) / rate) + # NOTE(kgriffs) Trigger creation of the bucket. + storage.replenish(key, rate, capacity) + assert storage.get_token_count(key) == capacity - # NOTE(kgriffs): This works because we can consume at a much - # higher rate relative to the replenishment rate, such that we - # easily consume the total capacity before a single token can - # be replenished. def consume_all(): - for i in range(capacity + 3): + conforming = limiter.consume(key, num_tokens=capacity) + assert conforming + for _ in range(3): conforming = consume_one() - - # NOTE(kgriffs): One past the end should be non-conforming, - # but sometimes an extra token or two can be generated, so - # only check a couple past the end for non-conforming. - if i < capacity: - assert conforming - elif i > capacity + 1: - assert not conforming + assert not conforming # Check non-conforming after consuming all of the tokens consume_all() # Let the bucket replenish 1 token - time.sleep(1.0 / rate) + frozen_time.tick(delta=datetime.timedelta(seconds=(1.5 / rate))) assert consume_one() - - # NOTE(kgriffs): Occasionally enough time will have elapsed to - # cause an additional token to be generated. Clear that one - # out if it is there. - consume_one() - assert storage.get_token_count(key) < 1.0 # NOTE(kgriffs): Let the bucket replenish all the tokens; do this # twice to verify that the bucket is limited to capacity. for __ in range(2): - time.sleep(float(capacity) / rate) + frozen_time.tick(delta=datetime.timedelta(seconds=((capacity + 0.5) / rate))) storage.replenish(key, rate, capacity) assert int(storage.get_token_count(key)) == capacity consume_all() -@pytest.mark.parametrize('capacity', [1, 2, 4, 10]) -def test_consume_multiple_tokens_at_a_time(capacity): +@pytest.mark.parametrize("capacity", [1, 2, 4, 10]) +def test_consume_multiple_tokens_at_a_time( + capacity: int, frozen_time: FrozenDateTimeFactory +): rate = 100 num_tokens = capacity - key = 'key' + key = "key" storage = token_bucket.MemoryStorage() limiter = token_bucket.Limiter(rate, capacity, storage) @@ -90,7 +90,9 @@ def test_consume_multiple_tokens_at_a_time(capacity): assert storage.get_token_count(key) < 1.0 # Sleep long enough to generate num_tokens - time.sleep(1.0 / rate * num_tokens) + frozen_time.tick( + delta=datetime.timedelta(seconds=(1.0 / rate * (num_tokens + 0.1))) + ) def test_different_keys(): @@ -102,69 +104,51 @@ def test_different_keys(): keys = [ uuid.uuid4().bytes, - u'3084"5tj jafsb: f', - b'77752098', - u'whiz:bang', - b'x' + '3084"5tj jafsb: f', + b"77752098", + "whiz:bang", + b"x", ] # The last two should be non-conforming - for i in range(capacity + 2): - for k in keys: - conforming = limiter.consume(k) - - if i < capacity: - assert conforming - else: - assert not conforming - - -def test_input_validation_storage_type(): - class DoesNotInheritFromStorageBase(object): - pass - - with pytest.raises(TypeError): - token_bucket.Limiter(1, 1, DoesNotInheritFromStorageBase()) - - -@pytest.mark.parametrize('rate,capacity,etype', [ - (0, 0, ValueError), - (0, 1, ValueError), - (1, 0, ValueError), - (-1, -1, ValueError), - (-1, 0, ValueError), - (0, -1, ValueError), - (-2, -2, ValueError), - (-2, 0, ValueError), - (0, -2, ValueError), - ('x', 'y', TypeError), - ('x', -1, (ValueError, TypeError)), # Params could be checked in any order - (-1, 'y', (ValueError, TypeError)), # ^^^ - ('x', 1, TypeError), - (1, 'y', TypeError), - ('x', None, TypeError), - (None, 'y', TypeError), - (None, None, TypeError), - (None, 1, TypeError), - (1, None, TypeError), -]) -def test_input_validation_rate_and_capacity(rate, capacity, etype): + for k in keys: + assert limiter.consume(k, capacity) + for _ in range(2): + assert not limiter.consume(k) + + +@pytest.mark.parametrize( + ("rate", "capacity", "etype"), + [ + (0, 0, ValueError), + (0, 1, ValueError), + (1, 0, ValueError), + (-1, -1, ValueError), + (-1, 0, ValueError), + (0, -1, ValueError), + (-2, -2, ValueError), + (-2, 0, ValueError), + (0, -2, ValueError), + ], +) +def test_input_validation_rate_and_capacity( + rate: float, capacity: int, etype: Type[Exception] +): with pytest.raises(etype): token_bucket.Limiter(rate, capacity, token_bucket.MemoryStorage()) -@pytest.mark.parametrize('key,num_tokens,etype', [ - ('', 1, ValueError), - ('', 0, ValueError), - ('x', 0, ValueError), - ('x', -1, ValueError), - ('x', -2, ValueError), - (-1, None, (ValueError, TypeError)), # Params could be checked in any order - (None, -1, (ValueError, TypeError)), # ^^^ - (None, 1, TypeError), - (1, None, TypeError), -]) -def test_input_validation_on_consume(key, num_tokens, etype): +@pytest.mark.parametrize( + ("key", "num_tokens", "etype"), + [ + ("x", 0, ValueError), + ("x", -1, ValueError), + ("x", -2, ValueError), + ], +) +def test_input_validation_on_consume( + key: bytes, num_tokens: int, etype: Type[Exception] +): limiter = token_bucket.Limiter(1, 1, token_bucket.MemoryStorage()) with pytest.raises(etype): limiter.consume(key, num_tokens) diff --git a/tests/test_multithreading.py b/tests/test_multithreading.py index a6a57dd..46eeb67 100644 --- a/tests/test_multithreading.py +++ b/tests/test_multithreading.py @@ -1,14 +1,32 @@ +from collections import Counter +import datetime +import os import random import threading import time +from typing import Any, Callable, List import uuid +from freezegun import freeze_time +from freezegun.api import FrozenDateTimeFactory import pytest import token_bucket -def _run_threaded(func, num_threads): +def patched_freeze_time(): + f = freeze_time() + f.ignore = tuple(set(f.ignore) - {"threading"}) # pyright: ignore + return f + + +@pytest.fixture +def frozen_time(): + with patched_freeze_time() as ft: + yield ft + + +def _run_threaded(func: Callable[..., Any], num_threads: int): threads = [threading.Thread(target=func) for __ in range(num_threads)] for t in threads: @@ -20,21 +38,31 @@ def _run_threaded(func, num_threads): # NOTE(kgriffs): Don't try to remove more tokens than could ever # be available according to the bucket capacity. -@pytest.mark.parametrize('rate,capacity,max_tokens_to_consume', [ - (10, 1, 1), - (100, 1, 1), - (100, 2, 2), - (10, 10, 1), - (10, 10, 2), - (100, 10, 1), - (100, 10, 10), - (100, 100, 5), - (100, 100, 10), - (1000, 10, 1), - (1000, 10, 5), - (1000, 10, 10), -]) -def test_negative_count(rate, capacity, max_tokens_to_consume): +# Test this only in the CI. It is incredibly slow and so +# unlikely that you may never see it. +@pytest.mark.skipif(os.getenv("CI") != "true", reason="slow test") +@pytest.mark.parametrize( + ("rate", "capacity", "max_tokens_to_consume"), + [ + (10, 1, 1), + (100, 1, 1), + (100, 2, 2), + (10, 10, 1), + (10, 10, 2), + (100, 10, 1), + (100, 10, 10), + (100, 100, 5), + (100, 100, 10), + (1000, 10, 1), + (1000, 10, 5), + (1000, 10, 10), + ], +) +def test_negative_count( + rate: int, + capacity: int, + max_tokens_to_consume: int, +): # NOTE(kgriffs): Usually there will be a much larger number of # keys in a production system, but keep to just five to increase # the likelihood of collisions. @@ -43,7 +71,7 @@ def test_negative_count(rate, capacity, max_tokens_to_consume): storage = token_bucket.MemoryStorage() limiter = token_bucket.Limiter(rate, capacity, storage) - token_counts = [] + token_counts: List[float] = [] def loop(): for __ in range(1000): @@ -73,7 +101,7 @@ def loop(): assert (max_tokens_to_consume * -2) < min(negative_counts) -def test_replenishment(): +def test_burst_replenishment(frozen_time: FrozenDateTimeFactory): capacity = 100 rate = 100 num_threads = 4 @@ -81,29 +109,28 @@ def test_replenishment(): storage = token_bucket.MemoryStorage() - def loop(): + def consume(): for i in range(trials): - key = str(i) + key = bytes(i) + storage.replenish(key, rate, capacity) - for __ in range(int(capacity / num_threads)): - storage.replenish(key, rate, capacity) - time.sleep(1.0 / rate) - - _run_threaded(loop, num_threads) + for __ in range(capacity // num_threads): + _run_threaded(consume, num_threads) + frozen_time.tick(1.0 / rate) # NOTE(kgriffs): Ensure that a race condition did not result in # not all the tokens being replenished for i in range(trials): - key = str(i) + key = bytes(i) assert storage.get_token_count(key) == capacity -def test_conforming_ratio(): +def test_burst_conforming_ratio(frozen_time: FrozenDateTimeFactory): rate = 100 capacity = 10 - key = 'key' + key = b"key" target_ratio = 0.5 - ratio_max = 0.62 + max_ratio = 0.55 num_threads = 4 storage = token_bucket.MemoryStorage() @@ -111,34 +138,26 @@ def test_conforming_ratio(): # NOTE(kgriffs): Rather than using a lock to protect some counters, # rely on the GIL and count things up after the fact. - conforming_states = [] + conforming_states: Counter[bool] = Counter() # NOTE(kgriffs): Start with an empty bucket while limiter.consume(key): pass - def loop(): - # NOTE(kgriffs): Run for 10 seconds - for __ in range(int(rate * 10 / target_ratio / num_threads)): - conforming_states.append(limiter.consume(key)) + def consume(): + conforming_states.update([limiter.consume(key)]) - # NOTE(kgriffs): Only generate some of the tokens needed, so - # that some requests will end up being non-conforming. - time.sleep(1.0 / rate * target_ratio * num_threads) - - _run_threaded(loop, num_threads) + for __ in range(int(rate * 10 / target_ratio / num_threads)): + # NOTE(kgriffs): Only generate some of the tokens needed, so + # that some requests will end up being non-conforming. + sleep_in_seconds = 1.0 / rate * target_ratio * num_threads + frozen_time.tick(delta=datetime.timedelta(seconds=sleep_in_seconds)) - total_conforming = 0 - for c in conforming_states: - if c: - total_conforming += 1 + _run_threaded(consume, num_threads) - actual_ratio = float(total_conforming) / len(conforming_states) + actual_ratio = conforming_states[True] / len(list(conforming_states.elements())) - # NOTE(kgriffs): We don't expect to be super precise due to - # the inprecision of time.sleep() and also having to take into - # account execution time of the other instructions in the - # loop. We do expect a few more conforming states vs. non- - # conforming since the sleep time + overall execution time - # makes the threads run a little behind the replenishment rate. - assert target_ratio < actual_ratio < ratio_max + # NOTE: With a frozen time we should hit exactly. However, due to a tiny gap between + # replenish, frozen_time.tick and consume, it is possible that we have a little bit + # more than expected. You may see this only with PyPy. + assert target_ratio <= actual_ratio < max_ratio diff --git a/tests/test_version.py b/tests/test_version.py index 9ddf448..d33a659 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -6,7 +6,7 @@ def test_version(): assert isinstance(version, str) - numbers = version.split('.') + numbers = version.split(".") assert len(numbers) == 3 for n in numbers: # NOTE(kgriffs): Just check that these are ints by virtue diff --git a/tools/build.sh b/tools/build.sh index 1cabd0a..5b4ce26 100755 --- a/tools/build.sh +++ b/tools/build.sh @@ -33,7 +33,7 @@ _open_env() { pyenv shell $VENV_NAME pip install --upgrade pip - pip install --upgrade wheel twine + pip install --upgrade hatch } # Args: () @@ -77,9 +77,9 @@ pyenv uninstall -f $VENV_NAME #---------------------------------------------------------------------- _echo_task "Building source distribution" -_open_env 2.7.12 +_open_env 3.11.1 -python setup.py sdist -d $DIST_DIR +hatch build -t sdist $DIST_DIR _close_env @@ -88,8 +88,8 @@ _close_env #---------------------------------------------------------------------- _echo_task "Building universal wheel" -_open_env 2.7.12 +_open_env 3.11.1 -python setup.py bdist_wheel -d $DIST_DIR +hatch build -t wheel $DIST_DIR _close_env diff --git a/tools/publish.sh b/tools/publish.sh index 590e1a9..3e1ad06 100755 --- a/tools/publish.sh +++ b/tools/publish.sh @@ -3,5 +3,5 @@ DIST_DIR=./dist read -p "Sign and upload $DIST_DIR/* to PyPI? [y/N]: " CONTINUE if [[ $CONTINUE =~ ^[Yy]$ ]]; then - twine upload -s --skip-existing $DIST_DIR/* + hatch publish $DIST_DIR/* fi diff --git a/tox.ini b/tox.ini index 777799b..0d343ed 100644 --- a/tox.ini +++ b/tox.ini @@ -1,11 +1,22 @@ -[tox] -envlist = pep8, - py38, - coverage-report - [testenv] deps = -r{toxinidir}/requirements/tests -commands = coverage run -m pytest {posargs:tests} +package = editable +passenv = CI +commands = pytest \ + --cov=src \ + --cov-report=xml \ + --cov-report=term \ + --cov-config=pyproject.toml \ + {posargs:tests} + +# -------------------------------------------------------------------- +# Typing +# -------------------------------------------------------------------- + +[testenv:mypy] +skip_install = true +deps = mypy +commands = mypy {posargs:src} # -------------------------------------------------------------------- # Style @@ -14,6 +25,7 @@ commands = coverage run -m pytest {posargs:tests} [flake8] ; But do please try to stick to 80 unless it makes the code ugly max-line-length = 99 +inline-quotes = double max-complexity = 10 import-order-style = google application-import-names = token_bucket @@ -26,17 +38,6 @@ deps = flake8 flake8-import-order commands = flake8 {posargs:.} -# -------------------------------------------------------------------- -# Coverage -# -------------------------------------------------------------------- - -[testenv:coverage-report] -skip_install = true -commands = - coverage combine - coverage html -d .coverage_html - coverage report - # -------------------------------------------------------------------- # Documentation # --------------------------------------------------------------------