Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python OAuth implementation #320

Merged
merged 12 commits into from
Oct 6, 2023
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ ptyprocess==0.7.0
pycodestyle==2.5.0
pyflakes==2.1.1
Pygments==2.16.1
PyJWT==2.8.0
pylint==1.9.3
pyparsing==3.1.1
python-dateutil==2.8.2
Expand Down
51 changes: 51 additions & 0 deletions samples/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import sys
import os
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
import segment.analytics as analytics

privatekey = '''-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDVll7uJaH322IN
PQsH2aOXZJ2r1q+6hpVK1R5JV1p41PUzn8pOxyXFHWB+53dUd4B8qywKS36XQjp0
VmhR1tQ22znQ9ZCM6y4LGeOJBjAZiFZLcGQNNrDFC0WGWTrK1ZTS2K7p5qy4fIXG
laNkMXiGGCawkgcHAdOvPTy8m1d9a6YSetYVmBP/tEYN95jPyZFIoHQfkQPBPr9W
cWPpdEBzasHV5d957akjurPpleDiD5as66UW4dkWXvS7Wu7teCLCyDApcyJKTb2Z
SXybmWjhIZuctZMAx3wT/GgW3FbkGaW5KLQgBUMzjpL0fCtMatlqckMD92ll1FuK
R+HnXu05AgMBAAECggEBAK4o2il4GDUh9zbyQo9ZIPLuwT6AZXRED3Igi3ykNQp4
I6S/s9g+vQaY6LkyBnSiqOt/K/8NBiFSiJWaa5/n+8zrP56qzf6KOlYk+wsdN5Vq
PWtwLrUzljpl8YAWPEFunNa8hwwE42vfZbnDBKNLT4qQIOQzfnVxQOoQlfj49gM2
iSrblvsnQTyucFy3UyTeioHbh8q2Xqcxry5WUCOrFDd3IIwATTpLZGw0IPeuFJbJ
NfBizLEcyJaM9hujQU8PRCRd16MWX+bbYM6Mh4dkT40QXWsVnHBHwsgPdQgDgseF
Na4ajtHoC0DlwYCXpCm3IzJfKfq/LR2q8NDUgKeF4AECgYEA9nD4czza3SRbzhpZ
bBoK77CSNqCcMAqyuHB0hp/XX3yB7flF9PIPb2ReO8wwmjbxn+bm7PPz2Uwd2SzO
pU+FXmyKJr53Jxw/hmDWZCoh42gsGDlVqpmytzsj74KlaYiMyZmEGbD7t/FGfNGV
LdLDJaHIYxEimFviOTXKCeKvPAECgYEA3d8tv4jdp1uAuRZiU9Z/tfw5mJOi3oXF
8AdFFDwaPzcTorEAxjrt9X6IjPbLIDJNJtuXYpe+dG6720KyuNnhLhWW9oZEJTwT
dUgqZ2fTCOS9uH0jSn+ZFlgTWI6UDQXRwE7z8avlhMIrQVmPsttGTo7V6sQVtGRx
bNj2RSVekTkCgYAJvy4UYLPHS0jWPfSLcfw8vp8JyhBjVgj7gncZW/kIrcP1xYYe
yfQSU8XmV40UjFfCGz/G318lmP0VOdByeVKtCV3talsMEPHyPqI8E+6DL/uOebYJ
qUqINK6XKnOgWOY4kvnGillqTQCcry1XQp61PlDOmj7kB75KxPXYrj6AAQKBgQDa
+ixCv6hURuEyy77cE/YT/Q4zYnL6wHjtP5+UKwWUop1EkwG6o+q7wtiul90+t6ah
1VUCP9X/QFM0Qg32l0PBohlO0pFrVnG17TW8vSHxwyDkds1f97N19BOT8ZR5jebI
sKPfP9LVRnY+l1BWLEilvB+xBzqMwh2YWkIlWI6PMQKBgGi6TBnxp81lOYrxVRDj
/3ycRnVDmBdlQKFunvfzUBmG1mG/G0YHeVSUKZJGX7w2l+jnDwIA383FcUeA8X6A
l9q+amhtkwD/6fbkAu/xoWNl+11IFoxd88y2ByBFoEKB6UVLuCTSKwXDqzEZet7x
mDyRxq7ohIzLkw8b8buDeuXZ
-----END PRIVATE KEY-----
''' # Should be read from a file on disk which can be rotated out

analytics.write_key = '9BWoGOi4lWVaQBP5NRheT7N0C1t4HTJM'
analytics.host = 'https://api.segment.build'

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update the urls to the prod versions?

analytics.oauth_client_id = '2VRbiyEPtsDBjFqO9Tu7EXvQH5H' # OAuth application ID from segment dashboard
analytics.oauth_client_key = privatekey # generated as a public/private key pair in PEM format from OpenSSL
analytics.oauth_key_id = '2VRbiuUFSMo4AGyrYyxxyONzqiP' # From segment dashboard after uploading public key
analytics.oauth_auth_server = 'https://oauth2.segment.build'
analytics.oauth_scope = 'tracking_api:write' # 'public_api:read_write'

def on_error(error, items):
print("An error occurred: ", error)
analytics.debug = True
analytics.on_error = on_error

analytics.track('AUser', 'track')

input("Press ENTER to exit after receiving a response...")
15 changes: 14 additions & 1 deletion segment/analytics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
timeout = Client.DefaultConfig.timeout
max_retries = Client.DefaultConfig.max_retries

"""Oauth Settings."""
oauth_client_id = Client.DefaultConfig.oauth_client_id
oauth_client_key = Client.DefaultConfig.oauth_client_key
oauth_key_id = Client.DefaultConfig.oauth_key_id
oauth_auth_server = Client.DefaultConfig.oauth_auth_server
oauth_scope = Client.DefaultConfig.oauth_scope

default_client = None


Expand Down Expand Up @@ -73,7 +80,13 @@ def _proxy(method, *args, **kwargs):
max_queue_size=max_queue_size,
send=send, on_error=on_error,
gzip=gzip, max_retries=max_retries,
sync_mode=sync_mode, timeout=timeout)
sync_mode=sync_mode, timeout=timeout,
oauth_client_id=oauth_client_id,
oauth_client_key=oauth_client_key,
oauth_key_id=oauth_key_id,
oauth_auth_server=oauth_auth_server,
oauth_scope=oauth_scope,
)

fn = getattr(default_client, method)
return fn(*args, **kwargs)
22 changes: 19 additions & 3 deletions segment/analytics/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json

from dateutil.tz import tzutc
from segment.analytics.oauth_manager import OauthManager

from segment.analytics.utils import guess_timezone, clean
from segment.analytics.consumer import Consumer, MAX_MSG_SIZE
Expand Down Expand Up @@ -33,6 +34,12 @@ class DefaultConfig(object):
thread = 1
upload_interval = 0.5
upload_size = 100
oauth_client_id = None
oauth_client_key = None
oauth_key_id = None
oauth_auth_server = 'https://oauth2.segment.io'
oauth_scope = 'tracking_api:write'


"""Create a new Segment client."""
log = logging.getLogger('segment')
Expand All @@ -51,7 +58,12 @@ def __init__(self,
proxies=DefaultConfig.proxies,
thread=DefaultConfig.thread,
upload_size=DefaultConfig.upload_size,
upload_interval=DefaultConfig.upload_interval,):
upload_interval=DefaultConfig.upload_interval,
oauth_client_id=DefaultConfig.oauth_client_id,
oauth_client_key=DefaultConfig.oauth_client_key,
oauth_key_id=DefaultConfig.oauth_key_id,
oauth_auth_server=DefaultConfig.oauth_auth_server,
oauth_scope=DefaultConfig.oauth_scope,):
require('write_key', write_key, str)

self.queue = queue.Queue(max_queue_size)
Expand All @@ -64,6 +76,9 @@ def __init__(self,
self.gzip = gzip
self.timeout = timeout
self.proxies = proxies
if(oauth_client_id and oauth_client_key and oauth_key_id):
self.oauth_manager = OauthManager(oauth_client_id, oauth_client_key, oauth_key_id,
MichaelGHSeg marked this conversation as resolved.
Show resolved Hide resolved
oauth_auth_server, oauth_scope, timeout, max_retries)

if debug:
self.log.setLevel(logging.DEBUG)
Expand All @@ -85,7 +100,7 @@ def __init__(self,
self.queue, write_key, host=host, on_error=on_error,
upload_size=upload_size, upload_interval=upload_interval,
gzip=gzip, retries=max_retries, timeout=timeout,
proxies=proxies,
proxies=proxies, oauth_manager=self.oauth_manager,
)
self.consumers.append(consumer)

Expand Down Expand Up @@ -280,7 +295,8 @@ def _enqueue(self, msg):
if self.sync_mode:
self.log.debug('enqueued with blocking %s.', msg['type'])
post(self.write_key, self.host, gzip=self.gzip,
timeout=self.timeout, proxies=self.proxies, batch=[msg])
timeout=self.timeout, proxies=self.proxies,
oauth_manager=self.oauth_manager, batch=[msg])

return True, msg

Expand Down
6 changes: 4 additions & 2 deletions segment/analytics/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Consumer(Thread):

def __init__(self, queue, write_key, upload_size=100, host=None,
on_error=None, upload_interval=0.5, gzip=False, retries=10,
timeout=15, proxies=None):
timeout=15, proxies=None, oauth_manager=None):
"""Create a consumer thread."""
Thread.__init__(self)
# Make consumer a daemon thread so that it doesn't block program exit
Expand All @@ -41,6 +41,7 @@ def __init__(self, queue, write_key, upload_size=100, host=None,
self.retries = retries
self.timeout = timeout
self.proxies = proxies
self.oauth_manager = oauth_manager

def run(self):
"""Runs the consumer."""
Expand Down Expand Up @@ -129,6 +130,7 @@ def fatal_exception(exc):
giveup=fatal_exception)
def send_request():
post(self.write_key, self.host, gzip=self.gzip,
timeout=self.timeout, batch=batch, proxies=self.proxies)
timeout=self.timeout, batch=batch, proxies=self.proxies,
oauth_manager=self.oauth_manager)

send_request()
159 changes: 159 additions & 0 deletions segment/analytics/oauth_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from datetime import date, datetime
import logging
import threading
import time
import uuid
from requests import sessions
import jwt

_session = sessions.Session()

class OauthManager(object):
def __init__(self,
client_id,
client_key,
key_id,
auth_server,
scope,
timeout,
max_retries):
self.log = logging.getLogger('segment')
self.client_id = client_id
self.client_key = client_key
self.key_id = key_id
self.auth_server = auth_server
self.scope = scope
self.timeout = timeout
self.max_retries = max_retries
self.retry_count = 0
self.thread = None
self.token_mutex = threading.Lock()
self.token = None
self.error = None

def get_token(self):
with self.token_mutex:
if self.token:
return self.token
# No good token, start the loop
self.thread = threading.Thread(target=self._poller_loop)
self.thread.start()

while True:
# Wait for a token or error
with self.token_mutex:
if self.token:
return self.token
if self.error:
error = self.error
self.error = None
raise Exception(error)
if self.thread:
self.thread.join()

def clear_token(self):
with self.token_mutex:
self.token = None

def _request_token(self):
jwt_body = {
"iss": self.client_id,
"sub": self.client_id,
"aud": "https://oauth2.segment.io",
"iat": int(time.time())-1,
"exp": int(time.time()) + 59,
"jti": str(uuid.uuid4())
}

signed_jwt = jwt.encode(
jwt_body,
self.client_key,
algorithm="RS256",
headers={"kid": self.key_id},
)

request_body = 'grant_type=client_credentials&client_assertion_type='\
'urn:ietf:params:oauth:client-assertion-type:jwt-bearer&'\
'client_assertion={}&scope={}'.format(signed_jwt, self.scope)

token_endpoint = f'{self.auth_server}/token'

res = _session.post(url=token_endpoint, data=request_body, timeout=self.timeout,
headers={'Content-Type': 'application/x-www-form-urlencoded'})
return res

def _poller_loop(self):
refresh_timer_ms = 25
response = None

try:
response = self._request_token()
except Exception as e:
self.log.error(e)
self.retry_count += 1
if self.retry_count < self.max_retries:
self.thread = threading.Timer(refresh_timer_ms / 1000.0, self._poller_loop)
self.thread.setDaemon(True)
self.thread.start()
return
# Too many retries, giving up
self.error = e
return

if response.status_code == 200:
data = None
try:
data = response.json()
except Exception as e:
self.retry_count += 1
if self.retry_count < self.max_retries:
self.thread = threading.Timer(refresh_timer_ms / 1000.0, self._poller_loop)
self.thread.setDaemon(True)
self.thread.start()
return
# Too many retries, giving up
self.error = e
return
try:
with self.token_mutex:
self.token = data['access_token']
# success!
self.retry_count = 0
except Exception as e:
# No access token in response?
self.log.error(e)
try:
refresh_timer_ms = int(data['expires_in']) / 2 * 1000
except Exception as e:
refresh_timer_ms = 60 * 1000

elif response.status_code == 429:
self.retry_count += 1
rate_limit_reset_time = None
try:
rate_limit_reset_time = int(response.headers.get("X-RateLimit-Reset"))
except Exception as e:
self.log.error(e)
if rate_limit_reset_time:
refresh_timer_ms = rate_limit_reset_time - time.time() * 1000
else:
refresh_timer_ms = 5 * 1000
elif response.status_code in [400, 401, 415]:
# unrecoverable errors
self.retry_count = 0
self.error = Exception(f'[{response.status_code}] {response.reason}')
self.log.error(self.error)
return
else:
# any other error
self.log.error(f'[{response.status_code}] {response.reason}')
self.retry_count += 1

if self.retry_count % self.max_retries == 0:
# every time we pass the retry count, put up an error to release any waiting token requests
self.error = Exception(f'[{response.status_code}] {response.reason}')

# loop
self.thread = threading.Timer(refresh_timer_ms / 1000.0, self._poller_loop)
self.thread.setDaemon(True)
self.thread.start()
20 changes: 15 additions & 5 deletions segment/analytics/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,28 @@
_session = sessions.Session()


def post(write_key, host=None, gzip=False, timeout=15, proxies=None, **kwargs):
def post(write_key, host=None, gzip=False, timeout=15, proxies=None, oauth_manager=None, **kwargs):
"""Post the `kwargs` to the API"""
log = logging.getLogger('segment')
body = kwargs
body["sentAt"] = datetime.utcnow().replace(tzinfo=tzutc()).isoformat()
body["writeKey"] = write_key
url = remove_trailing_slash(host or 'https://api.segment.io') + '/v1/batch'
auth = HTTPBasicAuth(write_key, '')
auth = None
if oauth_manager:
try:
auth = oauth_manager.get_token()
except Exception as e:
raise e
data = json.dumps(body, cls=DatetimeSerializer)
log.debug('making request: %s', data)
headers = {
'Content-Type': 'application/json',
'User-Agent': 'analytics-python/' + VERSION
}
if auth:
headers['Authorization'] = 'Bearer {}'.format(auth)

if gzip:
headers['Content-Encoding'] = 'gzip'
buf = BytesIO()
Expand All @@ -37,21 +46,22 @@ def post(write_key, host=None, gzip=False, timeout=15, proxies=None, **kwargs):

kwargs = {
"data": data,
"auth": auth,
"headers": headers,
"timeout": 15,
}

if proxies:
kwargs['proxies'] = proxies

res = _session.post(url, data=data, auth=auth,
headers=headers, timeout=timeout)
res = _session.post(url, data=data, headers=headers, timeout=timeout)

if res.status_code == 200:
log.debug('data uploaded successfully')
return res

if oauth_manager and res.status_code in [400, 401, 403]:
oauth_manager.clear_token()

try:
payload = res.json()
log.debug('received response: %s', payload)
Expand Down