Skip to content

Commit

Permalink
Various fixes, tests, and clock skew adjustment
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelGHSeg committed Oct 6, 2023
1 parent 1b231e0 commit 2f5945c
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 18 deletions.
3 changes: 2 additions & 1 deletion samples/oauth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
import os
import time
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
import time
import segment.analytics as analytics

privatekey = '''-----BEGIN PRIVATE KEY-----
Expand Down Expand Up @@ -48,5 +48,6 @@ def on_error(error, items):
analytics.on_error = on_error

analytics.track('AUser', 'track')
analytics.flush()

time.sleep(3)
9 changes: 9 additions & 0 deletions segment/analytics/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
# lower to leave space for extra data that will be added later, eg. "sentAt".
BATCH_SIZE_LIMIT = 475000

class FatalError(Exception):
def __init__(self, message):
self.message = message
def __str__(self):
msg = "[Segment] {0})"
return msg.format(self.message)


class Consumer(Thread):
"""Consumes the messages from the client's queue."""
Expand Down Expand Up @@ -119,6 +126,8 @@ def fatal_exception(exc):
# with 429 status code (rate limited),
# don't retry on other client errors
return (400 <= exc.status < 500) and exc.status != 429
elif isinstance(exc, FatalError):
return True
else:
# retry on all other errors (eg. network)
return False
Expand Down
51 changes: 39 additions & 12 deletions segment/analytics/oauth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import jwt

from segment.analytics import utils
from segment.analytics.request import APIError
from segment.analytics.consumer import FatalError

_session = sessions.Session()

Expand All @@ -17,7 +19,7 @@ def __init__(self,
key_id,
auth_server='https://oauth2.segment.io',
scope='tracking_api:write',
timeout=5,
timeout=15,
max_retries=3):
self.client_id = client_id
self.client_key = client_key
Expand All @@ -27,6 +29,7 @@ def __init__(self,
self.timeout = timeout
self.max_retries = max_retries
self.retry_count = 0
self.clock_skew = 0

self.log = logging.getLogger('segment')
self.thread = None
Expand All @@ -40,7 +43,7 @@ def get_token(self):
return self.token
# No good token, start the loop
self.log.debug("OAuth is enabled. No cached access token.")
# Make sure we're not waiting an excessively long time
# Make sure we're not waiting an excessively long time (this will not cancel 429 waits)
if self.thread and self.thread.is_alive():
self.thread.cancel()
self.thread = threading.Timer(0,self._poller_loop)
Expand All @@ -55,10 +58,10 @@ def get_token(self):
if self.error:
error = self.error
self.error = None
raise Exception(error)
raise error
if self.thread:
# Wait for a cycle, may not have an answer immediately
self.thread.join()
self.thread.join(1)

def clear_token(self):
self.log.debug("OAuth Token invalidated. Poller for new token is {}".format(
Expand All @@ -71,8 +74,8 @@ def _request_token(self):
"iss": self.client_id,
"sub": self.client_id,
"aud": utils.remove_trailing_slash(self.auth_server),
"iat": int(time.time())-1,
"exp": int(time.time()) + 59,
"iat": int(time.time())-5 - self.clock_skew,
"exp": int(time.time()) + 55 - self.clock_skew,
"jti": str(uuid.uuid4())
}

Expand Down Expand Up @@ -104,15 +107,21 @@ def _poller_loop(self):
except Exception as e:
self.retry_count += 1
if self.retry_count < self.max_retries:
self.log.error("OAuth token request encountered an error on attempt {}: {}".format(self.retry_count ,e))
self.log.debug("OAuth token request encountered an error on attempt {}: {}".format(self.retry_count ,e))
self.thread = threading.Timer(refresh_timer_ms / 1000.0, self._poller_loop)
self.thread.daemon = True
self.thread.start()
return
# Too many retries, giving up
self.log.error("OAuth token request encountered an error after {} attempts: {}".format(self.retry_count ,e))
self.error = e
self.error = FatalError(str(e))
return
if response.headers.get("Date"):
try:
server_time = datetime.strptime(response.headers.get("Date"), "%a, %d %b %Y %H:%M:%S %Z")
self.clock_skew = int((datetime.utcnow() - server_time).total_seconds())
except Exception as e:
self.log.error("OAuth token request received a response with an invalid Date header: {} | {}".format(response, e))

if response.status_code == 200:
data = None
Expand Down Expand Up @@ -157,12 +166,26 @@ def _poller_loop(self):
# We want subsequent calls to get_token to be able to interrupt our
# Timeout when it's waiting for e.g. a long normal expiration, but
# not when we're waiting for a rate limit reset. Sleep instead.
time.sleep(refresh_timer_ms * 1000)
time.sleep(refresh_timer_ms / 1000.0)
refresh_timer_ms = 0
elif response.status_code in [400, 401, 415]:
# unrecoverable errors
# unrecoverable errors (except for skew). APIError will be handled by request.py
self.retry_count = 0
self.error = Exception(f'[{response.status_code}] {response.reason}')
try:
payload = response.json()

if (payload.get('error') == 'invalid_request' and
(payload.get('error_description') == 'Token is expired' or
payload.get('error_description') == 'Token used before issued')):
refresh_timer_ms = 0 # Retry immediately hopefully with a good skew value
self.thread = threading.Timer(refresh_timer_ms / 1000.0, self._poller_loop)
self.thread.daemon = True
self.thread.start()
return

self.error = APIError(response.status_code, payload['error'], payload['error_description'])
except ValueError:
self.error = APIError(response.status_code, 'unknown', response.text)
self.log.error("OAuth token request error was unrecoverable, possibly due to configuration: {}".format(self.error))
return
else:
Expand All @@ -172,7 +195,11 @@ def _poller_loop(self):

if self.retry_count > 0 and self.retry_count % self.max_retries == 0:
# every time we pass the retry count, put up an error to release any waiting token requests
self.error = Exception(f'[{response.status_code}] {response.reason}')
try:
payload = response.json()
self.error = APIError(response.status_code, payload['error'], payload['error_description'])
except ValueError:
self.error = APIError(response.status_code, 'unknown', response.text)
self.log.error("OAuth token request error after {} attempts: {}".format(self.retry_count, self.error))

# loop
Expand Down
8 changes: 3 additions & 5 deletions segment/analytics/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@ def post(write_key, host=None, gzip=False, timeout=15, proxies=None, oauth_manag
"""Post the `kwargs` to the API"""
log = logging.getLogger('segment')
body = kwargs
body["sentAt"] = datetime.utcnow().replace(tzinfo=tzutc()).isoformat()
if not "sentAt" in body.keys():
body["sentAt"] = datetime.utcnow().replace(tzinfo=tzutc()).isoformat()
body["writeKey"] = write_key
url = remove_trailing_slash(host or 'https://api.segment.io') + '/v1/batch'
auth = None
if oauth_manager:
try:
auth = oauth_manager.get_token()
except Exception as e:
raise e
auth = oauth_manager.get_token()
data = json.dumps(body, cls=DatetimeSerializer)
log.debug('making request: %s', data)
headers = {
Expand Down
155 changes: 155 additions & 0 deletions segment/analytics/test/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from datetime import datetime
import threading
import time
import unittest
import mock
import sys
import os
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../.."))
from segment.analytics.client import Client
import segment.analytics.oauth_manager
import requests

privatekey = '''-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDVll7uJaH322IN
PQsH2aOXZJ2r1q+6hpVK1R5JV1p41PUzn8pOxyXFHWB+53dUd4B8qywKS36XQjp0
VmhR1tQ22znQ9ZCM6y4LGeOJBjAZiFZLcGQNNrDFC0WGWTrK1ZTS2K7p5qy4fIXG
laNkMXiGGCawkgcHAdOvPTy8m1d9a6YSetYVmBP/tEYN95jPyZFIoHQfkQPBPr9W
cWPpdEBzasHV5d957akjurPpleDiD5as66UW4dkWXvS7Wu7teCLCyDApcyJKTb2Z
SXybmWjhIZuctZMAx3wT/GgW3FbkGaW5KLQgBUMzjpL0fCtMatlqckMD92ll1FuK
R+HnXu05AgMBAAECggEBAK4o2il4GDUh9zbyQo9ZIPLuwT6AZXRED3Igi3ykNQp4
I6S/s9g+vQaY6LkyBnSiqOt/K/8NBiFSiJWaa5/n+8zrP56qzf6KOlYk+wsdN5Vq
PWtwLrUzljpl8YAWPEFunNa8hwwE42vfZbnDBKNLT4qQIOQzfnVxQOoQlfj49gM2
iSrblvsnQTyucFy3UyTeioHbh8q2Xqcxry5WUCOrFDd3IIwATTpLZGw0IPeuFJbJ
NfBizLEcyJaM9hujQU8PRCRd16MWX+bbYM6Mh4dkT40QXWsVnHBHwsgPdQgDgseF
Na4ajtHoC0DlwYCXpCm3IzJfKfq/LR2q8NDUgKeF4AECgYEA9nD4czza3SRbzhpZ
bBoK77CSNqCcMAqyuHB0hp/XX3yB7flF9PIPb2ReO8wwmjbxn+bm7PPz2Uwd2SzO
pU+FXmyKJr53Jxw/hmDWZCoh42gsGDlVqpmytzsj74KlaYiMyZmEGbD7t/FGfNGV
LdLDJaHIYxEimFviOTXKCeKvPAECgYEA3d8tv4jdp1uAuRZiU9Z/tfw5mJOi3oXF
8AdFFDwaPzcTorEAxjrt9X6IjPbLIDJNJtuXYpe+dG6720KyuNnhLhWW9oZEJTwT
dUgqZ2fTCOS9uH0jSn+ZFlgTWI6UDQXRwE7z8avlhMIrQVmPsttGTo7V6sQVtGRx
bNj2RSVekTkCgYAJvy4UYLPHS0jWPfSLcfw8vp8JyhBjVgj7gncZW/kIrcP1xYYe
yfQSU8XmV40UjFfCGz/G318lmP0VOdByeVKtCV3talsMEPHyPqI8E+6DL/uOebYJ
qUqINK6XKnOgWOY4kvnGillqTQCcry1XQp61PlDOmj7kB75KxPXYrj6AAQKBgQDa
+ixCv6hURuEyy77cE/YT/Q4zYnL6wHjtP5+UKwWUop1EkwG6o+q7wtiul90+t6ah
1VUCP9X/QFM0Qg32l0PBohlO0pFrVnG17TW8vSHxwyDkds1f97N19BOT8ZR5jebI
sKPfP9LVRnY+l1BWLEilvB+xBzqMwh2YWkIlWI6PMQKBgGi6TBnxp81lOYrxVRDj
/3ycRnVDmBdlQKFunvfzUBmG1mG/G0YHeVSUKZJGX7w2l+jnDwIA383FcUeA8X6A
l9q+amhtkwD/6fbkAu/xoWNl+11IFoxd88y2ByBFoEKB6UVLuCTSKwXDqzEZet7x
mDyRxq7ohIzLkw8b8buDeuXZ
-----END PRIVATE KEY-----'''

def mocked_requests_get(*args, **kwargs):
class MockResponse:
def __init__(self, data, status_code):
self.__dict__['headers'] = {'date': datetime.now().strftime("%a, %d %b %Y %H:%M:%S GMT")}
self.__dict__.update(data)
self.status_code = status_code

def json(self):
return self.json_data
if 'url' not in kwargs:
kwargs['url'] = args[0]
if kwargs['url'] == 'http://127.0.0.1:80/token':
return MockResponse({"json_data" : {"access_token": "test_token", "expires_in": 4000}}, 200)
elif kwargs['url'] == 'http://127.0.0.1:400/token':
return MockResponse({"reason": "test_reason", "json_data" : {"error":"unrecoverable", "error_description":"nah"}}, 400)
elif kwargs['url'] == 'http://127.0.0.1:429/token':
return MockResponse({"reason": "test_reason", "headers" : {"X-RateLimit-Reset": time.time()*1000 + 2000}}, 429)
elif kwargs['url'] == 'http://127.0.0.1:500/token':
return MockResponse({"reason": "test_reason", "json_data" : {"error":"recoverable", "error_description":"nah"}}, 500)
elif kwargs['url'] == 'http://127.0.0.1:501/token':
if mocked_requests_get.error_count < 0 or mocked_requests_get.error_count > 0:
if mocked_requests_get.error_count > 0:
mocked_requests_get.error_count -= 1
return MockResponse({"reason": "test_reason", "json_data" : {"error":"recoverable", "message":"nah"}}, 500)
else: # return the number of errors if set above 0
mocked_requests_get.error_count = -1
return MockResponse({"json_data" : {"access_token": "test_token", "expires_in": 4000}}, 200)
elif kwargs['url'] == 'https://api.segment.io/v1/batch':
return MockResponse({}, 200)
print("Unhandled mock URL")
return MockResponse({'text':'Unhandled mock URL error'}, 404)
mocked_requests_get.error_count = -1

class TestOauthManager(unittest.TestCase):
@mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get)
def test_oauth_success(self, mock_post):
manager = segment.analytics.oauth_manager.OauthManager("id", privatekey, "keyid", "http://127.0.0.1:80")
self.assertEqual(manager.get_token(), "test_token")
self.assertEqual(manager.max_retries, 3)
self.assertEqual(manager.scope, "tracking_api:write")
self.assertEqual(manager.auth_server, "http://127.0.0.1:80")
self.assertEqual(manager.timeout, 15)
self.assertTrue(manager.thread.is_alive)

@mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get)
def test_oauth_fail_unrecoverably(self, mock_post):
manager = segment.analytics.oauth_manager.OauthManager("id", privatekey, "keyid", "http://127.0.0.1:400")
with self.assertRaises(Exception) as context:
manager.get_token()
self.assertTrue(manager.thread.is_alive)
self.assertEqual(mock_post.call_count, 1)
manager.thread.cancel()

@mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get)
def test_oauth_fail_with_retries(self, mock_post):
manager = segment.analytics.oauth_manager.OauthManager("id", privatekey, "keyid", "http://127.0.0.1:500")
with self.assertRaises(Exception) as context:
manager.get_token()
self.assertTrue(manager.thread.is_alive)
self.assertEqual(mock_post.call_count, 3)
manager.thread.cancel()

@mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get)
@mock.patch('time.sleep', spec=time.sleep) # 429 uses sleep so it won't be interrupted
def test_oauth_rate_limit_delay(self, mock_sleep, mock_post):
manager = segment.analytics.oauth_manager.OauthManager("id", privatekey, "keyid", "http://127.0.0.1:429")
manager._poller_loop()
self.assertTrue(mock_sleep.call_args[0][0] > 1.9 and mock_sleep.call_args[0][0] <= 2.0)

class TestOauthIntegration(unittest.TestCase):
def fail(self, e, batch=[]):
self.failed = True

def setUp(self):
self.failed = False

@mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get)
def test_oauth_integration_success(self, mock_post):
client = Client("write_key", on_error=self.fail, oauth_auth_server="http://127.0.0.1:80",
oauth_client_id="id",oauth_client_key=privatekey, oauth_key_id="keyid")
client.track("user", "event")
client.flush()
self.assertFalse(self.failed)
self.assertEqual(mock_post.call_count, 2)

@mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get)
def test_oauth_integration_failure(self, mock_post):
client = Client("write_key", on_error=self.fail, oauth_auth_server="http://127.0.0.1:400",
oauth_client_id="id",oauth_client_key=privatekey, oauth_key_id="keyid")
client.track("user", "event")
client.flush()
self.assertTrue(self.failed)
self.assertEqual(mock_post.call_count, 1)

@mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get)
def test_oauth_integration_recovery(self, mock_post):
mocked_requests_get.error_count = 2 # 2 errors and then success
client = Client("write_key", on_error=self.fail, oauth_auth_server="http://127.0.0.1:501",
oauth_client_id="id",oauth_client_key=privatekey, oauth_key_id="keyid")
client.track("user", "event")
client.flush()
self.assertFalse(self.failed)
self.assertEqual(mock_post.call_count, 4)

@mock.patch.object(requests.Session, 'post', side_effect=mocked_requests_get)
def test_oauth_integration_fail_bad_key(self, mock_post):
client = Client("write_key", on_error=self.fail, oauth_auth_server="http://127.0.0.1:80",
oauth_client_id="id",oauth_client_key="badkey", oauth_key_id="keyid")
client.track("user", "event")
client.flush()
self.assertTrue(self.failed)

if __name__ == '__main__':
unittest.main()

0 comments on commit 2f5945c

Please sign in to comment.