Skip to content

Commit

Permalink
Replace hardcoded session with http_client param
Browse files Browse the repository at this point in the history
Remove timeout parameter, requests and httpx behaviors are incompatible anyway
  • Loading branch information
rayluo committed Apr 23, 2020
1 parent abd1394 commit cb83fa8
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 19 deletions.
77 changes: 58 additions & 19 deletions oauth2cli/oauth2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""This OAuth2 client implementation aims to be spec-compliant, and generic."""
# OAuth2 spec https://tools.ietf.org/html/rfc6749

import json
try:
from urllib.parse import urlencode, parse_qs
except ImportError:
Expand All @@ -11,6 +12,7 @@
import time
import base64
import sys
import functools

import requests

Expand All @@ -35,6 +37,7 @@ def __init__(
self,
server_configuration, # type: dict
client_id, # type: str
http_client=None, # We insert it here to match the upcoming async API
client_secret=None, # type: Optional[str]
client_assertion=None, # type: Union[bytes, callable, None]
client_assertion_type=None, # type: Optional[str]
Expand All @@ -57,6 +60,9 @@ def __init__(
or
https://example.com/.../.well-known/openid-configuration
client_id (str): The client's id, issued by the authorization server
http_client (http.HttpClient):
Your implementation of abstract class :class:`http.HttpClient`.
Defaults to a requests session instance.
client_secret (str): Triggers HTTP AUTH for Confidential Client
client_assertion (bytes, callable):
The client assertion to authenticate this client, per RFC 7521.
Expand All @@ -76,20 +82,51 @@ def __init__(
you could choose to set this as {"client_secret": "your secret"}
if your authorization server wants it to be in the request body
(rather than in the request header).
verify (boolean):
It will be passed to the
`verify parameter in the underlying requests library
<http://docs.python-requests.org/en/v2.9.1/user/advanced/#ssl-cert-verification>`_
This does not apply if you have chosen to pass your own Http client.
proxies (dict):
It will be passed to the
`proxies parameter in the underlying requests library
<http://docs.python-requests.org/en/v2.9.1/user/advanced/#proxies>`_
This does not apply if you have chosen to pass your own Http client.
timeout (object):
It will be passed to the
`timeout parameter in the underlying requests library
<http://docs.python-requests.org/en/v2.9.1/user/advanced/#timeouts>`_
This does not apply if you have chosen to pass your own Http client.
There is no session-wide `timeout` parameter defined here.
The timeout behavior is determined by the actual http client you use.
If you happen to use Requests, it chose to not support session-wide timeout
(https://github.com/psf/requests/issues/3341), but you can patch that by:
s = requests.Session()
s.request = functools.partial(s.request, timeout=3)
and then feed that patched session instance to this class.
"""
self.configuration = server_configuration
self.client_id = client_id
self.client_secret = client_secret
self.client_assertion = client_assertion
self.default_headers = default_headers or {}
self.default_body = default_body or {}
if client_assertion_type is not None:
self.default_body["client_assertion_type"] = client_assertion_type
self.logger = logging.getLogger(__name__)
self.session = s = requests.Session()
s.headers.update(default_headers or {})
s.verify = verify
s.proxies = proxies or {}
self.timeout = timeout
if http_client:
self.http_client = http_client
else:
self.http_client = requests.Session()
self.http_client.verify = verify
self.http_client.proxies = proxies
self.http_client.request = functools.partial(
# A workaround for requests not supporting session-wide timeout
self.http_client.request, timeout=timeout)

def _build_auth_request_params(self, response_type, **kwargs):
# response_type is a string defined in
Expand All @@ -110,7 +147,6 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
params=None, # a dict to be sent as query string to the endpoint
data=None, # All relevant data, which will go into the http body
headers=None, # a dict to be sent as request headers
timeout=None,
post=None, # A callable to replace requests.post(), for testing.
# Such as: lambda url, **kwargs:
# Mock(status_code=200, json=Mock(return_value={}))
Expand All @@ -128,38 +164,40 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749

_data.update(self.default_body) # It may contain authen parameters
_data.update(data or {}) # So the content in data param prevails
# We don't have to clean up None values here, because requests lib will.
_data = {k: v for k, v in _data.items() if v} # Clean up None values

if _data.get('scope'):
_data['scope'] = self._stringify(_data['scope'])

_headers = {'Accept': 'application/json'}
_headers.update(self.default_headers)
_headers.update(headers or {})

# Quoted from https://tools.ietf.org/html/rfc6749#section-2.3.1
# Clients in possession of a client password MAY use the HTTP Basic
# authentication.
# Alternatively, (but NOT RECOMMENDED,)
# the authorization server MAY support including the
# client credentials in the request-body using the following
# parameters: client_id, client_secret.
auth = None
if self.client_secret and self.client_id:
auth = (self.client_id, self.client_secret) # for HTTP Basic Auth
_headers["Authorization"] = "Basic " + base64.b64encode(
"{}:{}".format(self.client_id, self.client_secret)
.encode("ascii")).decode("ascii")

if "token_endpoint" not in self.configuration:
raise ValueError("token_endpoint not found in configuration")
_headers = {'Accept': 'application/json'}
_headers.update(headers or {})
resp = (post or self.session.post)(
resp = (post or self.http_client.post)(
self.configuration["token_endpoint"],
headers=_headers, params=params, data=_data, auth=auth,
timeout=timeout or self.timeout,
headers=_headers, params=params, data=_data,
**kwargs)
if resp.status_code >= 500:
resp.raise_for_status() # TODO: Will probably retry here
try:
# The spec (https://tools.ietf.org/html/rfc6749#section-5.2) says
# even an error response will be a valid json structure,
# so we simply return it here, without needing to invent an exception.
return resp.json()
return json.loads(resp.text)
except ValueError:
self.logger.exception(
"Token response is not in json format: %s", resp.text)
Expand Down Expand Up @@ -200,7 +238,7 @@ class Client(BaseClient): # We choose to implement all 4 grants in 1 class
grant_assertion_encoders = {GRANT_TYPE_SAML2: BaseClient.encode_saml_assertion}


def initiate_device_flow(self, scope=None, timeout=None, **kwargs):
def initiate_device_flow(self, scope=None, **kwargs):
# type: (list, **dict) -> dict
# The naming of this method is following the wording of this specs
# https://tools.ietf.org/html/draft-ietf-oauth-device-flow-12#section-3.1
Expand All @@ -218,10 +256,11 @@ def initiate_device_flow(self, scope=None, timeout=None, **kwargs):
DAE = "device_authorization_endpoint"
if not self.configuration.get(DAE):
raise ValueError("You need to provide device authorization endpoint")
flow = self.session.post(self.configuration[DAE],
resp = self.http_client.post(self.configuration[DAE],
data={"client_id": self.client_id, "scope": self._stringify(scope or [])},
timeout=timeout or self.timeout,
**kwargs).json()
headers=dict(self.default_headers, **kwargs.pop("headers", {})),
**kwargs)
flow = json.loads(resp.text)
flow["interval"] = int(flow.get("interval", 5)) # Some IdP returns string
flow["expires_in"] = int(flow.get("expires_in", 1800))
flow["expires_at"] = time.time() + flow["expires_in"] # We invent this
Expand Down
30 changes: 30 additions & 0 deletions tests/http_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import requests


class MinimalHttpClient:

def __init__(self, verify=True, proxies=None, timeout=None):
self.session = requests.Session()
self.session.verify = verify
self.session.proxies = proxies
self.timeout = timeout

def post(self, url, params=None, data=None, headers=None, **kwargs):
return MinimalResponse(requests_resp=self.session.post(
url, params=params, data=data, headers=headers,
timeout=self.timeout))

def get(self, url, params=None, headers=None, **kwargs):
return MinimalResponse(requests_resp=self.session.get(
url, params=params, headers=headers, timeout=self.timeout))


class MinimalResponse(object): # Not for production use
def __init__(self, requests_resp=None, status_code=None, text=None):
self.status_code = status_code or requests_resp.status_code
self.text = text or requests_resp.text
self._raw_resp = requests_resp

def raise_for_status(self):
if self._raw_resp:
self._raw_resp.raise_for_status()
4 changes: 4 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from oauth2cli.authcode import obtain_auth_code
from oauth2cli.assertion import JwtSigner
from tests import unittest, Oauth2TestCase
from tests.http_client import MinimalHttpClient


logging.basicConfig(level=logging.DEBUG)
Expand Down Expand Up @@ -83,13 +84,15 @@ class TestClient(Oauth2TestCase):

@classmethod
def setUpClass(cls):
http_client = MinimalHttpClient()
if "client_certificate" in CONFIG:
private_key_path = CONFIG["client_certificate"]["private_key_path"]
with open(os.path.join(THIS_FOLDER, private_key_path)) as f:
private_key = f.read() # Expecting PEM format
cls.client = Client(
CONFIG["openid_configuration"],
CONFIG['client_id'],
http_client=http_client,
client_assertion=JwtSigner(
private_key,
algorithm="RS256",
Expand All @@ -103,6 +106,7 @@ def setUpClass(cls):
else:
cls.client = Client(
CONFIG["openid_configuration"], CONFIG['client_id'],
http_client=http_client,
client_secret=CONFIG.get('client_secret'))

@unittest.skipIf(
Expand Down

0 comments on commit cb83fa8

Please sign in to comment.