Skip to content

Commit

Permalink
Moved CachedRecordRegistry
Browse files Browse the repository at this point in the history
  • Loading branch information
srfwx committed Dec 10, 2024
1 parent 2107c12 commit eaa75ad
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 54 deletions.
13 changes: 6 additions & 7 deletions pynetbox/core/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
limitations under the License.
"""

from collections import OrderedDict

from pynetbox.core.query import Request, RequestError
from pynetbox.core.response import Record, RecordSet

Expand All @@ -36,10 +34,13 @@ def get(self, object_type, key):
"""
Retrieves a record from the cache
"""
object_cache = self._cache.get(object_type)
if object_cache is None:
if not (object_cache := self._cache.get(object_type)):
return None
return object_cache.get(key, None)
if object := object_cache.get(key, None):
self._hit += 1
return object
self._miss += 1
return None

def set(self, object_type, key, value):
"""
Expand Down Expand Up @@ -83,8 +84,6 @@ def __init__(self, api, app, name, model=None):
endpoint=self.name,
)
self._choices = None
self._attribute_type_map = {}
self._attribute_endpoint_map = {}
self._init_cache()

def _init_cache(self):
Expand Down
55 changes: 10 additions & 45 deletions pynetbox/core/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,32 +47,6 @@ class JsonField:
_json_field = True


class CachedRecordRegistry:
"""
A cache for Record objects.
"""

def __init__(self, api):
self.api = api
self._cache = {}
self._hit = 0
self._miss = 0

def get(self, object_type, key):
"""
Retrieves a record from the cache
"""
return self._cache.get(object_type, {}).get(key, None)

def set(self, object_type, key, value):
"""
Stores a record in the cache
"""
if object_type not in self._cache:
self._cache[object_type] = {}
self._cache[object_type][key] = value


class RecordSet:
"""Iterator containing Record objects.
Expand Down Expand Up @@ -107,7 +81,6 @@ def __init__(self, endpoint, request, **kwargs):
self.request = request
self.response = self.request.get()
self._response_cache = []
self._init_endpoint_cache()

def __iter__(self):
return self
Expand All @@ -126,15 +99,9 @@ def __next__(self):
self.endpoint,
)
except StopIteration:
self._clear_endpoint_cache()
self.endpoint._init_cache()
raise

def _init_endpoint_cache(self):
self.endpoint._cache = CachedRecordRegistry(self.endpoint.api)

def _clear_endpoint_cache(self):
self.endpoint._cache = None

def __len__(self):
try:
return self.request.count
Expand Down Expand Up @@ -413,13 +380,11 @@ def _get_or_init(self, object_type, key, value, model):
Returns a record from the endpoint cache if it exists, otherwise
initializes a new record, store it in the cache, and return it.
"""
if key and self._endpoint and self._endpoint._cache:
if self._endpoint:
if cached := self._endpoint._cache.get(object_type, key):
self._endpoint._cache._hit += 1
return cached
record = model(value, self.api, None)
if key and self._endpoint and self._endpoint._cache:
self._endpoint._cache._miss += 1
if self._endpoint:
self._endpoint._cache.set(object_type, key, record)
return record

Expand Down Expand Up @@ -456,15 +421,15 @@ def dict_parser(key_name, value, model=None):

return value, deep_copy(value)

def mixed_list_parser(value):
def generic_list_parser(value):
from pynetbox.models.mapper import CONTENT_TYPE_MAPPER

parsed_list = []
for item in value:
lookup = item["object_type"]
if model := CONTENT_TYPE_MAPPER.get(lookup, None):
object_type = item["object_type"]
if model := CONTENT_TYPE_MAPPER.get(object_type, None):
item = self._get_or_init(
lookup, item["object"]["url"], item["object"], model
object_type, item["object"]["url"], item["object"], model
)
parsed_list.append(GenericListObject(item))
return parsed_list
Expand All @@ -480,9 +445,9 @@ def list_parser(key_name, value):
if not isinstance(sample_item, dict):
return value, [*value]

is_mixed_list = "object_type" in sample_item and "object" in sample_item
if is_mixed_list:
value = mixed_list_parser(value)
is_generic_list = "object_type" in sample_item and "object" in sample_item
if is_generic_list:
value = generic_list_parser(value)
else:
lookup = getattr(self.__class__, key_name, None)
model = lookup[0] if isinstance(lookup, list) else self.default_ret
Expand Down
8 changes: 6 additions & 2 deletions tests/unit/test_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ def test_nested_write(self):
endpoint = Mock()
endpoint.name = "test-endpoint"
endpoint.url = "http://localhost:8080/api/test-app/test-endpoint/"
endpoint._cache = None
endpoint._cache = Mock()
endpoint._cache.get = Mock(return_value=None)
api.test_app = Mock()
api.test_app.test_endpoint = endpoint
test = Record(
Expand All @@ -248,6 +249,8 @@ def test_nested_write(self):
)
test.child.name = "test321"
test.child.save()
print(api)
print(api.http_session)
self.assertEqual(
api.http_session.patch.call_args[0][0],
"http://localhost:8080/api/test-app/test-endpoint/321/",
Expand All @@ -260,7 +263,8 @@ def test_nested_write_with_directory_in_base_url(self):
endpoint = Mock()
endpoint.name = "test-endpoint"
endpoint.url = "http://localhost:8080/testing/api/test-app/test-endpoint/"
endpoint._cache = None
endpoint._cache = Mock()
endpoint._cache.get = Mock(return_value=None)
api.test_app = Mock()
api.test_app.test_endpoint = endpoint
test = Record(
Expand Down

0 comments on commit eaa75ad

Please sign in to comment.