From b74ed1824cfd2b6a3bf1c42fba1a042da21f234f Mon Sep 17 00:00:00 2001 From: Michael Grosse Huelsewiesche Date: Fri, 6 Oct 2023 14:35:28 -0400 Subject: [PATCH] Python OAuth implementation (#320) * Working OAuth implementation * Logging improvements and change for 429 handling * Cleanup and restore unit tests to functioning order --- analytics/consumer.py | 133 ----------- analytics/test/client.py | 342 ----------------------------- requirements.txt | 1 + samples/oauth.py | 52 +++++ segment/analytics/__init__.py | 16 +- segment/analytics/client.py | 31 ++- segment/analytics/consumer.py | 15 +- segment/analytics/oauth_manager.py | 208 ++++++++++++++++++ segment/analytics/request.py | 29 ++- segment/analytics/test/__init__.py | 17 +- segment/analytics/test/client.py | 9 +- segment/analytics/test/consumer.py | 12 +- segment/analytics/test/module.py | 2 +- segment/analytics/test/oauth.py | 155 +++++++++++++ segment/analytics/test/request.py | 2 +- segment/analytics/test/utils.py | 2 +- setup.py | 4 +- 17 files changed, 520 insertions(+), 510 deletions(-) delete mode 100644 analytics/consumer.py delete mode 100644 analytics/test/client.py create mode 100644 samples/oauth.py create mode 100644 segment/analytics/oauth_manager.py create mode 100644 segment/analytics/test/oauth.py diff --git a/analytics/consumer.py b/analytics/consumer.py deleted file mode 100644 index 33b9c26c..00000000 --- a/analytics/consumer.py +++ /dev/null @@ -1,133 +0,0 @@ -import logging -from threading import Thread -import json -import monotonic -import backoff - - -from analytics.request import post, APIError, DatetimeSerializer - -from queue import Empty - -MAX_MSG_SIZE = 32 << 10 - -# Our servers only accept batches less than 500KB. Here limit is set slightly -# lower to leave space for extra data that will be added later, eg. "sentAt". -BATCH_SIZE_LIMIT = 475000 - - -class Consumer(Thread): - """Consumes the messages from the client's queue.""" - log = logging.getLogger('segment') - - def __init__(self, queue, write_key, upload_size=100, host=None, - on_error=None, upload_interval=0.5, gzip=False, retries=10, - timeout=15, proxies=None): - """Create a consumer thread.""" - Thread.__init__(self) - # Make consumer a daemon thread so that it doesn't block program exit - self.daemon = True - self.upload_size = upload_size - self.upload_interval = upload_interval - self.write_key = write_key - self.host = host - self.on_error = on_error - self.queue = queue - self.gzip = gzip - # It's important to set running in the constructor: if we are asked to - # pause immediately after construction, we might set running to True in - # run() *after* we set it to False in pause... and keep running - # forever. - self.running = True - self.retries = retries - self.timeout = timeout - self.proxies = proxies - - def run(self): - """Runs the consumer.""" - self.log.debug('consumer is running...') - while self.running: - self.upload() - - self.log.debug('consumer exited.') - - def pause(self): - """Pause the consumer.""" - self.running = False - - def upload(self): - """Upload the next batch of items, return whether successful.""" - success = False - batch = self.next() - if len(batch) == 0: - return False - - try: - self.request(batch) - success = True - except Exception as e: - self.log.error('error uploading: %s', e) - success = False - if self.on_error: - self.on_error(e, batch) - finally: - # mark items as acknowledged from queue - for _ in batch: - self.queue.task_done() - return success - - def next(self): - """Return the next batch of items to upload.""" - queue = self.queue - items = [] - - start_time = monotonic.monotonic() - total_size = 0 - - while len(items) < self.upload_size: - elapsed = monotonic.monotonic() - start_time - if elapsed >= self.upload_interval: - break - try: - item = queue.get( - block=True, timeout=self.upload_interval - elapsed) - item_size = len(json.dumps( - item, cls=DatetimeSerializer).encode()) - if item_size > MAX_MSG_SIZE: - self.log.error( - 'Item exceeds 32kb limit, dropping. (%s)', str(item)) - continue - items.append(item) - total_size += item_size - if total_size >= BATCH_SIZE_LIMIT: - self.log.debug( - 'hit batch size limit (size: %d)', total_size) - break - except Empty: - break - - return items - - def request(self, batch): - """Attempt to upload the batch and retry before raising an error """ - - def fatal_exception(exc): - if isinstance(exc, APIError): - # retry on server errors and client errors - # with 429 status code (rate limited), - # don't retry on other client errors - return (400 <= exc.status < 500) and exc.status != 429 - else: - # retry on all other errors (eg. network) - return False - - @backoff.on_exception( - backoff.expo, - Exception, - max_tries=self.retries + 1, - giveup=fatal_exception) - def send_request(): - post(self.write_key, self.host, gzip=self.gzip, - timeout=self.timeout, batch=batch, proxies=self.proxies) - - send_request() diff --git a/analytics/test/client.py b/analytics/test/client.py deleted file mode 100644 index fcfbc0eb..00000000 --- a/analytics/test/client.py +++ /dev/null @@ -1,342 +0,0 @@ -from datetime import date, datetime -import unittest -import time -import mock - -from analytics.version import VERSION -from analytics.client import Client - - -class TestClient(unittest.TestCase): - - def fail(self): - """Mark the failure handler""" - self.failed = True - - def setUp(self): - self.failed = False - self.client = Client('testsecret', on_error=self.fail) - - def test_requires_write_key(self): - self.assertRaises(AssertionError, Client) - - def test_empty_flush(self): - self.client.flush() - - def test_basic_track(self): - client = self.client - success, msg = client.track('userId', 'python test event') - client.flush() - self.assertTrue(success) - self.assertFalse(self.failed) - - self.assertEqual(msg['event'], 'python test event') - self.assertTrue(isinstance(msg['timestamp'], str)) - self.assertTrue(isinstance(msg['messageId'], str)) - self.assertEqual(msg['userId'], 'userId') - self.assertEqual(msg['properties'], {}) - self.assertEqual(msg['type'], 'track') - - def test_stringifies_user_id(self): - # A large number that loses precision in node: - # node -e "console.log(157963456373623802 + 1)" > 157963456373623800 - client = self.client - success, msg = client.track( - user_id=157963456373623802, event='python test event') - client.flush() - self.assertTrue(success) - self.assertFalse(self.failed) - - self.assertEqual(msg['userId'], '157963456373623802') - self.assertEqual(msg['anonymousId'], None) - - def test_stringifies_anonymous_id(self): - # A large number that loses precision in node: - # node -e "console.log(157963456373623803 + 1)" > 157963456373623800 - client = self.client - success, msg = client.track( - anonymous_id=157963456373623803, event='python test event') - client.flush() - self.assertTrue(success) - self.assertFalse(self.failed) - - self.assertEqual(msg['userId'], None) - self.assertEqual(msg['anonymousId'], '157963456373623803') - - def test_advanced_track(self): - client = self.client - success, msg = client.track( - 'userId', 'python test event', {'property': 'value'}, - {'ip': '192.168.0.1'}, datetime(2014, 9, 3), 'anonymousId', - {'Amplitude': True}, 'messageId') - - self.assertTrue(success) - - self.assertEqual(msg['timestamp'], '2014-09-03T00:00:00+00:00') - self.assertEqual(msg['properties'], {'property': 'value'}) - self.assertEqual(msg['integrations'], {'Amplitude': True}) - self.assertEqual(msg['context']['ip'], '192.168.0.1') - self.assertEqual(msg['event'], 'python test event') - self.assertEqual(msg['anonymousId'], 'anonymousId') - self.assertEqual(msg['context']['library'], { - 'name': 'analytics-python', - 'version': VERSION - }) - self.assertEqual(msg['messageId'], 'messageId') - self.assertEqual(msg['userId'], 'userId') - self.assertEqual(msg['type'], 'track') - - def test_basic_identify(self): - client = self.client - success, msg = client.identify('userId', {'trait': 'value'}) - client.flush() - self.assertTrue(success) - self.assertFalse(self.failed) - - self.assertEqual(msg['traits'], {'trait': 'value'}) - self.assertTrue(isinstance(msg['timestamp'], str)) - self.assertTrue(isinstance(msg['messageId'], str)) - self.assertEqual(msg['userId'], 'userId') - self.assertEqual(msg['type'], 'identify') - - def test_advanced_identify(self): - client = self.client - success, msg = client.identify( - 'userId', {'trait': 'value'}, {'ip': '192.168.0.1'}, - datetime(2014, 9, 3), 'anonymousId', {'Amplitude': True}, - 'messageId') - - self.assertTrue(success) - - self.assertEqual(msg['timestamp'], '2014-09-03T00:00:00+00:00') - self.assertEqual(msg['integrations'], {'Amplitude': True}) - self.assertEqual(msg['context']['ip'], '192.168.0.1') - self.assertEqual(msg['traits'], {'trait': 'value'}) - self.assertEqual(msg['anonymousId'], 'anonymousId') - self.assertEqual(msg['context']['library'], { - 'name': 'analytics-python', - 'version': VERSION - }) - self.assertTrue(isinstance(msg['timestamp'], str)) - self.assertEqual(msg['messageId'], 'messageId') - self.assertEqual(msg['userId'], 'userId') - self.assertEqual(msg['type'], 'identify') - - def test_basic_group(self): - client = self.client - success, msg = client.group('userId', 'groupId') - client.flush() - self.assertTrue(success) - self.assertFalse(self.failed) - - self.assertEqual(msg['groupId'], 'groupId') - self.assertEqual(msg['userId'], 'userId') - self.assertEqual(msg['type'], 'group') - - def test_advanced_group(self): - client = self.client - success, msg = client.group( - 'userId', 'groupId', {'trait': 'value'}, {'ip': '192.168.0.1'}, - datetime(2014, 9, 3), 'anonymousId', {'Amplitude': True}, - 'messageId') - - self.assertTrue(success) - - self.assertEqual(msg['timestamp'], '2014-09-03T00:00:00+00:00') - self.assertEqual(msg['integrations'], {'Amplitude': True}) - self.assertEqual(msg['context']['ip'], '192.168.0.1') - self.assertEqual(msg['traits'], {'trait': 'value'}) - self.assertEqual(msg['anonymousId'], 'anonymousId') - self.assertEqual(msg['context']['library'], { - 'name': 'analytics-python', - 'version': VERSION - }) - self.assertTrue(isinstance(msg['timestamp'], str)) - self.assertEqual(msg['messageId'], 'messageId') - self.assertEqual(msg['userId'], 'userId') - self.assertEqual(msg['type'], 'group') - - def test_basic_alias(self): - client = self.client - success, msg = client.alias('previousId', 'userId') - client.flush() - self.assertTrue(success) - self.assertFalse(self.failed) - self.assertEqual(msg['previousId'], 'previousId') - self.assertEqual(msg['userId'], 'userId') - - def test_basic_page(self): - client = self.client - success, msg = client.page('userId', name='name') - self.assertFalse(self.failed) - client.flush() - self.assertTrue(success) - self.assertEqual(msg['userId'], 'userId') - self.assertEqual(msg['type'], 'page') - self.assertEqual(msg['name'], 'name') - - def test_advanced_page(self): - client = self.client - success, msg = client.page( - 'userId', 'category', 'name', {'property': 'value'}, - {'ip': '192.168.0.1'}, datetime(2014, 9, 3), 'anonymousId', - {'Amplitude': True}, 'messageId') - - self.assertTrue(success) - - self.assertEqual(msg['timestamp'], '2014-09-03T00:00:00+00:00') - self.assertEqual(msg['integrations'], {'Amplitude': True}) - self.assertEqual(msg['context']['ip'], '192.168.0.1') - self.assertEqual(msg['properties'], {'property': 'value'}) - self.assertEqual(msg['anonymousId'], 'anonymousId') - self.assertEqual(msg['context']['library'], { - 'name': 'analytics-python', - 'version': VERSION - }) - self.assertEqual(msg['category'], 'category') - self.assertTrue(isinstance(msg['timestamp'], str)) - self.assertEqual(msg['messageId'], 'messageId') - self.assertEqual(msg['userId'], 'userId') - self.assertEqual(msg['type'], 'page') - self.assertEqual(msg['name'], 'name') - - def test_basic_screen(self): - client = self.client - success, msg = client.screen('userId', name='name') - client.flush() - self.assertTrue(success) - self.assertEqual(msg['userId'], 'userId') - self.assertEqual(msg['type'], 'screen') - self.assertEqual(msg['name'], 'name') - - def test_advanced_screen(self): - client = self.client - success, msg = client.screen( - 'userId', 'category', 'name', {'property': 'value'}, - {'ip': '192.168.0.1'}, datetime(2014, 9, 3), 'anonymousId', - {'Amplitude': True}, 'messageId') - - self.assertTrue(success) - - self.assertEqual(msg['timestamp'], '2014-09-03T00:00:00+00:00') - self.assertEqual(msg['integrations'], {'Amplitude': True}) - self.assertEqual(msg['context']['ip'], '192.168.0.1') - self.assertEqual(msg['properties'], {'property': 'value'}) - self.assertEqual(msg['anonymousId'], 'anonymousId') - self.assertEqual(msg['context']['library'], { - 'name': 'analytics-python', - 'version': VERSION - }) - self.assertTrue(isinstance(msg['timestamp'], str)) - self.assertEqual(msg['messageId'], 'messageId') - self.assertEqual(msg['category'], 'category') - self.assertEqual(msg['userId'], 'userId') - self.assertEqual(msg['type'], 'screen') - self.assertEqual(msg['name'], 'name') - - def test_flush(self): - client = self.client - # set up the consumer with more requests than a single batch will allow - for _ in range(1000): - _, _ = client.identify('userId', {'trait': 'value'}) - # We can't reliably assert that the queue is non-empty here; that's - # a race condition. We do our best to load it up though. - client.flush() - # Make sure that the client queue is empty after flushing - self.assertTrue(client.queue.empty()) - - def test_shutdown(self): - client = self.client - # set up the consumer with more requests than a single batch will allow - for _ in range(1000): - _, _ = client.identify('userId', {'trait': 'value'}) - client.shutdown() - # we expect two things after shutdown: - # 1. client queue is empty - # 2. consumer thread has stopped - self.assertTrue(client.queue.empty()) - for consumer in client.consumers: - self.assertFalse(consumer.is_alive()) - - def test_synchronous(self): - client = Client('testsecret', sync_mode=True) - - success, _ = client.identify('userId') - self.assertFalse(client.consumers) - self.assertTrue(client.queue.empty()) - self.assertTrue(success) - - def test_overflow(self): - client = Client('testsecret', max_queue_size=1) - # Ensure consumer thread is no longer uploading - client.join() - - for _ in range(10): - client.identify('userId') - - success, _ = client.identify('userId') - # Make sure we are informed that the queue is at capacity - self.assertFalse(success) - - def test_success_on_invalid_write_key(self): - client = Client('bad_key', on_error=self.fail) - client.track('userId', 'event') - client.flush() - self.assertFalse(self.failed) - - def test_numeric_user_id(self): - self.client.track(1234, 'python event') - self.client.flush() - self.assertFalse(self.failed) - - def test_identify_with_date_object(self): - client = self.client - success, msg = client.identify( - 'userId', - { - 'birthdate': date(1981, 2, 2), - }, - ) - client.flush() - self.assertTrue(success) - self.assertFalse(self.failed) - - self.assertEqual(msg['traits'], {'birthdate': date(1981, 2, 2)}) - - def test_gzip(self): - client = Client('testsecret', on_error=self.fail, gzip=True) - for _ in range(10): - client.identify('userId', {'trait': 'value'}) - client.flush() - self.assertFalse(self.failed) - - def test_user_defined_upload_size(self): - client = Client('testsecret', on_error=self.fail, - upload_size=10, upload_interval=3) - - def mock_post_fn(**kwargs): - self.assertEqual(len(kwargs['batch']), 10) - - # the post function should be called 2 times, with a batch size of 10 - # each time. - with mock.patch('analytics.consumer.post', side_effect=mock_post_fn) \ - as mock_post: - for _ in range(20): - client.identify('userId', {'trait': 'value'}) - time.sleep(1) - self.assertEqual(mock_post.call_count, 2) - - def test_user_defined_timeout(self): - client = Client('testsecret', timeout=10) - for consumer in client.consumers: - self.assertEqual(consumer.timeout, 10) - - def test_default_timeout_15(self): - client = Client('testsecret') - for consumer in client.consumers: - self.assertEqual(consumer.timeout, 15) - - def test_proxies(self): - client = Client('testsecret', proxies='203.243.63.16:80') - success, msg = client.identify('userId', {'trait': 'value'}) - self.assertTrue(success) diff --git a/requirements.txt b/requirements.txt index 4888d8ed..5eaf3fcf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -52,6 +52,7 @@ ptyprocess==0.7.0 pycodestyle==2.5.0 pyflakes==2.1.1 Pygments==2.16.1 +PyJWT==2.8.0 pylint==1.9.3 pyparsing==3.1.1 python-dateutil==2.8.2 diff --git a/samples/oauth.py b/samples/oauth.py new file mode 100644 index 00000000..bf1cc70d --- /dev/null +++ b/samples/oauth.py @@ -0,0 +1,52 @@ +import sys +import os +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) +import time +import segment.analytics as analytics + +privatekey = '''-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDVll7uJaH322IN +PQsH2aOXZJ2r1q+6hpVK1R5JV1p41PUzn8pOxyXFHWB+53dUd4B8qywKS36XQjp0 +VmhR1tQ22znQ9ZCM6y4LGeOJBjAZiFZLcGQNNrDFC0WGWTrK1ZTS2K7p5qy4fIXG +laNkMXiGGCawkgcHAdOvPTy8m1d9a6YSetYVmBP/tEYN95jPyZFIoHQfkQPBPr9W +cWPpdEBzasHV5d957akjurPpleDiD5as66UW4dkWXvS7Wu7teCLCyDApcyJKTb2Z +SXybmWjhIZuctZMAx3wT/GgW3FbkGaW5KLQgBUMzjpL0fCtMatlqckMD92ll1FuK +R+HnXu05AgMBAAECggEBAK4o2il4GDUh9zbyQo9ZIPLuwT6AZXRED3Igi3ykNQp4 +I6S/s9g+vQaY6LkyBnSiqOt/K/8NBiFSiJWaa5/n+8zrP56qzf6KOlYk+wsdN5Vq +PWtwLrUzljpl8YAWPEFunNa8hwwE42vfZbnDBKNLT4qQIOQzfnVxQOoQlfj49gM2 +iSrblvsnQTyucFy3UyTeioHbh8q2Xqcxry5WUCOrFDd3IIwATTpLZGw0IPeuFJbJ +NfBizLEcyJaM9hujQU8PRCRd16MWX+bbYM6Mh4dkT40QXWsVnHBHwsgPdQgDgseF +Na4ajtHoC0DlwYCXpCm3IzJfKfq/LR2q8NDUgKeF4AECgYEA9nD4czza3SRbzhpZ +bBoK77CSNqCcMAqyuHB0hp/XX3yB7flF9PIPb2ReO8wwmjbxn+bm7PPz2Uwd2SzO +pU+FXmyKJr53Jxw/hmDWZCoh42gsGDlVqpmytzsj74KlaYiMyZmEGbD7t/FGfNGV +LdLDJaHIYxEimFviOTXKCeKvPAECgYEA3d8tv4jdp1uAuRZiU9Z/tfw5mJOi3oXF +8AdFFDwaPzcTorEAxjrt9X6IjPbLIDJNJtuXYpe+dG6720KyuNnhLhWW9oZEJTwT +dUgqZ2fTCOS9uH0jSn+ZFlgTWI6UDQXRwE7z8avlhMIrQVmPsttGTo7V6sQVtGRx +bNj2RSVekTkCgYAJvy4UYLPHS0jWPfSLcfw8vp8JyhBjVgj7gncZW/kIrcP1xYYe +yfQSU8XmV40UjFfCGz/G318lmP0VOdByeVKtCV3talsMEPHyPqI8E+6DL/uOebYJ +qUqINK6XKnOgWOY4kvnGillqTQCcry1XQp61PlDOmj7kB75KxPXYrj6AAQKBgQDa ++ixCv6hURuEyy77cE/YT/Q4zYnL6wHjtP5+UKwWUop1EkwG6o+q7wtiul90+t6ah +1VUCP9X/QFM0Qg32l0PBohlO0pFrVnG17TW8vSHxwyDkds1f97N19BOT8ZR5jebI +sKPfP9LVRnY+l1BWLEilvB+xBzqMwh2YWkIlWI6PMQKBgGi6TBnxp81lOYrxVRDj +/3ycRnVDmBdlQKFunvfzUBmG1mG/G0YHeVSUKZJGX7w2l+jnDwIA383FcUeA8X6A +l9q+amhtkwD/6fbkAu/xoWNl+11IFoxd88y2ByBFoEKB6UVLuCTSKwXDqzEZet7x +mDyRxq7ohIzLkw8b8buDeuXZ +-----END PRIVATE KEY----- +''' # Should be read from a file on disk which can be rotated out + +analytics.write_key = '' + +analytics.oauth_client_id = 'CLIENT_ID' # OAuth application ID from segment dashboard +analytics.oauth_client_key = privatekey # generated as a public/private key pair in PEM format from OpenSSL +analytics.oauth_key_id = 'KEY_ID' # From segment dashboard after uploading public key +analytics.oauth_scope = 'tracking_api:write' #'public_api:read_write' + +def on_error(error, items): + print("An error occurred: ", error) +analytics.debug = True +analytics.on_error = on_error + +analytics.track('AUser', 'track') +analytics.flush() + +time.sleep(3) \ No newline at end of file diff --git a/segment/analytics/__init__.py b/segment/analytics/__init__.py index 230769b5..ef35f67f 100644 --- a/segment/analytics/__init__.py +++ b/segment/analytics/__init__.py @@ -9,6 +9,7 @@ host = Client.DefaultConfig.host on_error = Client.DefaultConfig.on_error debug = Client.DefaultConfig.debug +log_handler = Client.DefaultConfig.log_handler send = Client.DefaultConfig.send sync_mode = Client.DefaultConfig.sync_mode max_queue_size = Client.DefaultConfig.max_queue_size @@ -16,6 +17,13 @@ timeout = Client.DefaultConfig.timeout max_retries = Client.DefaultConfig.max_retries +"""Oauth Settings.""" +oauth_client_id = Client.DefaultConfig.oauth_client_id +oauth_client_key = Client.DefaultConfig.oauth_client_key +oauth_key_id = Client.DefaultConfig.oauth_key_id +oauth_auth_server = Client.DefaultConfig.oauth_auth_server +oauth_scope = Client.DefaultConfig.oauth_scope + default_client = None @@ -73,7 +81,13 @@ def _proxy(method, *args, **kwargs): max_queue_size=max_queue_size, send=send, on_error=on_error, gzip=gzip, max_retries=max_retries, - sync_mode=sync_mode, timeout=timeout) + sync_mode=sync_mode, timeout=timeout, + oauth_client_id=oauth_client_id, + oauth_client_key=oauth_client_key, + oauth_key_id=oauth_key_id, + oauth_auth_server=oauth_auth_server, + oauth_scope=oauth_scope, + ) fn = getattr(default_client, method) return fn(*args, **kwargs) diff --git a/segment/analytics/client.py b/segment/analytics/client.py index 515da899..1d4e35da 100644 --- a/segment/analytics/client.py +++ b/segment/analytics/client.py @@ -6,6 +6,7 @@ import json from dateutil.tz import tzutc +from segment.analytics.oauth_manager import OauthManager from segment.analytics.utils import guess_timezone, clean from segment.analytics.consumer import Consumer, MAX_MSG_SIZE @@ -23,6 +24,7 @@ class DefaultConfig(object): host = None on_error = None debug = False + log_handler = None send = True sync_mode = False max_queue_size = 10000 @@ -33,6 +35,12 @@ class DefaultConfig(object): thread = 1 upload_interval = 0.5 upload_size = 100 + oauth_client_id = None + oauth_client_key = None + oauth_key_id = None + oauth_auth_server = 'https://oauth2.segment.io' + oauth_scope = 'tracking_api:write' + """Create a new Segment client.""" log = logging.getLogger('segment') @@ -51,7 +59,13 @@ def __init__(self, proxies=DefaultConfig.proxies, thread=DefaultConfig.thread, upload_size=DefaultConfig.upload_size, - upload_interval=DefaultConfig.upload_interval,): + upload_interval=DefaultConfig.upload_interval, + log_handler=DefaultConfig.log_handler, + oauth_client_id=DefaultConfig.oauth_client_id, + oauth_client_key=DefaultConfig.oauth_client_key, + oauth_key_id=DefaultConfig.oauth_key_id, + oauth_auth_server=DefaultConfig.oauth_auth_server, + oauth_scope=DefaultConfig.oauth_scope,): require('write_key', write_key, str) self.queue = queue.Queue(max_queue_size) @@ -64,9 +78,19 @@ def __init__(self, self.gzip = gzip self.timeout = timeout self.proxies = proxies + self.oauth_manager = None + if(oauth_client_id and oauth_client_key and oauth_key_id): + self.oauth_manager = OauthManager(oauth_client_id, oauth_client_key, oauth_key_id, + oauth_auth_server, oauth_scope, timeout, max_retries) + + if log_handler: + self.log.addHandler(log_handler) if debug: self.log.setLevel(logging.DEBUG) + if not log_handler: + # default log handler does not print debug or info + self.log.addHandler(logging.StreamHandler()) if sync_mode: self.consumers = None @@ -85,7 +109,7 @@ def __init__(self, self.queue, write_key, host=host, on_error=on_error, upload_size=upload_size, upload_interval=upload_interval, gzip=gzip, retries=max_retries, timeout=timeout, - proxies=proxies, + proxies=proxies, oauth_manager=self.oauth_manager, ) self.consumers.append(consumer) @@ -280,7 +304,8 @@ def _enqueue(self, msg): if self.sync_mode: self.log.debug('enqueued with blocking %s.', msg['type']) post(self.write_key, self.host, gzip=self.gzip, - timeout=self.timeout, proxies=self.proxies, batch=[msg]) + timeout=self.timeout, proxies=self.proxies, + oauth_manager=self.oauth_manager, batch=[msg]) return True, msg diff --git a/segment/analytics/consumer.py b/segment/analytics/consumer.py index 27586284..a78f2d34 100644 --- a/segment/analytics/consumer.py +++ b/segment/analytics/consumer.py @@ -14,6 +14,13 @@ # lower to leave space for extra data that will be added later, eg. "sentAt". BATCH_SIZE_LIMIT = 475000 +class FatalError(Exception): + def __init__(self, message): + self.message = message + def __str__(self): + msg = "[Segment] {0})" + return msg.format(self.message) + class Consumer(Thread): """Consumes the messages from the client's queue.""" @@ -21,7 +28,7 @@ class Consumer(Thread): def __init__(self, queue, write_key, upload_size=100, host=None, on_error=None, upload_interval=0.5, gzip=False, retries=10, - timeout=15, proxies=None): + timeout=15, proxies=None, oauth_manager=None): """Create a consumer thread.""" Thread.__init__(self) # Make consumer a daemon thread so that it doesn't block program exit @@ -41,6 +48,7 @@ def __init__(self, queue, write_key, upload_size=100, host=None, self.retries = retries self.timeout = timeout self.proxies = proxies + self.oauth_manager = oauth_manager def run(self): """Runs the consumer.""" @@ -118,6 +126,8 @@ def fatal_exception(exc): # with 429 status code (rate limited), # don't retry on other client errors return (400 <= exc.status < 500) and exc.status != 429 + elif isinstance(exc, FatalError): + return True else: # retry on all other errors (eg. network) return False @@ -129,6 +139,7 @@ def fatal_exception(exc): giveup=fatal_exception) def send_request(): post(self.write_key, self.host, gzip=self.gzip, - timeout=self.timeout, batch=batch, proxies=self.proxies) + timeout=self.timeout, batch=batch, proxies=self.proxies, + oauth_manager=self.oauth_manager) send_request() diff --git a/segment/analytics/oauth_manager.py b/segment/analytics/oauth_manager.py new file mode 100644 index 00000000..453a23a0 --- /dev/null +++ b/segment/analytics/oauth_manager.py @@ -0,0 +1,208 @@ +from datetime import date, datetime +import logging +import threading +import time +import uuid +from requests import sessions +import jwt + +from segment.analytics import utils +from segment.analytics.request import APIError +from segment.analytics.consumer import FatalError + +_session = sessions.Session() + +class OauthManager(object): + def __init__(self, + client_id, + client_key, + key_id, + auth_server='https://oauth2.segment.io', + scope='tracking_api:write', + timeout=15, + max_retries=3): + self.client_id = client_id + self.client_key = client_key + self.key_id = key_id + self.auth_server = auth_server + self.scope = scope + self.timeout = timeout + self.max_retries = max_retries + self.retry_count = 0 + self.clock_skew = 0 + + self.log = logging.getLogger('segment') + self.thread = None + self.token_mutex = threading.Lock() + self.token = None + self.error = None + + def get_token(self): + with self.token_mutex: + if self.token: + return self.token + # No good token, start the loop + self.log.debug("OAuth is enabled. No cached access token.") + # Make sure we're not waiting an excessively long time (this will not cancel 429 waits) + if self.thread and self.thread.is_alive(): + self.thread.cancel() + self.thread = threading.Timer(0,self._poller_loop) + self.thread.daemon = True + self.thread.start() + + while True: + # Wait for a token or error + with self.token_mutex: + if self.token: + return self.token + if self.error: + error = self.error + self.error = None + raise error + if self.thread: + # Wait for a cycle, may not have an answer immediately + self.thread.join(1) + + def clear_token(self): + self.log.debug("OAuth Token invalidated. Poller for new token is {}".format( + "active" if self.thread and self.thread.is_alive() else "stopped" )) + with self.token_mutex: + self.token = None + + def _request_token(self): + jwt_body = { + "iss": self.client_id, + "sub": self.client_id, + "aud": utils.remove_trailing_slash(self.auth_server), + "iat": int(time.time())-5 - self.clock_skew, + "exp": int(time.time()) + 55 - self.clock_skew, + "jti": str(uuid.uuid4()) + } + + signed_jwt = jwt.encode( + jwt_body, + self.client_key, + algorithm="RS256", + headers={"kid": self.key_id}, + ) + + request_body = 'grant_type=client_credentials&client_assertion_type='\ + 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer&'\ + 'client_assertion={}&scope={}'.format(signed_jwt, self.scope) + + token_endpoint = f'{utils.remove_trailing_slash(self.auth_server)}/token' + + self.log.debug("OAuth token requested from {} with size {}".format(token_endpoint, len(request_body))) + + res = _session.post(url=token_endpoint, data=request_body, timeout=self.timeout, + headers={'Content-Type': 'application/x-www-form-urlencoded'}) + return res + + def _poller_loop(self): + refresh_timer_ms = 25 + response = None + + try: + response = self._request_token() + except Exception as e: + self.retry_count += 1 + if self.retry_count < self.max_retries: + self.log.debug("OAuth token request encountered an error on attempt {}: {}".format(self.retry_count ,e)) + self.thread = threading.Timer(refresh_timer_ms / 1000.0, self._poller_loop) + self.thread.daemon = True + self.thread.start() + return + # Too many retries, giving up + self.log.error("OAuth token request encountered an error after {} attempts: {}".format(self.retry_count ,e)) + self.error = FatalError(str(e)) + return + if response.headers.get("Date"): + try: + server_time = datetime.strptime(response.headers.get("Date"), "%a, %d %b %Y %H:%M:%S %Z") + self.clock_skew = int((datetime.utcnow() - server_time).total_seconds()) + except Exception as e: + self.log.error("OAuth token request received a response with an invalid Date header: {} | {}".format(response, e)) + + if response.status_code == 200: + data = None + try: + data = response.json() + except Exception as e: + self.retry_count += 1 + if self.retry_count < self.max_retries: + self.thread = threading.Timer(refresh_timer_ms / 1000.0, self._poller_loop) + self.thread.daemon = True + self.thread.start() + return + # Too many retries, giving up + self.error = e + return + try: + with self.token_mutex: + self.token = data['access_token'] + # success! + self.retry_count = 0 + except Exception as e: + # No access token in response? + self.log.error("OAuth token request received a successful response with a missing token: {}".format(response)) + try: + refresh_timer_ms = int(data['expires_in']) / 2 * 1000 + except Exception as e: + refresh_timer_ms = 60 * 1000 + + elif response.status_code == 429: + self.retry_count += 1 + rate_limit_reset_timestamp = None + try: + rate_limit_reset_timestamp = int(response.headers.get("X-RateLimit-Reset")) + except Exception as e: + self.log.error("OAuth rate limit response did not have a valid rest time: {} | {}".format(response, e)) + if rate_limit_reset_timestamp: + refresh_timer_ms = rate_limit_reset_timestamp - time.time() * 1000 + else: + refresh_timer_ms = 5 * 1000 + + self.log.debug("OAuth token request encountered a rate limit response, waiting {} ms".format(refresh_timer_ms)) + # We want subsequent calls to get_token to be able to interrupt our + # Timeout when it's waiting for e.g. a long normal expiration, but + # not when we're waiting for a rate limit reset. Sleep instead. + time.sleep(refresh_timer_ms / 1000.0) + refresh_timer_ms = 0 + elif response.status_code in [400, 401, 415]: + # unrecoverable errors (except for skew). APIError will be handled by request.py + self.retry_count = 0 + try: + payload = response.json() + + if (payload.get('error') == 'invalid_request' and + (payload.get('error_description') == 'Token is expired' or + payload.get('error_description') == 'Token used before issued')): + refresh_timer_ms = 0 # Retry immediately hopefully with a good skew value + self.thread = threading.Timer(refresh_timer_ms / 1000.0, self._poller_loop) + self.thread.daemon = True + self.thread.start() + return + + self.error = APIError(response.status_code, payload['error'], payload['error_description']) + except ValueError: + self.error = APIError(response.status_code, 'unknown', response.text) + self.log.error("OAuth token request error was unrecoverable, possibly due to configuration: {}".format(self.error)) + return + else: + # any other error + self.log.debug("OAuth token request error, attempt {}: [{}] {}".format(self.retry_count, response.status_code, response.reason)) + self.retry_count += 1 + + if self.retry_count > 0 and self.retry_count % self.max_retries == 0: + # every time we pass the retry count, put up an error to release any waiting token requests + try: + payload = response.json() + self.error = APIError(response.status_code, payload['error'], payload['error_description']) + except ValueError: + self.error = APIError(response.status_code, 'unknown', response.text) + self.log.error("OAuth token request error after {} attempts: {}".format(self.retry_count, self.error)) + + # loop + self.thread = threading.Timer(refresh_timer_ms / 1000.0, self._poller_loop) + self.thread.daemon = True + self.thread.start() diff --git a/segment/analytics/request.py b/segment/analytics/request.py index d1901f79..273247ce 100644 --- a/segment/analytics/request.py +++ b/segment/analytics/request.py @@ -13,19 +13,26 @@ _session = sessions.Session() -def post(write_key, host=None, gzip=False, timeout=15, proxies=None, **kwargs): +def post(write_key, host=None, gzip=False, timeout=15, proxies=None, oauth_manager=None, **kwargs): """Post the `kwargs` to the API""" log = logging.getLogger('segment') body = kwargs - body["sentAt"] = datetime.utcnow().replace(tzinfo=tzutc()).isoformat() + if not "sentAt" in body.keys(): + body["sentAt"] = datetime.utcnow().replace(tzinfo=tzutc()).isoformat() + body["writeKey"] = write_key url = remove_trailing_slash(host or 'https://api.segment.io') + '/v1/batch' - auth = HTTPBasicAuth(write_key, '') + auth = None + if oauth_manager: + auth = oauth_manager.get_token() data = json.dumps(body, cls=DatetimeSerializer) log.debug('making request: %s', data) headers = { 'Content-Type': 'application/json', 'User-Agent': 'analytics-python/' + VERSION } + if auth: + headers['Authorization'] = 'Bearer {}'.format(auth) + if gzip: headers['Content-Encoding'] = 'gzip' buf = BytesIO() @@ -37,26 +44,32 @@ def post(write_key, host=None, gzip=False, timeout=15, proxies=None, **kwargs): kwargs = { "data": data, - "auth": auth, "headers": headers, "timeout": 15, } if proxies: kwargs['proxies'] = proxies - - res = _session.post(url, data=data, auth=auth, - headers=headers, timeout=timeout) - + res = None + try: + res = _session.post(url, data=data, headers=headers, timeout=timeout) + except Exception as e: + log.error(e) + raise e + if res.status_code == 200: log.debug('data uploaded successfully') return res + if oauth_manager and res.status_code in [400, 401, 403]: + oauth_manager.clear_token() + try: payload = res.json() log.debug('received response: %s', payload) raise APIError(res.status_code, payload['code'], payload['message']) except ValueError: + log.error('Unknown error: [%s] %s', res.status_code, res.reason) raise APIError(res.status_code, 'unknown', res.text) diff --git a/segment/analytics/test/__init__.py b/segment/analytics/test/__init__.py index 09bf9b63..98ad6aa3 100644 --- a/segment/analytics/test/__init__.py +++ b/segment/analytics/test/__init__.py @@ -2,14 +2,13 @@ import pkgutil import logging import sys -import analytics - -from analytics.client import Client +import segment.analytics as analytics +from segment.analytics.client import Client def all_names(): for _, modname, _ in pkgutil.iter_modules(__path__): - yield 'analytics.test.' + modname + yield 'segment.analytics.test.' + modname def all(): @@ -32,6 +31,7 @@ def test_debug(self): analytics.debug = False analytics.flush() self.assertFalse(analytics.default_client.debug) + analytics.default_client.log.setLevel(0) # reset log level after debug enable def test_gzip(self): self.assertIsNone(analytics.default_client) @@ -45,9 +45,11 @@ def test_gzip(self): def test_host(self): self.assertIsNone(analytics.default_client) - analytics.host = 'test-host' + analytics.host = 'http://test-host' analytics.flush() - self.assertEqual(analytics.default_client.host, 'test-host') + self.assertEqual(analytics.default_client.host, 'http://test-host') + analytics.host = None + analytics.default_client = None def test_max_queue_size(self): self.assertIsNone(analytics.default_client) @@ -80,3 +82,6 @@ def test_timeout(self): def setUp(self): analytics.write_key = 'test-init' analytics.default_client = None + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/segment/analytics/test/client.py b/segment/analytics/test/client.py index f01dc2ce..bffdb1e5 100644 --- a/segment/analytics/test/client.py +++ b/segment/analytics/test/client.py @@ -3,13 +3,13 @@ import time import mock -from analytics.version import VERSION -from analytics.client import Client +from segment.analytics.version import VERSION +from segment.analytics.client import Client class TestClient(unittest.TestCase): - def fail(self, e, batch): + def fail(self, e, batch=[]): """Mark the failure handler""" self.failed = True @@ -294,6 +294,7 @@ def test_numeric_user_id(self): def test_debug(self): Client('bad_key', debug=True) + self.client.log.setLevel(0) # reset log level after debug enable def test_identify_with_date_object(self): client = self.client @@ -325,7 +326,7 @@ def mock_post_fn(*args, **kwargs): # the post function should be called 2 times, with a batch size of 10 # each time. - with mock.patch('analytics.consumer.post', side_effect=mock_post_fn) \ + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn) \ as mock_post: for _ in range(20): client.identify('userId', {'trait': 'value'}) diff --git a/segment/analytics/test/consumer.py b/segment/analytics/test/consumer.py index 16d0b213..4a2a1e24 100644 --- a/segment/analytics/test/consumer.py +++ b/segment/analytics/test/consumer.py @@ -8,8 +8,8 @@ except ImportError: from Queue import Queue -from analytics.consumer import Consumer, MAX_MSG_SIZE -from analytics.request import APIError +from segment.analytics.consumer import Consumer, MAX_MSG_SIZE +from segment.analytics.request import APIError class TestConsumer(unittest.TestCase): @@ -59,7 +59,7 @@ def test_upload_interval(self): upload_interval = 0.3 consumer = Consumer(q, 'testsecret', upload_size=10, upload_interval=upload_interval) - with mock.patch('analytics.consumer.post') as mock_post: + with mock.patch('segment.analytics.consumer.post') as mock_post: consumer.start() for i in range(0, 3): track = { @@ -79,7 +79,7 @@ def test_multiple_uploads_per_interval(self): upload_size = 10 consumer = Consumer(q, 'testsecret', upload_size=upload_size, upload_interval=upload_interval) - with mock.patch('analytics.consumer.post') as mock_post: + with mock.patch('segment.analytics.consumer.post') as mock_post: consumer.start() for i in range(0, upload_size * 2): track = { @@ -110,7 +110,7 @@ def mock_post(*args, **kwargs): raise expected_exception mock_post.call_count = 0 - with mock.patch('analytics.consumer.post', + with mock.patch('segment.analytics.consumer.post', mock.Mock(side_effect=mock_post)): track = { 'type': 'track', @@ -190,7 +190,7 @@ def mock_post_fn(_, data, **kwargs): % len(data.encode())) return res - with mock.patch('analytics.request._session.post', + with mock.patch('segment.analytics.request._session.post', side_effect=mock_post_fn) as mock_post: consumer.start() for _ in range(0, n_msgs + 2): diff --git a/segment/analytics/test/module.py b/segment/analytics/test/module.py index 3901b1c7..e5fe598c 100644 --- a/segment/analytics/test/module.py +++ b/segment/analytics/test/module.py @@ -1,6 +1,6 @@ import unittest -import analytics +import segment.analytics as analytics class TestModule(unittest.TestCase): diff --git a/segment/analytics/test/oauth.py b/segment/analytics/test/oauth.py new file mode 100644 index 00000000..259342bc --- /dev/null +++ b/segment/analytics/test/oauth.py @@ -0,0 +1,155 @@ +from datetime import datetime +import threading +import time +import unittest +import mock +import sys +import os +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../..")) +from segment.analytics.client import Client +import segment.analytics.oauth_manager +import requests + +privatekey = '''-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDVll7uJaH322IN +PQsH2aOXZJ2r1q+6hpVK1R5JV1p41PUzn8pOxyXFHWB+53dUd4B8qywKS36XQjp0 +VmhR1tQ22znQ9ZCM6y4LGeOJBjAZiFZLcGQNNrDFC0WGWTrK1ZTS2K7p5qy4fIXG +laNkMXiGGCawkgcHAdOvPTy8m1d9a6YSetYVmBP/tEYN95jPyZFIoHQfkQPBPr9W +cWPpdEBzasHV5d957akjurPpleDiD5as66UW4dkWXvS7Wu7teCLCyDApcyJKTb2Z +SXybmWjhIZuctZMAx3wT/GgW3FbkGaW5KLQgBUMzjpL0fCtMatlqckMD92ll1FuK +R+HnXu05AgMBAAECggEBAK4o2il4GDUh9zbyQo9ZIPLuwT6AZXRED3Igi3ykNQp4 +I6S/s9g+vQaY6LkyBnSiqOt/K/8NBiFSiJWaa5/n+8zrP56qzf6KOlYk+wsdN5Vq +PWtwLrUzljpl8YAWPEFunNa8hwwE42vfZbnDBKNLT4qQIOQzfnVxQOoQlfj49gM2 +iSrblvsnQTyucFy3UyTeioHbh8q2Xqcxry5WUCOrFDd3IIwATTpLZGw0IPeuFJbJ +NfBizLEcyJaM9hujQU8PRCRd16MWX+bbYM6Mh4dkT40QXWsVnHBHwsgPdQgDgseF +Na4ajtHoC0DlwYCXpCm3IzJfKfq/LR2q8NDUgKeF4AECgYEA9nD4czza3SRbzhpZ +bBoK77CSNqCcMAqyuHB0hp/XX3yB7flF9PIPb2ReO8wwmjbxn+bm7PPz2Uwd2SzO +pU+FXmyKJr53Jxw/hmDWZCoh42gsGDlVqpmytzsj74KlaYiMyZmEGbD7t/FGfNGV +LdLDJaHIYxEimFviOTXKCeKvPAECgYEA3d8tv4jdp1uAuRZiU9Z/tfw5mJOi3oXF +8AdFFDwaPzcTorEAxjrt9X6IjPbLIDJNJtuXYpe+dG6720KyuNnhLhWW9oZEJTwT +dUgqZ2fTCOS9uH0jSn+ZFlgTWI6UDQXRwE7z8avlhMIrQVmPsttGTo7V6sQVtGRx +bNj2RSVekTkCgYAJvy4UYLPHS0jWPfSLcfw8vp8JyhBjVgj7gncZW/kIrcP1xYYe +yfQSU8XmV40UjFfCGz/G318lmP0VOdByeVKtCV3talsMEPHyPqI8E+6DL/uOebYJ +qUqINK6XKnOgWOY4kvnGillqTQCcry1XQp61PlDOmj7kB75KxPXYrj6AAQKBgQDa ++ixCv6hURuEyy77cE/YT/Q4zYnL6wHjtP5+UKwWUop1EkwG6o+q7wtiul90+t6ah +1VUCP9X/QFM0Qg32l0PBohlO0pFrVnG17TW8vSHxwyDkds1f97N19BOT8ZR5jebI +sKPfP9LVRnY+l1BWLEilvB+xBzqMwh2YWkIlWI6PMQKBgGi6TBnxp81lOYrxVRDj +/3ycRnVDmBdlQKFunvfzUBmG1mG/G0YHeVSUKZJGX7w2l+jnDwIA383FcUeA8X6A +l9q+amhtkwD/6fbkAu/xoWNl+11IFoxd88y2ByBFoEKB6UVLuCTSKwXDqzEZet7x +mDyRxq7ohIzLkw8b8buDeuXZ +-----END PRIVATE KEY-----''' + +def mocked_requests_get(*args, **kwargs): + class MockResponse: + def __init__(self, data, status_code): + self.__dict__['headers'] = {'date': datetime.now().strftime("%a, %d %b %Y %H:%M:%S GMT")} + self.__dict__.update(data) + self.status_code = status_code + + def json(self): + return self.json_data + if 'url' not in kwargs: + kwargs['url'] = args[0] + if kwargs['url'] == 'http://127.0.0.1:80/token': + return MockResponse({"json_data" : {"access_token": "test_token", "expires_in": 4000}}, 200) + elif kwargs['url'] == 'http://127.0.0.1:400/token': + return MockResponse({"reason": "test_reason", "json_data" : {"error":"unrecoverable", "error_description":"nah"}}, 400) + elif kwargs['url'] == 'http://127.0.0.1:429/token': + return MockResponse({"reason": "test_reason", "headers" : {"X-RateLimit-Reset": time.time()*1000 + 2000}}, 429) + elif kwargs['url'] == 'http://127.0.0.1:500/token': + return MockResponse({"reason": "test_reason", "json_data" : {"error":"recoverable", "error_description":"nah"}}, 500) + elif kwargs['url'] == 'http://127.0.0.1:501/token': + if mocked_requests_get.error_count < 0 or mocked_requests_get.error_count > 0: + if mocked_requests_get.error_count > 0: + mocked_requests_get.error_count -= 1 + return MockResponse({"reason": "test_reason", "json_data" : {"error":"recoverable", "message":"nah"}}, 500) + else: # return the number of errors if set above 0 + mocked_requests_get.error_count = -1 + return MockResponse({"json_data" : {"access_token": "test_token", "expires_in": 4000}}, 200) + elif kwargs['url'] == 'https://api.segment.io/v1/batch': + return MockResponse({}, 200) + print("Unhandled mock URL") + return MockResponse({'text':'Unhandled mock URL error'}, 404) +mocked_requests_get.error_count = -1 + +class TestOauthManager(unittest.TestCase): + @mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get) + def test_oauth_success(self, mock_post): + manager = segment.analytics.oauth_manager.OauthManager("id", privatekey, "keyid", "http://127.0.0.1:80") + self.assertEqual(manager.get_token(), "test_token") + self.assertEqual(manager.max_retries, 3) + self.assertEqual(manager.scope, "tracking_api:write") + self.assertEqual(manager.auth_server, "http://127.0.0.1:80") + self.assertEqual(manager.timeout, 15) + self.assertTrue(manager.thread.is_alive) + + @mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get) + def test_oauth_fail_unrecoverably(self, mock_post): + manager = segment.analytics.oauth_manager.OauthManager("id", privatekey, "keyid", "http://127.0.0.1:400") + with self.assertRaises(Exception) as context: + manager.get_token() + self.assertTrue(manager.thread.is_alive) + self.assertEqual(mock_post.call_count, 1) + manager.thread.cancel() + + @mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get) + def test_oauth_fail_with_retries(self, mock_post): + manager = segment.analytics.oauth_manager.OauthManager("id", privatekey, "keyid", "http://127.0.0.1:500") + with self.assertRaises(Exception) as context: + manager.get_token() + self.assertTrue(manager.thread.is_alive) + self.assertEqual(mock_post.call_count, 3) + manager.thread.cancel() + + @mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get) + @mock.patch('time.sleep', spec=time.sleep) # 429 uses sleep so it won't be interrupted + def test_oauth_rate_limit_delay(self, mock_sleep, mock_post): + manager = segment.analytics.oauth_manager.OauthManager("id", privatekey, "keyid", "http://127.0.0.1:429") + manager._poller_loop() + self.assertTrue(mock_sleep.call_args[0][0] > 1.9 and mock_sleep.call_args[0][0] <= 2.0) + +class TestOauthIntegration(unittest.TestCase): + def fail(self, e, batch=[]): + self.failed = True + + def setUp(self): + self.failed = False + + @mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get) + def test_oauth_integration_success(self, mock_post): + client = Client("write_key", on_error=self.fail, oauth_auth_server="http://127.0.0.1:80", + oauth_client_id="id",oauth_client_key=privatekey, oauth_key_id="keyid") + client.track("user", "event") + client.flush() + self.assertFalse(self.failed) + self.assertEqual(mock_post.call_count, 2) + + @mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get) + def test_oauth_integration_failure(self, mock_post): + client = Client("write_key", on_error=self.fail, oauth_auth_server="http://127.0.0.1:400", + oauth_client_id="id",oauth_client_key=privatekey, oauth_key_id="keyid") + client.track("user", "event") + client.flush() + self.assertTrue(self.failed) + self.assertEqual(mock_post.call_count, 1) + + @mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get) + def test_oauth_integration_recovery(self, mock_post): + mocked_requests_get.error_count = 2 # 2 errors and then success + client = Client("write_key", on_error=self.fail, oauth_auth_server="http://127.0.0.1:501", + oauth_client_id="id",oauth_client_key=privatekey, oauth_key_id="keyid") + client.track("user", "event") + client.flush() + self.assertFalse(self.failed) + self.assertEqual(mock_post.call_count, 4) + + @mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get) + def test_oauth_integration_fail_bad_key(self, mock_post): + client = Client("write_key", on_error=self.fail, oauth_auth_server="http://127.0.0.1:80", + oauth_client_id="id",oauth_client_key="badkey", oauth_key_id="keyid") + client.track("user", "event") + client.flush() + self.assertTrue(self.failed) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/segment/analytics/test/request.py b/segment/analytics/test/request.py index 3420deca..3f40c497 100644 --- a/segment/analytics/test/request.py +++ b/segment/analytics/test/request.py @@ -3,7 +3,7 @@ import json import requests -from analytics.request import post, DatetimeSerializer +from segment.analytics.request import post, DatetimeSerializer class TestRequests(unittest.TestCase): diff --git a/segment/analytics/test/utils.py b/segment/analytics/test/utils.py index 6995e799..43a6fd4b 100644 --- a/segment/analytics/test/utils.py +++ b/segment/analytics/test/utils.py @@ -4,7 +4,7 @@ from dateutil.tz import tzutc -from analytics import utils +from segment.analytics import utils class TestUtils(unittest.TestCase): diff --git a/setup.py b/setup.py index fbc1d066..0f3da9ed 100644 --- a/setup.py +++ b/setup.py @@ -40,8 +40,8 @@ author_email='friends@segment.com', maintainer='Segment', maintainer_email='friends@segment.com', - test_suite='analytics.test.all', - packages=['segment.analytics', 'analytics.test'], + test_suite='segment.analytics.test.all', + packages=['segment.analytics', 'segment.analytics.test'], python_requires='>=3.6.0', license='MIT License', install_requires=install_requires,