-
Notifications
You must be signed in to change notification settings - Fork 238
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add sagemaker auth integration * Update the implementation, no env vars are used except for AWS_OPIK_AUTH now * Remove error catching, rename functions * Remove config import * Update env variable to check * Fix url update * Convert httpx url to string before passing get_signed_request * Update hooks mechanism to support many hooks. Move all sagemaker hook logic into the hook so that it is executed when httpx client is created and not on import when env vars can still be not set * Fix lint errors, rename registered hooks list * Rename hooks list --------- Co-authored-by: Nimrod Lahav <[email protected]>
- Loading branch information
1 parent
217cfee
commit 0c3956f
Showing
5 changed files
with
60 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,13 @@ | ||
import httpx | ||
from typing import Callable, List | ||
|
||
_registered_httpx_client_hooks: List[Callable[[httpx.Client], None]] = [] | ||
|
||
def httpx_client_hook(client: httpx.Client) -> None: | ||
pass | ||
|
||
def register_httpx_client_hook(hook: Callable[[httpx.Client], httpx.Client]) -> None: | ||
_registered_httpx_client_hooks.append(hook) | ||
|
||
|
||
def run_httpx_client_hooks(client: httpx.Client) -> None: | ||
for hook in _registered_httpx_client_hooks: | ||
hook(client) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import logging | ||
import os | ||
import httpx | ||
|
||
from typing import Any | ||
|
||
import opik.hooks | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
def _in_aws_sagemaker() -> bool: | ||
return os.getenv("AWS_PARTNER_APP_AUTH") is not None | ||
|
||
|
||
class SagemakerAuth(httpx.Auth): | ||
def __init__(self, auth_provider: Any) -> None: | ||
self.auth_provider = auth_provider | ||
|
||
def auth_flow(self, request): # type: ignore | ||
if not _in_aws_sagemaker(): | ||
yield request | ||
|
||
url, signed_headers = self.auth_provider.get_signed_request( | ||
str(request.url), request.method, request.headers, request.content | ||
) | ||
|
||
request.url = httpx.URL(url) | ||
request.headers.update(signed_headers) | ||
|
||
yield request | ||
|
||
|
||
def setup_aws_sagemaker_session_hook() -> None: | ||
def sagemaker_auth_client_hook(client: httpx.Client) -> None: | ||
if not _in_aws_sagemaker(): | ||
return | ||
|
||
import sagemaker | ||
|
||
auth_provider = sagemaker.PartnerAppAuthProvider() | ||
sagemaker_auth = SagemakerAuth(auth_provider) | ||
|
||
client.auth = sagemaker_auth | ||
|
||
opik.hooks.register_httpx_client_hook(sagemaker_auth_client_hook) |