diff --git a/mailjet_rest/client.py b/mailjet_rest/client.py index 175d885..7d1b8f2 100644 --- a/mailjet_rest/client.py +++ b/mailjet_rest/client.py @@ -5,8 +5,10 @@ import logging import requests +from requests.auth import HTTPBasicAuth from requests.compat import urljoin from .utils.version import get_version +from .utils.auth import HTTPBearerAuth requests.packages.urllib3.disable_warnings() @@ -27,14 +29,15 @@ def __getitem__(self, key): # Forward slash is ignored if present in self.version. url = urljoin(url, self.version + '/') headers = {'Content-type': 'application/json', 'User-agent': self.user_agent} - if key.lower() == 'contactslist_csvdata': - url = urljoin(url, 'DATA/') - headers['Content-type'] = 'text/plain' - elif key.lower() == 'batchjob_csverror': - url = urljoin(url, 'DATA/') - headers['Content-type'] = 'text/csv' - elif key.lower() != 'send': - url = urljoin(url, 'REST/') + if self.version == 'v3' or self.version == 'v3.1': + if key.lower() == 'contactslist_csvdata': + url = urljoin(url, 'DATA/') + headers['Content-type'] = 'text/plain' + elif key.lower() == 'batchjob_csverror': + url = urljoin(url, 'DATA/') + headers['Content-type'] = 'text/csv' + elif key.lower() != 'send': + url = urljoin(url, 'REST/') url = url + key.split('_')[0].lower() return url, headers @@ -73,7 +76,7 @@ def delete(self, id, **kwargs): class Client(object): def __init__(self, auth=None, **kwargs): - self.auth = auth + self.auth = get_auth(auth) version = kwargs.get('version', None) self.config = Config(version=version) @@ -91,6 +94,16 @@ def __getattr__(self, name): return type(fname, (Endpoint,), {})(url=url, headers=headers, action=action, auth=self.auth) +def get_auth(auth): + if isinstance(auth, tuple) and len(auth) == 2: + return HTTPBasicAuth(*auth) + elif isinstance(auth, str): + return HTTPBearerAuth(auth) + else: + raise AuthorizationError('Unsupported authorization format!') + + + def api_call(auth, method, url, headers, data=None, filters=None, resource_id=None, timeout=60, debug=False, action=None, action_id=None, **kwargs): url = build_url(url, method=method, action=action, resource_id=resource_id, action_id=action_id) diff --git a/mailjet_rest/utils/auth.py b/mailjet_rest/utils/auth.py new file mode 100644 index 0000000..d08a005 --- /dev/null +++ b/mailjet_rest/utils/auth.py @@ -0,0 +1,25 @@ +from requests.auth import AuthBase + + +def _bearer_token_str(token): + if not isinstance(token, str): + token = str(token) + authstr = 'Bearer {}'.format(token) + return authstr + + +class HTTPBearerAuth(AuthBase): + """Attaches HTTP Bearer Authentication to the given Request object.""" + + def __init__(self, token): + self.token = token + + def __eq__(self, other): + return self.token == getattr(other, 'token', None) + + def __ne__(self, other): + return not self == other + + def __call__(self, r): + r.headers['Authorization'] = _bearer_token_str(self.token) + return r diff --git a/mailjet_rest/utils/version.py b/mailjet_rest/utils/version.py index 93b500a..978167f 100644 --- a/mailjet_rest/utils/version.py +++ b/mailjet_rest/utils/version.py @@ -1,4 +1,4 @@ -VERSION = (1, 3, 0) +VERSION = (1, 4, 0) def get_version(version=None): diff --git a/test.py b/test.py index 69712e8..5359e1d 100644 --- a/test.py +++ b/test.py @@ -1,5 +1,8 @@ import unittest +from requests.auth import HTTPBasicAuth from mailjet_rest import Client +from mailjet_rest.client import get_auth, AuthorizationError +from mailjet_rest.utils.auth import HTTPBearerAuth import os import random import string @@ -74,7 +77,18 @@ def test_post_with_no_param(self): result = self.client.sender.create(data={}).json() self.assertTrue('StatusCode' in result and result['StatusCode'] is not 400) - def test_client_custom_version(self): + def test_client_v31(self): + self.client = Client( + auth=self.auth, + version='v3.1' + ) + self.assertEqual(self.client.config.version, 'v3.1') + self.assertEqual( + self.client.config['contact'][0], + 'https://api.mailjet.com/v3.1/REST/contact' + ) + + def test_client_v31_send(self): self.client = Client( auth=self.auth, version='v3.1' @@ -85,12 +99,57 @@ def test_client_custom_version(self): 'https://api.mailjet.com/v3.1/send' ) + def test_client_v3(self): + self.client = Client( + auth=self.auth, + version='v3' + ) + self.assertEqual(self.client.config.version, 'v3') + self.assertEqual( + self.client.config['contact'][0], + 'https://api.mailjet.com/v3/REST/contact' + ) + + def test_client_v3_send(self): + self.client = Client( + auth=self.auth, + version='v3' + ) + self.assertEqual(self.client.config.version, 'v3') + self.assertEqual( + self.client.config['send'][0], + 'https://api.mailjet.com/v3/send' + ) + + def test_client_v4(self): + self.client = Client( + auth=self.auth, + version='v4' + ) + self.assertEqual(self.client.config.version, 'v4') + self.assertEqual( + self.client.config['sms'][0], + 'https://api.mailjet.com/v4/sms' + ) + def test_user_agent(self): self.client = Client( auth=self.auth, version='v3.1' ) - self.assertEqual(self.client.config.user_agent, 'mailjet-apiv3-python/v1.3.0') + self.assertEqual(self.client.config.user_agent, 'mailjet-apiv3-python/v1.4.0') + + def test_get_auth_bearer(self): + auth = get_auth('bearer_token') + self.assertIsInstance(auth, HTTPBearerAuth) + + def test_get_auth_basic(self): + auth = get_auth(('api_key', 'api_secret')) + self.assertIsInstance(auth, HTTPBasicAuth) + + def test_get_auth_unrecognised(self): + with self.assertRaises(AuthorizationError): + get_auth(['api_key', 'api_secret']) if __name__ == '__main__':