Skip to content

Commit

Permalink
avoid infinite recursion in cyclic structures
Browse files Browse the repository at this point in the history
  • Loading branch information
japdubengsub committed Dec 30, 2024
1 parent 2bbddd5 commit 6510ae9
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 6 deletions.
16 changes: 12 additions & 4 deletions sdks/python/src/opik/jsonable_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@ def jsonable_encoder(obj: Any, seen: Optional[Set[int]] = None) -> Any:
"""
This is a modified version of the serializer generated by Fern in rest_api.core.jsonable_encoder.
The code is simplified to serialize complex objects into a textual representation.
It also handles cyclic references to avoid infinite recursion.
"""
if seen is None:
seen = set()

obj_id = id(obj)
if obj_id in seen:
return None

if hasattr(obj, "__dict__"):
obj_id = id(obj)
if obj_id in seen:
LOGGER.debug(f"Found cyclic reference to {type(obj).__name__} id={obj_id}")
return f"<Cyclic reference to {type(obj).__name__} id={obj_id}>"
seen.add(obj_id)

try:
Expand Down Expand Up @@ -79,6 +80,13 @@ def jsonable_encoder(obj: Any, seen: Optional[Set[int]] = None) -> Any:
except Exception:
LOGGER.debug("Failed to serialize object.", exc_info=True)

finally:
# Once done encoding this object, remove from `seen`,
# so the same object can appear again at a sibling branch.
if hasattr(obj, "__dict__"):
obj_id = id(obj)
seen.remove(obj_id)

data = str(obj)

return data
73 changes: 71 additions & 2 deletions sdks/python/tests/unit/message_processing/test_jsonable_encoder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,83 @@
from typing import Any
import dataclasses
from datetime import date, datetime, timezone
from threading import Lock
from typing import Any, Optional

import numpy as np
import pytest
import dataclasses

import opik.jsonable_encoder as jsonable_encoder


@dataclasses.dataclass
class Node:
value: int
child: Optional["Node"] = None


def test_jsonable_encoder__cyclic_reference():
"""
Test that the encoder detects cyclic references and does not infinitely recurse.
"""
# Create a simple two-node cycle: A -> B -> A
node_a = Node(value=1)
node_b = Node(value=2)
node_a.child = node_b
node_b.child = node_a

encoded = jsonable_encoder.jsonable_encoder(node_a)
# The exact format of the cycle marker can vary; we check that:
# 1. We get some structure for node_a (like a dict).
# 2. Inside node_a, there's a reference to node_b (a dict).
# 3. Inside node_b, there's a "cyclic reference" marker instead of a full node_a object.
print("=" * 150)
print(encoded)
assert isinstance(encoded, dict)
assert "value" in encoded
assert "child" in encoded

# node_a.child (which is node_b) should be a dict
assert isinstance(encoded["child"], dict)
assert "value" in encoded["child"]
assert "child" in encoded["child"]

# node_b.child should be the cycle marker
cycle_marker = encoded["child"]["child"]
print("=" * 150)
print(cycle_marker)
assert isinstance(
cycle_marker, str
), "Expected a string marker for cyclic reference"
assert (
"<Cyclic reference to " in cycle_marker
), "Should contain 'Cyclic reference' text"


def test_jsonable_encoder__repeated_objects_in_list():
"""
Test that the encoder handles a list of the same object repeated multiple times
without marking it as a cycle (because it isn't a cycle—just repeated references).
"""
node = Node(value=42)

# Put the same node object in a list multiple times
repeated_list = [node, node, node]

encoded = jsonable_encoder.jsonable_encoder(repeated_list)
# We expect a list of three items, each being a dict with `value` = 42, `child` = None
assert isinstance(encoded, list)
assert len(encoded) == 3

for item in encoded:
assert isinstance(item, dict)
assert item.get("value") == 42
assert item.get("child") is None

# They are distinct dictionary objects, but there is no cycle reference marker
# because there's no actual cycle. It's just repeated references of the same object.
assert all("Cyclic reference" not in str(item) for item in encoded)


@pytest.mark.parametrize(
"obj",
[
Expand Down

0 comments on commit 6510ae9

Please sign in to comment.