diff --git a/samples/oauth.py b/samples/oauth.py index f67665e..9440737 100644 --- a/samples/oauth.py +++ b/samples/oauth.py @@ -1,7 +1,7 @@ import sys import os -import time sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) +import time import segment.analytics as analytics privatekey = '''-----BEGIN PRIVATE KEY----- @@ -48,5 +48,6 @@ def on_error(error, items): analytics.on_error = on_error analytics.track('AUser', 'track') +analytics.flush() time.sleep(3) \ No newline at end of file diff --git a/segment/analytics/consumer.py b/segment/analytics/consumer.py index c20b1a1..a78f2d3 100644 --- a/segment/analytics/consumer.py +++ b/segment/analytics/consumer.py @@ -14,6 +14,13 @@ # lower to leave space for extra data that will be added later, eg. "sentAt". BATCH_SIZE_LIMIT = 475000 +class FatalError(Exception): + def __init__(self, message): + self.message = message + def __str__(self): + msg = "[Segment] {0})" + return msg.format(self.message) + class Consumer(Thread): """Consumes the messages from the client's queue.""" @@ -119,6 +126,8 @@ def fatal_exception(exc): # with 429 status code (rate limited), # don't retry on other client errors return (400 <= exc.status < 500) and exc.status != 429 + elif isinstance(exc, FatalError): + return True else: # retry on all other errors (eg. network) return False diff --git a/segment/analytics/oauth_manager.py b/segment/analytics/oauth_manager.py index ab10350..453a23a 100644 --- a/segment/analytics/oauth_manager.py +++ b/segment/analytics/oauth_manager.py @@ -7,6 +7,8 @@ import jwt from segment.analytics import utils +from segment.analytics.request import APIError +from segment.analytics.consumer import FatalError _session = sessions.Session() @@ -17,7 +19,7 @@ def __init__(self, key_id, auth_server='https://oauth2.segment.io', scope='tracking_api:write', - timeout=5, + timeout=15, max_retries=3): self.client_id = client_id self.client_key = client_key @@ -27,6 +29,7 @@ def __init__(self, self.timeout = timeout self.max_retries = max_retries self.retry_count = 0 + self.clock_skew = 0 self.log = logging.getLogger('segment') self.thread = None @@ -40,7 +43,7 @@ def get_token(self): return self.token # No good token, start the loop self.log.debug("OAuth is enabled. No cached access token.") - # Make sure we're not waiting an excessively long time + # Make sure we're not waiting an excessively long time (this will not cancel 429 waits) if self.thread and self.thread.is_alive(): self.thread.cancel() self.thread = threading.Timer(0,self._poller_loop) @@ -55,10 +58,10 @@ def get_token(self): if self.error: error = self.error self.error = None - raise Exception(error) + raise error if self.thread: # Wait for a cycle, may not have an answer immediately - self.thread.join() + self.thread.join(1) def clear_token(self): self.log.debug("OAuth Token invalidated. Poller for new token is {}".format( @@ -71,8 +74,8 @@ def _request_token(self): "iss": self.client_id, "sub": self.client_id, "aud": utils.remove_trailing_slash(self.auth_server), - "iat": int(time.time())-1, - "exp": int(time.time()) + 59, + "iat": int(time.time())-5 - self.clock_skew, + "exp": int(time.time()) + 55 - self.clock_skew, "jti": str(uuid.uuid4()) } @@ -104,15 +107,21 @@ def _poller_loop(self): except Exception as e: self.retry_count += 1 if self.retry_count < self.max_retries: - self.log.error("OAuth token request encountered an error on attempt {}: {}".format(self.retry_count ,e)) + self.log.debug("OAuth token request encountered an error on attempt {}: {}".format(self.retry_count ,e)) self.thread = threading.Timer(refresh_timer_ms / 1000.0, self._poller_loop) self.thread.daemon = True self.thread.start() return # Too many retries, giving up self.log.error("OAuth token request encountered an error after {} attempts: {}".format(self.retry_count ,e)) - self.error = e + self.error = FatalError(str(e)) return + if response.headers.get("Date"): + try: + server_time = datetime.strptime(response.headers.get("Date"), "%a, %d %b %Y %H:%M:%S %Z") + self.clock_skew = int((datetime.utcnow() - server_time).total_seconds()) + except Exception as e: + self.log.error("OAuth token request received a response with an invalid Date header: {} | {}".format(response, e)) if response.status_code == 200: data = None @@ -157,12 +166,26 @@ def _poller_loop(self): # We want subsequent calls to get_token to be able to interrupt our # Timeout when it's waiting for e.g. a long normal expiration, but # not when we're waiting for a rate limit reset. Sleep instead. - time.sleep(refresh_timer_ms * 1000) + time.sleep(refresh_timer_ms / 1000.0) refresh_timer_ms = 0 elif response.status_code in [400, 401, 415]: - # unrecoverable errors + # unrecoverable errors (except for skew). APIError will be handled by request.py self.retry_count = 0 - self.error = Exception(f'[{response.status_code}] {response.reason}') + try: + payload = response.json() + + if (payload.get('error') == 'invalid_request' and + (payload.get('error_description') == 'Token is expired' or + payload.get('error_description') == 'Token used before issued')): + refresh_timer_ms = 0 # Retry immediately hopefully with a good skew value + self.thread = threading.Timer(refresh_timer_ms / 1000.0, self._poller_loop) + self.thread.daemon = True + self.thread.start() + return + + self.error = APIError(response.status_code, payload['error'], payload['error_description']) + except ValueError: + self.error = APIError(response.status_code, 'unknown', response.text) self.log.error("OAuth token request error was unrecoverable, possibly due to configuration: {}".format(self.error)) return else: @@ -172,7 +195,11 @@ def _poller_loop(self): if self.retry_count > 0 and self.retry_count % self.max_retries == 0: # every time we pass the retry count, put up an error to release any waiting token requests - self.error = Exception(f'[{response.status_code}] {response.reason}') + try: + payload = response.json() + self.error = APIError(response.status_code, payload['error'], payload['error_description']) + except ValueError: + self.error = APIError(response.status_code, 'unknown', response.text) self.log.error("OAuth token request error after {} attempts: {}".format(self.retry_count, self.error)) # loop diff --git a/segment/analytics/request.py b/segment/analytics/request.py index e66ced5..273247c 100644 --- a/segment/analytics/request.py +++ b/segment/analytics/request.py @@ -17,15 +17,13 @@ def post(write_key, host=None, gzip=False, timeout=15, proxies=None, oauth_manag """Post the `kwargs` to the API""" log = logging.getLogger('segment') body = kwargs - body["sentAt"] = datetime.utcnow().replace(tzinfo=tzutc()).isoformat() + if not "sentAt" in body.keys(): + body["sentAt"] = datetime.utcnow().replace(tzinfo=tzutc()).isoformat() body["writeKey"] = write_key url = remove_trailing_slash(host or 'https://api.segment.io') + '/v1/batch' auth = None if oauth_manager: - try: - auth = oauth_manager.get_token() - except Exception as e: - raise e + auth = oauth_manager.get_token() data = json.dumps(body, cls=DatetimeSerializer) log.debug('making request: %s', data) headers = { diff --git a/segment/analytics/test/oauth.py b/segment/analytics/test/oauth.py new file mode 100644 index 0000000..259342b --- /dev/null +++ b/segment/analytics/test/oauth.py @@ -0,0 +1,155 @@ +from datetime import datetime +import threading +import time +import unittest +import mock +import sys +import os +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../..")) +from segment.analytics.client import Client +import segment.analytics.oauth_manager +import requests + +privatekey = '''-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDVll7uJaH322IN +PQsH2aOXZJ2r1q+6hpVK1R5JV1p41PUzn8pOxyXFHWB+53dUd4B8qywKS36XQjp0 +VmhR1tQ22znQ9ZCM6y4LGeOJBjAZiFZLcGQNNrDFC0WGWTrK1ZTS2K7p5qy4fIXG +laNkMXiGGCawkgcHAdOvPTy8m1d9a6YSetYVmBP/tEYN95jPyZFIoHQfkQPBPr9W +cWPpdEBzasHV5d957akjurPpleDiD5as66UW4dkWXvS7Wu7teCLCyDApcyJKTb2Z +SXybmWjhIZuctZMAx3wT/GgW3FbkGaW5KLQgBUMzjpL0fCtMatlqckMD92ll1FuK +R+HnXu05AgMBAAECggEBAK4o2il4GDUh9zbyQo9ZIPLuwT6AZXRED3Igi3ykNQp4 +I6S/s9g+vQaY6LkyBnSiqOt/K/8NBiFSiJWaa5/n+8zrP56qzf6KOlYk+wsdN5Vq +PWtwLrUzljpl8YAWPEFunNa8hwwE42vfZbnDBKNLT4qQIOQzfnVxQOoQlfj49gM2 +iSrblvsnQTyucFy3UyTeioHbh8q2Xqcxry5WUCOrFDd3IIwATTpLZGw0IPeuFJbJ +NfBizLEcyJaM9hujQU8PRCRd16MWX+bbYM6Mh4dkT40QXWsVnHBHwsgPdQgDgseF +Na4ajtHoC0DlwYCXpCm3IzJfKfq/LR2q8NDUgKeF4AECgYEA9nD4czza3SRbzhpZ +bBoK77CSNqCcMAqyuHB0hp/XX3yB7flF9PIPb2ReO8wwmjbxn+bm7PPz2Uwd2SzO +pU+FXmyKJr53Jxw/hmDWZCoh42gsGDlVqpmytzsj74KlaYiMyZmEGbD7t/FGfNGV +LdLDJaHIYxEimFviOTXKCeKvPAECgYEA3d8tv4jdp1uAuRZiU9Z/tfw5mJOi3oXF +8AdFFDwaPzcTorEAxjrt9X6IjPbLIDJNJtuXYpe+dG6720KyuNnhLhWW9oZEJTwT +dUgqZ2fTCOS9uH0jSn+ZFlgTWI6UDQXRwE7z8avlhMIrQVmPsttGTo7V6sQVtGRx +bNj2RSVekTkCgYAJvy4UYLPHS0jWPfSLcfw8vp8JyhBjVgj7gncZW/kIrcP1xYYe +yfQSU8XmV40UjFfCGz/G318lmP0VOdByeVKtCV3talsMEPHyPqI8E+6DL/uOebYJ +qUqINK6XKnOgWOY4kvnGillqTQCcry1XQp61PlDOmj7kB75KxPXYrj6AAQKBgQDa ++ixCv6hURuEyy77cE/YT/Q4zYnL6wHjtP5+UKwWUop1EkwG6o+q7wtiul90+t6ah +1VUCP9X/QFM0Qg32l0PBohlO0pFrVnG17TW8vSHxwyDkds1f97N19BOT8ZR5jebI +sKPfP9LVRnY+l1BWLEilvB+xBzqMwh2YWkIlWI6PMQKBgGi6TBnxp81lOYrxVRDj +/3ycRnVDmBdlQKFunvfzUBmG1mG/G0YHeVSUKZJGX7w2l+jnDwIA383FcUeA8X6A +l9q+amhtkwD/6fbkAu/xoWNl+11IFoxd88y2ByBFoEKB6UVLuCTSKwXDqzEZet7x +mDyRxq7ohIzLkw8b8buDeuXZ +-----END PRIVATE KEY-----''' + +def mocked_requests_get(*args, **kwargs): + class MockResponse: + def __init__(self, data, status_code): + self.__dict__['headers'] = {'date': datetime.now().strftime("%a, %d %b %Y %H:%M:%S GMT")} + self.__dict__.update(data) + self.status_code = status_code + + def json(self): + return self.json_data + if 'url' not in kwargs: + kwargs['url'] = args[0] + if kwargs['url'] == 'http://127.0.0.1:80/token': + return MockResponse({"json_data" : {"access_token": "test_token", "expires_in": 4000}}, 200) + elif kwargs['url'] == 'http://127.0.0.1:400/token': + return MockResponse({"reason": "test_reason", "json_data" : {"error":"unrecoverable", "error_description":"nah"}}, 400) + elif kwargs['url'] == 'http://127.0.0.1:429/token': + return MockResponse({"reason": "test_reason", "headers" : {"X-RateLimit-Reset": time.time()*1000 + 2000}}, 429) + elif kwargs['url'] == 'http://127.0.0.1:500/token': + return MockResponse({"reason": "test_reason", "json_data" : {"error":"recoverable", "error_description":"nah"}}, 500) + elif kwargs['url'] == 'http://127.0.0.1:501/token': + if mocked_requests_get.error_count < 0 or mocked_requests_get.error_count > 0: + if mocked_requests_get.error_count > 0: + mocked_requests_get.error_count -= 1 + return MockResponse({"reason": "test_reason", "json_data" : {"error":"recoverable", "message":"nah"}}, 500) + else: # return the number of errors if set above 0 + mocked_requests_get.error_count = -1 + return MockResponse({"json_data" : {"access_token": "test_token", "expires_in": 4000}}, 200) + elif kwargs['url'] == 'https://api.segment.io/v1/batch': + return MockResponse({}, 200) + print("Unhandled mock URL") + return MockResponse({'text':'Unhandled mock URL error'}, 404) +mocked_requests_get.error_count = -1 + +class TestOauthManager(unittest.TestCase): + @mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get) + def test_oauth_success(self, mock_post): + manager = segment.analytics.oauth_manager.OauthManager("id", privatekey, "keyid", "http://127.0.0.1:80") + self.assertEqual(manager.get_token(), "test_token") + self.assertEqual(manager.max_retries, 3) + self.assertEqual(manager.scope, "tracking_api:write") + self.assertEqual(manager.auth_server, "http://127.0.0.1:80") + self.assertEqual(manager.timeout, 15) + self.assertTrue(manager.thread.is_alive) + + @mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get) + def test_oauth_fail_unrecoverably(self, mock_post): + manager = segment.analytics.oauth_manager.OauthManager("id", privatekey, "keyid", "http://127.0.0.1:400") + with self.assertRaises(Exception) as context: + manager.get_token() + self.assertTrue(manager.thread.is_alive) + self.assertEqual(mock_post.call_count, 1) + manager.thread.cancel() + + @mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get) + def test_oauth_fail_with_retries(self, mock_post): + manager = segment.analytics.oauth_manager.OauthManager("id", privatekey, "keyid", "http://127.0.0.1:500") + with self.assertRaises(Exception) as context: + manager.get_token() + self.assertTrue(manager.thread.is_alive) + self.assertEqual(mock_post.call_count, 3) + manager.thread.cancel() + + @mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get) + @mock.patch('time.sleep', spec=time.sleep) # 429 uses sleep so it won't be interrupted + def test_oauth_rate_limit_delay(self, mock_sleep, mock_post): + manager = segment.analytics.oauth_manager.OauthManager("id", privatekey, "keyid", "http://127.0.0.1:429") + manager._poller_loop() + self.assertTrue(mock_sleep.call_args[0][0] > 1.9 and mock_sleep.call_args[0][0] <= 2.0) + +class TestOauthIntegration(unittest.TestCase): + def fail(self, e, batch=[]): + self.failed = True + + def setUp(self): + self.failed = False + + @mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get) + def test_oauth_integration_success(self, mock_post): + client = Client("write_key", on_error=self.fail, oauth_auth_server="http://127.0.0.1:80", + oauth_client_id="id",oauth_client_key=privatekey, oauth_key_id="keyid") + client.track("user", "event") + client.flush() + self.assertFalse(self.failed) + self.assertEqual(mock_post.call_count, 2) + + @mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get) + def test_oauth_integration_failure(self, mock_post): + client = Client("write_key", on_error=self.fail, oauth_auth_server="http://127.0.0.1:400", + oauth_client_id="id",oauth_client_key=privatekey, oauth_key_id="keyid") + client.track("user", "event") + client.flush() + self.assertTrue(self.failed) + self.assertEqual(mock_post.call_count, 1) + + @mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get) + def test_oauth_integration_recovery(self, mock_post): + mocked_requests_get.error_count = 2 # 2 errors and then success + client = Client("write_key", on_error=self.fail, oauth_auth_server="http://127.0.0.1:501", + oauth_client_id="id",oauth_client_key=privatekey, oauth_key_id="keyid") + client.track("user", "event") + client.flush() + self.assertFalse(self.failed) + self.assertEqual(mock_post.call_count, 4) + + @mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get) + def test_oauth_integration_fail_bad_key(self, mock_post): + client = Client("write_key", on_error=self.fail, oauth_auth_server="http://127.0.0.1:80", + oauth_client_id="id",oauth_client_key="badkey", oauth_key_id="keyid") + client.track("user", "event") + client.flush() + self.assertTrue(self.failed) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file