Skip to content

Commit

Permalink
Integrate auth with AuthBase from requests module
Browse files Browse the repository at this point in the history
  • Loading branch information
J535D165 committed Nov 18, 2023
1 parent 64070a0 commit 1305fcc
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 11 deletions.
42 changes: 31 additions & 11 deletions pyalex/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from urllib.parse import quote_plus

import requests
from requests.auth import AuthBase
from urllib3.util import Retry

try:
Expand All @@ -22,6 +23,7 @@ def __setattr__(self, key, value):
config = AlexConfig(
email=None,
api_key=None,
user_agent="pyalex/" + __version__,
openalex_url="https://api.openalex.org",
max_retries=0,
retry_backoff_factor=0.1,
Expand Down Expand Up @@ -160,6 +162,32 @@ def __next__(self):
return results


class OpenAlexAuth(AuthBase):
"""OpenAlex auth class based on requests auth
Includes the email, api_key and user-agent headers.
arguments:
config: an AlexConfig object
"""

def __init__(self, config):
self.config = config

def __call__(self, r):
if self.config.api_key:
r.headers["Authorization"] = f"Bearer {self.config.api_key}"

if self.config.email:
r.headers["From"] = self.config.email

if self.config.user_agent:
r.headers["User-Agent"] = self.config.user_agent

return r


class BaseOpenAlex:
"""Base class for OpenAlex objects."""

Expand Down Expand Up @@ -222,13 +250,7 @@ def count(self):
return m["count"]

def _get_from_url(self, url, return_meta=False):
params = {"api_key": config.api_key} if config.api_key else {}

res = _get_requests_session().get(
url,
headers={"User-Agent": "pyalex/" + __version__, "email": config.email},
params=params,
)
res = _get_requests_session().get(url, auth=OpenAlexAuth(config))

# handle query errors
if res.status_code == 403:
Expand Down Expand Up @@ -334,11 +356,9 @@ def __getitem__(self, key):

def ngrams(self, return_meta=False):
openalex_id = self["id"].split("/")[-1]
n_gram_url = f"{config.openalex_url}/works/{openalex_id}/ngrams"

res = _get_requests_session().get(
f"{config.openalex_url}/works/{openalex_id}/ngrams",
headers={"User-Agent": "pyalex/" + __version__, "email": config.email},
)
res = _get_requests_session().get(n_gram_url, auth=OpenAlexAuth(config))
res.raise_for_status()
results = res.json()

Expand Down
13 changes: 13 additions & 0 deletions tests/test_pyalex.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,16 @@ def test_sample_seed():
def test_subset():
url = "https://api.openalex.org/works?select=id,doi,display_name"
assert url == Works().select(["id", "doi", "display_name"]).url


def test_auth():
w_no_auth = Works().get()
pyalex.config.email = "[email protected]"
pyalex.config.api_key = "my_api_key"

w_auth = Works().get()

pyalex.config.email = None
pyalex.config.api_key = None

assert len(w_no_auth) == len(w_auth)

0 comments on commit 1305fcc

Please sign in to comment.