diff --git a/multidb/tests.py b/multidb/tests.py index 843d503..a1d62e7 100644 --- a/multidb/tests.py +++ b/multidb/tests.py @@ -1,26 +1,35 @@ import warnings from threading import Lock, Thread + from django.http import HttpRequest, HttpResponse from django.test import TestCase from django.test.utils import override_settings + try: from unittest import mock except ImportError: import mock -from nose.tools import eq_ # For deprecation tests import multidb import multidb.pinning - -from multidb import (DEFAULT_DB_ALIAS, ReplicaRouter, PinningReplicaRouter, - get_replica) -from multidb.middleware import (pinning_cookie, pinning_cookie_httponly, - pinning_cookie_samesite, pinning_cookie_secure, - pinning_seconds, PinningRouterMiddleware) -from multidb.pinning import (this_thread_is_pinned, pin_this_thread, - unpin_this_thread, use_primary_db, db_write) +from multidb import DEFAULT_DB_ALIAS, PinningReplicaRouter, ReplicaRouter, get_replica +from multidb.middleware import ( + PinningRouterMiddleware, + pinning_cookie, + pinning_cookie_httponly, + pinning_cookie_samesite, + pinning_cookie_secure, + pinning_seconds, +) +from multidb.pinning import ( + db_write, + pin_this_thread, + this_thread_is_pinned, + unpin_this_thread, + use_primary_db, +) class UnpinningTestCase(TestCase): @@ -31,13 +40,12 @@ def tearDown(self): class ReplicaRouterTests(TestCase): - def test_db_for_read(self): - eq_(ReplicaRouter().db_for_read(None), get_replica()) + self.assertEqual(ReplicaRouter().db_for_read(None), get_replica()) # TODO: Test the round-robin functionality. def test_db_for_write(self): - eq_(ReplicaRouter().db_for_write(None), DEFAULT_DB_ALIAS) + self.assertEqual(ReplicaRouter().db_for_write(None), DEFAULT_DB_ALIAS) def test_allow_syncdb(self): router = ReplicaRouter() @@ -46,8 +54,8 @@ def test_allow_syncdb(self): def test_allow_migrate(self): router = ReplicaRouter() - assert router.allow_migrate(DEFAULT_DB_ALIAS, 'dummy') - assert not router.allow_migrate(get_replica(), 'dummy') + assert router.allow_migrate(DEFAULT_DB_ALIAS, "dummy") + assert not router.allow_migrate(get_replica(), "dummy") class SettingsTests(TestCase): @@ -55,11 +63,11 @@ class SettingsTests(TestCase): def test_defaults(self): """Check that the cookie name has the right default.""" - eq_(pinning_cookie(), 'multidb_pin_writes') - eq_(pinning_seconds(), 15) - eq_(pinning_cookie_secure(), False) - eq_(pinning_cookie_httponly(), False) - eq_(pinning_cookie_samesite(), 'Lax') + self.assertEqual(pinning_cookie(), "multidb_pin_writes") + self.assertEqual(pinning_seconds(), 15) + self.assertEqual(pinning_cookie_secure(), False) + self.assertEqual(pinning_cookie_httponly(), False) + self.assertEqual(pinning_cookie_samesite(), "Lax") @override_settings(MULTIDB_PINNING_COOKIE="override_pin_writes") @override_settings(MULTIDB_PINNING_SECONDS=60) @@ -67,11 +75,11 @@ def test_defaults(self): @override_settings(MULTIDB_PINNING_COOKIE_HTTPONLY=True) @override_settings(MULTIDB_PINNING_COOKIE_SAMESITE="Strict") def test_overrides(self): - eq_(pinning_cookie(), "override_pin_writes") - eq_(pinning_seconds(), 60) - eq_(pinning_cookie_secure(), True) - eq_(pinning_cookie_httponly(), True) - eq_(pinning_cookie_samesite(), "Strict") + self.assertEqual(pinning_cookie(), "override_pin_writes") + self.assertEqual(pinning_seconds(), 60) + self.assertEqual(pinning_cookie_secure(), True) + self.assertEqual(pinning_cookie_httponly(), True) + self.assertEqual(pinning_cookie_samesite(), "Strict") class PinningTests(UnpinningTestCase): @@ -80,40 +88,40 @@ class PinningTests(UnpinningTestCase): def test_pinning_encapsulation(self): """Check the pinning getters and setters.""" - assert not this_thread_is_pinned(), \ - "Thread started out pinned or this_thread_is_pinned() is broken." + assert ( + not this_thread_is_pinned() + ), "Thread started out pinned or this_thread_is_pinned() is broken." pin_this_thread() - assert this_thread_is_pinned(), \ - "pin_this_thread() didn't pin the thread." + assert this_thread_is_pinned(), "pin_this_thread() didn't pin the thread." unpin_this_thread() - assert not this_thread_is_pinned(), \ - "Thread remained pinned after unpin_this_thread()." + assert ( + not this_thread_is_pinned() + ), "Thread remained pinned after unpin_this_thread()." def test_pinned_reads(self): """Test PinningReplicaRouter.db_for_read() when pinned and when not.""" router = PinningReplicaRouter() - eq_(router.db_for_read(None), get_replica()) + self.assertEqual(router.db_for_read(None), get_replica()) pin_this_thread() - eq_(router.db_for_read(None), DEFAULT_DB_ALIAS) + self.assertEqual(router.db_for_read(None), DEFAULT_DB_ALIAS) def test_db_write_decorator(self): - def read_view(req): - eq_(router.db_for_read(None), get_replica()) + self.assertEqual(router.db_for_read(None), get_replica()) return HttpResponse() @db_write def write_view(req): - eq_(router.db_for_read(None), DEFAULT_DB_ALIAS) + self.assertEqual(router.db_for_read(None), DEFAULT_DB_ALIAS) return HttpResponse() router = PinningReplicaRouter() - eq_(router.db_for_read(None), get_replica()) + self.assertEqual(router.db_for_read(None), get_replica()) write_view(HttpRequest()) read_view(HttpRequest()) @@ -134,43 +142,46 @@ def setUp(self): def test_pin_on_cookie(self): """Thread should pin when the cookie is set.""" - self.request.COOKIES[pinning_cookie()] = 'y' + self.request.COOKIES[pinning_cookie()] = "y" self.middleware.process_request(self.request) assert this_thread_is_pinned() def test_unpin_on_no_cookie(self): """Thread should unpin when cookie is absent and method is GET.""" pin_this_thread() - self.request.method = 'GET' + self.request.method = "GET" self.middleware.process_request(self.request) assert not this_thread_is_pinned() def test_pin_on_post(self): """Thread should pin when method is POST.""" - self.request.method = 'POST' + self.request.method = "POST" self.middleware.process_request(self.request) assert this_thread_is_pinned() def test_process_response(self): """Make sure the cookie gets set on POSTs but not GETs.""" - self.request.method = 'GET' - response = self.middleware.process_response( - self.request, HttpResponse()) + self.request.method = "GET" + response = self.middleware.process_response(self.request, HttpResponse()) assert pinning_cookie() not in response.cookies - self.request.method = 'POST' - response = self.middleware.process_response( - self.request, HttpResponse()) + self.request.method = "POST" + response = self.middleware.process_response(self.request, HttpResponse()) assert pinning_cookie() in response.cookies - eq_(response.cookies[pinning_cookie()]['max-age'], - pinning_seconds()) - eq_(response.cookies[pinning_cookie()]['samesite'], - pinning_cookie_samesite()) - eq_(response.cookies[pinning_cookie()]['httponly'], - pinning_cookie_httponly() or '') - eq_(response.cookies[pinning_cookie()]['secure'], - pinning_cookie_secure() or '') + self.assertEqual( + response.cookies[pinning_cookie()]["max-age"], pinning_seconds() + ) + self.assertEqual( + response.cookies[pinning_cookie()]["samesite"], pinning_cookie_samesite() + ) + self.assertEqual( + response.cookies[pinning_cookie()]["httponly"], + pinning_cookie_httponly() or "", + ) + self.assertEqual( + response.cookies[pinning_cookie()]["secure"], pinning_cookie_secure() or "" + ) def test_attribute(self): """The cookie should get set if the _db_write attribute is True.""" @@ -182,16 +193,18 @@ def test_attribute(self): def test_db_write_decorator(self): """The @db_write decorator should make any view set the cookie.""" req = self.request - req.method = 'GET' + req.method = "GET" def view(req): return HttpResponse() + response = self.middleware.process_response(req, view(req)) assert pinning_cookie() not in response.cookies @db_write def write_view(req): return HttpResponse() + response = self.middleware.process_response(req, write_view(req)) assert pinning_cookie() in response.cookies @@ -201,6 +214,7 @@ def test_decorator(self): @use_primary_db def check(): assert this_thread_is_pinned() + unpin_this_thread() assert not this_thread_is_pinned() check() @@ -210,6 +224,7 @@ def test_decorator_resets(self): @use_primary_db def check(): assert this_thread_is_pinned() + pin_this_thread() assert this_thread_is_pinned() check() @@ -289,7 +304,7 @@ def thread2_worker(): class DeprecationTestCase(TestCase): def test_masterslaverouter(self): with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + warnings.simplefilter("always") router = multidb.MasterSlaveRouter() assert isinstance(router, ReplicaRouter) assert len(w) == 1 @@ -297,26 +312,25 @@ def test_masterslaverouter(self): def test_pinningmasterslaverouter(self): with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + warnings.simplefilter("always") router = multidb.PinningMasterSlaveRouter() assert isinstance(router, PinningReplicaRouter) assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) - @mock.patch.object(multidb, 'get_replica') + @mock.patch.object(multidb, "get_replica") def test_get_slave(self, mock_get_replica): with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + warnings.simplefilter("always") multidb.get_slave() assert mock_get_replica.called assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) def test_use_master(self): - assert isinstance(multidb.pinning.use_master, - use_primary_db.__class__) + assert isinstance(multidb.pinning.use_master, use_primary_db.__class__) with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + warnings.simplefilter("always") with multidb.pinning.use_master: pass assert len(w) == 1 diff --git a/requirements.txt b/requirements.txt index d507e8c..5d20a81 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1 @@ -django-nose==1.4.7 flake8==6.0.0 diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..8dd399a --- /dev/null +++ b/setup.cfg @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 88 +extend-ignore = E203 diff --git a/test_settings.py b/test_settings.py index caa8dc0..f146104 100644 --- a/test_settings.py +++ b/test_settings.py @@ -1,8 +1,6 @@ # A Django settings module to support the tests SECRET_KEY = 'dummy' -# TEST_RUNNER = 'django_nose.runner.NoseTestSuiteRunner' - MIDDLEWARE_CLASSES = ( 'django.middleware.common.CommonMiddleware', 'django.middleware.csrf.CsrfViewMiddleware'