diff --git a/sdks/python/src/opik/_logging.py b/sdks/python/src/opik/_logging.py index 97870b272d..1d64c0d3a3 100644 --- a/sdks/python/src/opik/_logging.py +++ b/sdks/python/src/opik/_logging.py @@ -1,4 +1,4 @@ -from typing import Callable, Any +from typing import Callable, Any, Optional import functools import logging @@ -6,9 +6,31 @@ from . import config CONSOLE_MSG_FORMAT = "OPIK: %(message)s" - FILE_MSG_FORMAT = "%(asctime)s OPIK %(levelname)s: %(message)s" +# 1MB, to prevent logger from frequent writing hundreds of megabytes in DEBUG mode +# when batches are big and payloads are heavy (e.g. base64 encoded data) +MAX_MESSAGE_LENGTH = 1024 * 1024 + + +class TruncateFormatter(logging.Formatter): + def __init__( + self, + fmt: str, + datefmt: Optional[str] = None, + max_length: int = MAX_MESSAGE_LENGTH, + ) -> None: + super().__init__(fmt, datefmt) + self.max_length = max_length + + def format(self, record: logging.LogRecord) -> str: + result = super().format(record) + + if len(result) > self.max_length: + result = result[: self.max_length] + "... (truncated)." + + return result + def setup() -> None: opik_root_logger = logging.getLogger("opik") @@ -19,8 +41,7 @@ def setup() -> None: console_handler = logging.StreamHandler() console_level = config_.console_logging_level console_handler.setLevel(console_level) - console_handler.setFormatter(logging.Formatter(CONSOLE_MSG_FORMAT)) - + console_handler.setFormatter(TruncateFormatter(CONSOLE_MSG_FORMAT)) opik_root_logger.addHandler(console_handler) root_level = console_handler.level @@ -29,7 +50,7 @@ def setup() -> None: file_handler = logging.FileHandler(config_.logging_file) file_level = config_.file_logging_level file_handler.setLevel(file_level) - file_handler.setFormatter(logging.Formatter(FILE_MSG_FORMAT)) + file_handler.setFormatter(TruncateFormatter(FILE_MSG_FORMAT)) opik_root_logger.addHandler(file_handler) root_level = min(root_level, file_handler.level) diff --git a/sdks/python/src/opik/api_objects/dataset/dataset.py b/sdks/python/src/opik/api_objects/dataset/dataset.py index fb528f6cd3..68ef044a95 100644 --- a/sdks/python/src/opik/api_objects/dataset/dataset.py +++ b/sdks/python/src/opik/api_objects/dataset/dataset.py @@ -3,10 +3,11 @@ from typing import Optional, Any, List, Dict, Sequence, Set from opik.rest_api import client as rest_api_client -from opik.rest_api.types import dataset_item as rest_dataset_item -from opik import exceptions +from opik.rest_api.types import dataset_item_write as rest_dataset_item +from opik.message_processing.batching import sequence_splitter +from opik import exceptions, config -from .. import helpers, constants +from .. import constants from . import dataset_item, converters import pandas @@ -60,7 +61,7 @@ def __internal_api__insert_items_as_dataclasses__( self._id_to_hash[item.id] = item_hash rest_items = [ - rest_dataset_item.DatasetItem( + rest_dataset_item.DatasetItemWrite( id=item.id, # type: ignore trace_id=item.trace_id, # type: ignore span_id=item.span_id, # type: ignore @@ -70,12 +71,16 @@ def __internal_api__insert_items_as_dataclasses__( for item in deduplicated_items ] - batches = helpers.list_to_batches( - rest_items, batch_size=constants.DATASET_ITEMS_MAX_BATCH_SIZE + batches = sequence_splitter.split_into_batches( + rest_items, + max_payload_size_MB=config.MAX_BATCH_SIZE_MB, + max_length=constants.DATASET_ITEMS_MAX_BATCH_SIZE, ) for batch in batches: - LOGGER.debug("Sending dataset items batch: %s", batch) + LOGGER.debug( + "Sending dataset items batch of size %d: %s", len(batch), batch + ) self._rest_client.datasets.create_or_update_dataset_items( dataset_name=self._name, items=batch ) @@ -134,8 +139,8 @@ def delete(self, items_ids: List[str]) -> None: Args: items_ids: List of item ids to delete. """ - batches = helpers.list_to_batches( - items_ids, batch_size=constants.DATASET_ITEMS_MAX_BATCH_SIZE + batches = sequence_splitter.split_into_batches( + items_ids, max_length=constants.DATASET_ITEMS_MAX_BATCH_SIZE ) for batch in batches: diff --git a/sdks/python/src/opik/api_objects/experiment/experiment.py b/sdks/python/src/opik/api_objects/experiment/experiment.py index e3bcf42525..b308363bb6 100644 --- a/sdks/python/src/opik/api_objects/experiment/experiment.py +++ b/sdks/python/src/opik/api_objects/experiment/experiment.py @@ -3,6 +3,8 @@ from opik.rest_api import client as rest_api_client from opik.rest_api.types import experiment_item as rest_experiment_item +from opik.message_processing.batching import sequence_splitter + from . import experiment_item from .. import helpers, constants from ... import Prompt @@ -36,8 +38,8 @@ def insert(self, experiment_items: List[experiment_item.ExperimentItem]) -> None for item in experiment_items ] - batches = helpers.list_to_batches( - rest_experiment_items, batch_size=constants.EXPERIMENT_ITEMS_MAX_BATCH_SIZE + batches = sequence_splitter.split_into_batches( + rest_experiment_items, max_length=constants.EXPERIMENT_ITEMS_MAX_BATCH_SIZE ) for batch in batches: diff --git a/sdks/python/src/opik/api_objects/helpers.py b/sdks/python/src/opik/api_objects/helpers.py index 11160eb63a..e6e20708af 100644 --- a/sdks/python/src/opik/api_objects/helpers.py +++ b/sdks/python/src/opik/api_objects/helpers.py @@ -1,6 +1,6 @@ import datetime import logging -from typing import Any, List, Optional +from typing import Optional import uuid_extensions @@ -22,10 +22,6 @@ def datetime_to_iso8601_if_not_None( return datetime_helpers.datetime_to_iso8601(value) -def list_to_batches(items: List[Any], batch_size: int) -> List[List[Any]]: - return [items[i : i + batch_size] for i in range(0, len(items), batch_size)] - - def resolve_child_span_project_name( parent_project_name: Optional[str], child_project_name: Optional[str], diff --git a/sdks/python/src/opik/api_objects/opik_client.py b/sdks/python/src/opik/api_objects/opik_client.py index 7f8d16ed2e..ddf5a33efb 100644 --- a/sdks/python/src/opik/api_objects/opik_client.py +++ b/sdks/python/src/opik/api_objects/opik_client.py @@ -20,6 +20,8 @@ validation_helpers, ) from ..message_processing import streamer_constructors, messages +from ..message_processing.batching import sequence_splitter + from ..rest_api import client as rest_api_client from ..rest_api.types import dataset_public, trace_public, span_public, project_public from ..rest_api.core.api_error import ApiError @@ -309,8 +311,10 @@ def log_spans_feedback_scores( for score_dict in valid_scores ] - for batch in helpers.list_to_batches( - score_messages, batch_size=constants.FEEDBACK_SCORES_MAX_BATCH_SIZE + for batch in sequence_splitter.split_into_batches( + score_messages, + max_payload_size_MB=config.MAX_BATCH_SIZE_MB, + max_length=constants.FEEDBACK_SCORES_MAX_BATCH_SIZE, ): add_span_feedback_scores_batch_message = ( messages.AddSpanFeedbackScoresBatchMessage(batch=batch) @@ -350,8 +354,10 @@ def log_traces_feedback_scores( ) for score_dict in valid_scores ] - for batch in helpers.list_to_batches( - score_messages, batch_size=constants.FEEDBACK_SCORES_MAX_BATCH_SIZE + for batch in sequence_splitter.split_into_batches( + score_messages, + max_payload_size_MB=config.MAX_BATCH_SIZE_MB, + max_length=constants.FEEDBACK_SCORES_MAX_BATCH_SIZE, ): add_span_feedback_scores_batch_message = ( messages.AddTraceFeedbackScoresBatchMessage(batch=batch) diff --git a/sdks/python/src/opik/config.py b/sdks/python/src/opik/config.py index 654133ea6e..8521e8cdaa 100644 --- a/sdks/python/src/opik/config.py +++ b/sdks/python/src/opik/config.py @@ -22,6 +22,8 @@ _SESSION_CACHE_DICT: Dict[str, Any] = {} +MAX_BATCH_SIZE_MB = 50 + OPIK_URL_CLOUD: Final[str] = "https://www.comet.com/opik/api" OPIK_URL_LOCAL: Final[str] = "http://localhost:5173/api" diff --git a/sdks/python/src/opik/message_processing/batching/sequence_splitter.py b/sdks/python/src/opik/message_processing/batching/sequence_splitter.py new file mode 100644 index 0000000000..dd93b31a27 --- /dev/null +++ b/sdks/python/src/opik/message_processing/batching/sequence_splitter.py @@ -0,0 +1,58 @@ +import json +from typing import List, Optional, TypeVar, Sequence +from opik import jsonable_encoder + +T = TypeVar("T") + + +def _get_expected_payload_size_MB(item: T) -> float: + encoded_for_json = jsonable_encoder.jsonable_encoder(item) + json_str = json.dumps(encoded_for_json) + return len(json_str.encode("utf-8")) / (1024 * 1024) + + +def split_into_batches( + items: Sequence[T], + max_payload_size_MB: Optional[float] = None, + max_length: Optional[int] = None, +) -> List[List[T]]: + assert (max_payload_size_MB is not None) or ( + max_length is not None + ), "At least one limitation must be set for splitting" + + if max_length is None: + max_length = len(items) + + if max_payload_size_MB is None: + max_payload_size_MB = float("inf") + + batches: List[List[T]] = [] + current_batch: List[T] = [] + current_batch_size_MB: float = 0.0 + + for item in items: + item_size_MB = ( + 0.0 if max_payload_size_MB is None else _get_expected_payload_size_MB(item) + ) + + if item_size_MB >= max_payload_size_MB: + batches.append([item]) + continue + + batch_is_already_full = len(current_batch) == max_length + batch_will_exceed_memory_limit_after_adding = ( + current_batch_size_MB + item_size_MB > max_payload_size_MB + ) + + if batch_is_already_full or batch_will_exceed_memory_limit_after_adding: + batches.append(current_batch) + current_batch = [item] + current_batch_size_MB = item_size_MB + else: + current_batch.append(item) + current_batch_size_MB += item_size_MB + + if len(current_batch) > 0: + batches.append(current_batch) + + return batches diff --git a/sdks/python/src/opik/message_processing/message_processors.py b/sdks/python/src/opik/message_processing/message_processors.py index 9b3afaf7b1..b7b6cfb9c1 100644 --- a/sdks/python/src/opik/message_processing/message_processors.py +++ b/sdks/python/src/opik/message_processing/message_processors.py @@ -9,9 +9,12 @@ from ..rest_api import client as rest_api_client from ..rest_api.types import feedback_score_batch_item from ..rest_api.types import span_write +from .batching import sequence_splitter LOGGER = logging.getLogger(__name__) +BATCH_MEMORY_LIMIT_MB = 50 + class BaseMessageProcessor(abc.ABC): @abc.abstractmethod @@ -169,7 +172,8 @@ def _process_add_trace_feedback_scores_batch_message( def _process_create_span_batch_message( self, message: messages.CreateSpansBatchMessage ) -> None: - span_write_batch: List[span_write.SpanWrite] = [] + rest_spans: List[span_write.SpanWrite] = [] + for item in message.batch: span_write_kwargs = { "id": item.span_id, @@ -190,7 +194,13 @@ def _process_create_span_batch_message( span_write_kwargs ) cleaned_span_write_kwargs = jsonable_encoder(cleaned_span_write_kwargs) - span_write_batch.append(span_write.SpanWrite(**cleaned_span_write_kwargs)) + rest_spans.append(span_write.SpanWrite(**cleaned_span_write_kwargs)) + + memory_limited_batches = sequence_splitter.split_into_batches( + items=rest_spans, + max_payload_size_MB=BATCH_MEMORY_LIMIT_MB, + ) - LOGGER.debug("Create spans batch request: %s", span_write_batch) - self._rest_client.spans.create_spans(spans=span_write_batch) + for batch in memory_limited_batches: + LOGGER.debug("Create spans batch request of size %d", len(batch), batch) + self._rest_client.spans.create_spans(spans=batch) diff --git a/sdks/python/tests/unit/message_processing/batching/test_sequence_splitter.py b/sdks/python/tests/unit/message_processing/batching/test_sequence_splitter.py new file mode 100644 index 0000000000..af837aba6a --- /dev/null +++ b/sdks/python/tests/unit/message_processing/batching/test_sequence_splitter.py @@ -0,0 +1,77 @@ +import dataclasses +from opik.message_processing.batching import sequence_splitter + + +@dataclasses.dataclass +class LongStr: + value: str + + def __str__(self) -> str: + return self.value[1] + ".." + self.value[-1] + + def __repr__(self) -> str: + return str(self) + + +ONE_MEGABYTE_OBJECT_A = LongStr("a" * 1024 * 1024) +ONE_MEGABYTE_OBJECT_B = LongStr("b" * 1024 * 1024) +ONE_MEGABYTE_OBJECT_C = LongStr("c" * 1024 * 1024) + + +def test_split_list_into_batches__by_size_only(): + items = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + batches = sequence_splitter.split_into_batches(items, max_length=4) + + assert batches == [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10]] + + +def test_split_list_into_batches__by_memory_only(): + items = [ONE_MEGABYTE_OBJECT_A] * 2 + [ONE_MEGABYTE_OBJECT_B] * 2 + batches = sequence_splitter.split_into_batches(items, max_payload_size_MB=3.5) + + assert batches == [ + [ONE_MEGABYTE_OBJECT_A, ONE_MEGABYTE_OBJECT_A, ONE_MEGABYTE_OBJECT_B], + [ONE_MEGABYTE_OBJECT_B], + ] + + +def test_split_list_into_batches__by_memory_and_by_size(): + FOUR_MEGABYTE_OBJECT_C = [ONE_MEGABYTE_OBJECT_C] * 4 + items = ( + [ONE_MEGABYTE_OBJECT_A] * 2 + + [FOUR_MEGABYTE_OBJECT_C] + + [ONE_MEGABYTE_OBJECT_B] * 2 + ) + batches = sequence_splitter.split_into_batches( + items, max_length=2, max_payload_size_MB=3.5 + ) + + # Object C comes before object A because if item is bigger than the max payload size + # it is immediately added to the result batches list before batch which is currently accumulating + assert batches == [ + [FOUR_MEGABYTE_OBJECT_C], + [ONE_MEGABYTE_OBJECT_A, ONE_MEGABYTE_OBJECT_A], + [ONE_MEGABYTE_OBJECT_B, ONE_MEGABYTE_OBJECT_B], + ] + + +def test_split_list_into_batches__empty_list(): + items = [] + batches = sequence_splitter.split_into_batches( + items, max_length=3, max_payload_size_MB=3.5 + ) + + assert batches == [] + + +def test_split_list_into_batches__multiple_large_objects(): + items = [ONE_MEGABYTE_OBJECT_A, ONE_MEGABYTE_OBJECT_B, ONE_MEGABYTE_OBJECT_C] + batches = sequence_splitter.split_into_batches( + items, max_length=2, max_payload_size_MB=0.5 + ) + + assert batches == [ + [ONE_MEGABYTE_OBJECT_A], + [ONE_MEGABYTE_OBJECT_B], + [ONE_MEGABYTE_OBJECT_C], + ]