diff --git a/sdks/python/src/opik/jsonable_encoder.py b/sdks/python/src/opik/jsonable_encoder.py index 8fb76d77ac..ffecd1ee29 100644 --- a/sdks/python/src/opik/jsonable_encoder.py +++ b/sdks/python/src/opik/jsonable_encoder.py @@ -1,12 +1,11 @@ -import logging import dataclasses import datetime as dt - -from typing import Callable, Any, Type, Set, Tuple - +import logging from enum import Enum from pathlib import PurePath from types import GeneratorType +from typing import Any, Callable, Optional, Set, Tuple, Type + import pydantic import opik.rest_api.core.datetime_utils as datetime_utils @@ -25,18 +24,29 @@ def register_encoder_extension(obj_type: Type, encoder: Callable[[Any], Any]) -> _ENCODER_EXTENSIONS.add((obj_type, encoder)) -def jsonable_encoder(obj: Any) -> Any: +def jsonable_encoder(obj: Any, seen: Optional[Set[int]] = None) -> Any: """ This is a modified version of the serializer generated by Fern in rest_api.core.jsonable_encoder. The code is simplified to serialize complex objects into a textual representation. + It also handles cyclic references to avoid infinite recursion. """ + if seen is None: + seen = set() + + if hasattr(obj, "__dict__"): + obj_id = id(obj) + if obj_id in seen: + LOGGER.debug(f"Found cyclic reference to {type(obj).__name__} id={obj_id}") + return f"" + seen.add(obj_id) + try: if dataclasses.is_dataclass(obj) or isinstance(obj, pydantic.BaseModel): obj_dict = obj.__dict__ - return jsonable_encoder(obj_dict) + return jsonable_encoder(obj_dict, seen) if isinstance(obj, Enum): - return jsonable_encoder(obj.value) + return jsonable_encoder(obj.value, seen) if isinstance(obj, PurePath): return str(obj) if isinstance(obj, (str, int, float, type(None))): @@ -50,26 +60,33 @@ def jsonable_encoder(obj: Any) -> Any: allowed_keys = set(obj.keys()) for key, value in obj.items(): if key in allowed_keys: - encoded_key = jsonable_encoder(key) - encoded_value = jsonable_encoder(value) + encoded_key = jsonable_encoder(key, seen) + encoded_value = jsonable_encoder(value, seen) encoded_dict[encoded_key] = encoded_value return encoded_dict if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): encoded_list = [] for item in obj: - encoded_list.append(jsonable_encoder(item)) + encoded_list.append(jsonable_encoder(item, seen)) return encoded_list for type_, encoder in _ENCODER_EXTENSIONS: if isinstance(obj, type_): - return jsonable_encoder(encoder(obj)) + return jsonable_encoder(encoder(obj), seen) if np is not None and isinstance(obj, np.ndarray): - return jsonable_encoder(obj.tolist()) + return jsonable_encoder(obj.tolist(), seen) except Exception: LOGGER.debug("Failed to serialize object.", exc_info=True) + finally: + # Once done encoding this object, remove from `seen`, + # so the same object can appear again at a sibling branch. + if hasattr(obj, "__dict__"): + obj_id = id(obj) + seen.remove(obj_id) + data = str(obj) return data diff --git a/sdks/python/tests/unit/message_processing/test_jsonable_encoder.py b/sdks/python/tests/unit/message_processing/test_jsonable_encoder.py index 73e0879b59..f436e4b976 100644 --- a/sdks/python/tests/unit/message_processing/test_jsonable_encoder.py +++ b/sdks/python/tests/unit/message_processing/test_jsonable_encoder.py @@ -1,14 +1,83 @@ -from typing import Any +import dataclasses from datetime import date, datetime, timezone from threading import Lock +from typing import Any, Optional import numpy as np import pytest -import dataclasses import opik.jsonable_encoder as jsonable_encoder +@dataclasses.dataclass +class Node: + value: int + child: Optional["Node"] = None + + +def test_jsonable_encoder__cyclic_reference(): + """ + Test that the encoder detects cyclic references and does not infinitely recurse. + """ + # Create a simple two-node cycle: A -> B -> A + node_a = Node(value=1) + node_b = Node(value=2) + node_a.child = node_b + node_b.child = node_a + + encoded = jsonable_encoder.jsonable_encoder(node_a) + # The exact format of the cycle marker can vary; we check that: + # 1. We get some structure for node_a (like a dict). + # 2. Inside node_a, there's a reference to node_b (a dict). + # 3. Inside node_b, there's a "cyclic reference" marker instead of a full node_a object. + print("=" * 150) + print(encoded) + assert isinstance(encoded, dict) + assert "value" in encoded + assert "child" in encoded + + # node_a.child (which is node_b) should be a dict + assert isinstance(encoded["child"], dict) + assert "value" in encoded["child"] + assert "child" in encoded["child"] + + # node_b.child should be the cycle marker + cycle_marker = encoded["child"]["child"] + print("=" * 150) + print(cycle_marker) + assert isinstance( + cycle_marker, str + ), "Expected a string marker for cyclic reference" + assert ( + "