Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
fix client hashing in nested client params case (#373)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Jan 24, 2024
1 parent 897ab43 commit 04b9000
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 2 deletions.
4 changes: 3 additions & 1 deletion prefect_aws/client_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from botocore.client import Config
from pydantic import VERSION as PYDANTIC_VERSION

from prefect_aws.utilities import hash_collection

if PYDANTIC_VERSION.startswith("2."):
from pydantic.v1 import BaseModel, Field, FilePath, root_validator, validator
else:
Expand Down Expand Up @@ -78,7 +80,7 @@ def __hash__(self):
self.verify,
self.verify_cert_path,
self.endpoint_url,
self.config,
hash_collection(self.config),
)
)

Expand Down
2 changes: 1 addition & 1 deletion prefect_aws/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __hash__(self):
hash(self.aws_session_token),
hash(self.profile_name),
hash(self.region_name),
hash(frozenset(self.aws_client_parameters.dict().items())),
hash(self.aws_client_parameters),
)
return hash(field_hashes)

Expand Down
35 changes: 35 additions & 0 deletions prefect_aws/utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Utilities for working with AWS services."""

from prefect.utilities.collections import visit_collection


def hash_collection(collection) -> int:
"""Use visit_collection to transform and hash a collection.
Args:
collection (Any): The collection to hash.
Returns:
int: The hash of the transformed collection.
Example:
```python
from prefect_aws.utilities import hash_collection
hash_collection({"a": 1, "b": 2})
```
"""

def make_hashable(item):
"""Make an item hashable by converting it to a tuple."""
if isinstance(item, dict):
return tuple(sorted((k, make_hashable(v)) for k, v in item.items()))
elif isinstance(item, list):
return tuple(make_hashable(v) for v in item)
return item

hashable_collection = visit_collection(
collection, visit_fn=make_hashable, return_data=True
)
return hash(hashable_collection)
25 changes: 25 additions & 0 deletions tests/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,28 @@ def test_aws_credentials_hash_changes(credentials_type, initial_field, new_field
new_hash = hash(credentials)

assert initial_hash != new_hash, "Hash should change when region_name changes"


def test_aws_credentials_nested_client_parameters_are_hashable():
"""
Test to ensure that nested client parameters are hashable.
"""

creds = AwsCredentials(
region_name="us-east-1",
aws_client_parameters=dict(
config=dict(
connect_timeout=5,
read_timeout=5,
retries=dict(max_attempts=10, mode="standard"),
)
),
)

assert hash(creds) is not None

client = creds.get_client("s3")

_client = creds.get_client("s3")

assert client is _client
34 changes: 34 additions & 0 deletions tests/test_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest

from prefect_aws.utilities import hash_collection


class TestHashCollection:
def test_simple_dict(self):
simple_dict = {"key1": "value1", "key2": "value2"}
assert hash_collection(simple_dict) == hash_collection(
simple_dict
), "Simple dictionary hashing failed"

def test_nested_dict(self):
nested_dict = {"key1": {"subkey1": "subvalue1"}, "key2": "value2"}
assert hash_collection(nested_dict) == hash_collection(
nested_dict
), "Nested dictionary hashing failed"

def test_complex_structure(self):
complex_structure = {
"key1": [1, 2, 3],
"key2": {"subkey1": {"subsubkey1": "value"}},
}
assert hash_collection(complex_structure) == hash_collection(
complex_structure
), "Complex structure hashing failed"

def test_unhashable_structure(self):
typically_unhashable_structure = dict(key=dict(subkey=[1, 2, 3]))
with pytest.raises(TypeError):
hash(typically_unhashable_structure)
assert hash_collection(typically_unhashable_structure) == hash_collection(
typically_unhashable_structure
), "Unhashable structure hashing failed after transformation"

0 comments on commit 04b9000

Please sign in to comment.