Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/bearer auth #17

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions mailjet_rest/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import logging

import requests
from requests.auth import HTTPBasicAuth
from requests.compat import urljoin
from .utils.version import get_version
from .utils.auth import HTTPBearerAuth

requests.packages.urllib3.disable_warnings()

Expand All @@ -27,14 +29,15 @@ def __getitem__(self, key):
# Forward slash is ignored if present in self.version.
url = urljoin(url, self.version + '/')
headers = {'Content-type': 'application/json', 'User-agent': self.user_agent}
if key.lower() == 'contactslist_csvdata':
url = urljoin(url, 'DATA/')
headers['Content-type'] = 'text/plain'
elif key.lower() == 'batchjob_csverror':
url = urljoin(url, 'DATA/')
headers['Content-type'] = 'text/csv'
elif key.lower() != 'send':
url = urljoin(url, 'REST/')
if self.version == 'v3' or self.version == 'v3.1':
if key.lower() == 'contactslist_csvdata':
url = urljoin(url, 'DATA/')
headers['Content-type'] = 'text/plain'
elif key.lower() == 'batchjob_csverror':
url = urljoin(url, 'DATA/')
headers['Content-type'] = 'text/csv'
elif key.lower() != 'send':
url = urljoin(url, 'REST/')
url = url + key.split('_')[0].lower()
return url, headers

Expand Down Expand Up @@ -73,7 +76,7 @@ def delete(self, id, **kwargs):
class Client(object):

def __init__(self, auth=None, **kwargs):
self.auth = auth
self.auth = get_auth(auth)
version = kwargs.get('version', None)
self.config = Config(version=version)

Expand All @@ -91,6 +94,16 @@ def __getattr__(self, name):
return type(fname, (Endpoint,), {})(url=url, headers=headers, action=action, auth=self.auth)


def get_auth(auth):
if isinstance(auth, tuple) and len(auth) == 2:
return HTTPBasicAuth(*auth)
elif isinstance(auth, str):
return HTTPBearerAuth(auth)
else:
raise AuthorizationError('Unsupported authorization format!')



def api_call(auth, method, url, headers, data=None, filters=None, resource_id=None,
timeout=60, debug=False, action=None, action_id=None, **kwargs):
url = build_url(url, method=method, action=action, resource_id=resource_id, action_id=action_id)
Expand Down
25 changes: 25 additions & 0 deletions mailjet_rest/utils/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from requests.auth import AuthBase


def _bearer_token_str(token):
if not isinstance(token, str):
token = str(token)
authstr = 'Bearer {}'.format(token)
return authstr


class HTTPBearerAuth(AuthBase):
"""Attaches HTTP Bearer Authentication to the given Request object."""

def __init__(self, token):
self.token = token

def __eq__(self, other):
return self.token == getattr(other, 'token', None)

def __ne__(self, other):
return not self == other

def __call__(self, r):
r.headers['Authorization'] = _bearer_token_str(self.token)
return r
2 changes: 1 addition & 1 deletion mailjet_rest/utils/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
VERSION = (1, 3, 0)
VERSION = (1, 4, 0)


def get_version(version=None):
Expand Down
63 changes: 61 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import unittest
from requests.auth import HTTPBasicAuth
from mailjet_rest import Client
from mailjet_rest.client import get_auth, AuthorizationError
from mailjet_rest.utils.auth import HTTPBearerAuth
import os
import random
import string
Expand Down Expand Up @@ -74,7 +77,18 @@ def test_post_with_no_param(self):
result = self.client.sender.create(data={}).json()
self.assertTrue('StatusCode' in result and result['StatusCode'] is not 400)

def test_client_custom_version(self):
def test_client_v31(self):
self.client = Client(
auth=self.auth,
version='v3.1'
)
self.assertEqual(self.client.config.version, 'v3.1')
self.assertEqual(
self.client.config['contact'][0],
'https://api.mailjet.com/v3.1/REST/contact'
)

def test_client_v31_send(self):
self.client = Client(
auth=self.auth,
version='v3.1'
Expand All @@ -85,12 +99,57 @@ def test_client_custom_version(self):
'https://api.mailjet.com/v3.1/send'
)

def test_client_v3(self):
self.client = Client(
auth=self.auth,
version='v3'
)
self.assertEqual(self.client.config.version, 'v3')
self.assertEqual(
self.client.config['contact'][0],
'https://api.mailjet.com/v3/REST/contact'
)

def test_client_v3_send(self):
self.client = Client(
auth=self.auth,
version='v3'
)
self.assertEqual(self.client.config.version, 'v3')
self.assertEqual(
self.client.config['send'][0],
'https://api.mailjet.com/v3/send'
)

def test_client_v4(self):
self.client = Client(
auth=self.auth,
version='v4'
)
self.assertEqual(self.client.config.version, 'v4')
self.assertEqual(
self.client.config['sms'][0],
'https://api.mailjet.com/v4/sms'
)

def test_user_agent(self):
self.client = Client(
auth=self.auth,
version='v3.1'
)
self.assertEqual(self.client.config.user_agent, 'mailjet-apiv3-python/v1.3.0')
self.assertEqual(self.client.config.user_agent, 'mailjet-apiv3-python/v1.4.0')

def test_get_auth_bearer(self):
auth = get_auth('bearer_token')
self.assertIsInstance(auth, HTTPBearerAuth)

def test_get_auth_basic(self):
auth = get_auth(('api_key', 'api_secret'))
self.assertIsInstance(auth, HTTPBasicAuth)

def test_get_auth_unrecognised(self):
with self.assertRaises(AuthorizationError):
get_auth(['api_key', 'api_secret'])


if __name__ == '__main__':
Expand Down