Skip to content

Commit

Permalink
[OPIK-410] improve batch mechanism in SDK (#648)
Browse files Browse the repository at this point in the history
* Initial batch_splitter implementation

* Refactor batch splitter

* Add unit tests for batch splitter, refactor the implementation

* Fix lint errors

* Replace usage of list_to_batches with the new batch splitter method

* Update modules and functions names

* Update type hint

* Add splitting by size to dataset items and feedbacks

* Add truncation to logger formatter to prevent logger from writing very heavy payloads in debug mode
  • Loading branch information
alexkuzmik authored Nov 19, 2024
1 parent c59c82c commit 5dfcea4
Show file tree
Hide file tree
Showing 9 changed files with 206 additions and 29 deletions.
31 changes: 26 additions & 5 deletions sdks/python/src/opik/_logging.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,36 @@
from typing import Callable, Any
from typing import Callable, Any, Optional
import functools
import logging


from . import config

CONSOLE_MSG_FORMAT = "OPIK: %(message)s"

FILE_MSG_FORMAT = "%(asctime)s OPIK %(levelname)s: %(message)s"

# 1MB, to prevent logger from frequent writing hundreds of megabytes in DEBUG mode
# when batches are big and payloads are heavy (e.g. base64 encoded data)
MAX_MESSAGE_LENGTH = 1024 * 1024


class TruncateFormatter(logging.Formatter):
def __init__(
self,
fmt: str,
datefmt: Optional[str] = None,
max_length: int = MAX_MESSAGE_LENGTH,
) -> None:
super().__init__(fmt, datefmt)
self.max_length = max_length

def format(self, record: logging.LogRecord) -> str:
result = super().format(record)

if len(result) > self.max_length:
result = result[: self.max_length] + "... (truncated)."

return result


def setup() -> None:
opik_root_logger = logging.getLogger("opik")
Expand All @@ -19,8 +41,7 @@ def setup() -> None:
console_handler = logging.StreamHandler()
console_level = config_.console_logging_level
console_handler.setLevel(console_level)
console_handler.setFormatter(logging.Formatter(CONSOLE_MSG_FORMAT))

console_handler.setFormatter(TruncateFormatter(CONSOLE_MSG_FORMAT))
opik_root_logger.addHandler(console_handler)

root_level = console_handler.level
Expand All @@ -29,7 +50,7 @@ def setup() -> None:
file_handler = logging.FileHandler(config_.logging_file)
file_level = config_.file_logging_level
file_handler.setLevel(file_level)
file_handler.setFormatter(logging.Formatter(FILE_MSG_FORMAT))
file_handler.setFormatter(TruncateFormatter(FILE_MSG_FORMAT))
opik_root_logger.addHandler(file_handler)

root_level = min(root_level, file_handler.level)
Expand Down
23 changes: 14 additions & 9 deletions sdks/python/src/opik/api_objects/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from typing import Optional, Any, List, Dict, Sequence, Set

from opik.rest_api import client as rest_api_client
from opik.rest_api.types import dataset_item as rest_dataset_item
from opik import exceptions
from opik.rest_api.types import dataset_item_write as rest_dataset_item
from opik.message_processing.batching import sequence_splitter
from opik import exceptions, config

from .. import helpers, constants
from .. import constants
from . import dataset_item, converters
import pandas

Expand Down Expand Up @@ -60,7 +61,7 @@ def __internal_api__insert_items_as_dataclasses__(
self._id_to_hash[item.id] = item_hash

rest_items = [
rest_dataset_item.DatasetItem(
rest_dataset_item.DatasetItemWrite(
id=item.id, # type: ignore
trace_id=item.trace_id, # type: ignore
span_id=item.span_id, # type: ignore
Expand All @@ -70,12 +71,16 @@ def __internal_api__insert_items_as_dataclasses__(
for item in deduplicated_items
]

batches = helpers.list_to_batches(
rest_items, batch_size=constants.DATASET_ITEMS_MAX_BATCH_SIZE
batches = sequence_splitter.split_into_batches(
rest_items,
max_payload_size_MB=config.MAX_BATCH_SIZE_MB,
max_length=constants.DATASET_ITEMS_MAX_BATCH_SIZE,
)

for batch in batches:
LOGGER.debug("Sending dataset items batch: %s", batch)
LOGGER.debug(
"Sending dataset items batch of size %d: %s", len(batch), batch
)
self._rest_client.datasets.create_or_update_dataset_items(
dataset_name=self._name, items=batch
)
Expand Down Expand Up @@ -134,8 +139,8 @@ def delete(self, items_ids: List[str]) -> None:
Args:
items_ids: List of item ids to delete.
"""
batches = helpers.list_to_batches(
items_ids, batch_size=constants.DATASET_ITEMS_MAX_BATCH_SIZE
batches = sequence_splitter.split_into_batches(
items_ids, max_length=constants.DATASET_ITEMS_MAX_BATCH_SIZE
)

for batch in batches:
Expand Down
6 changes: 4 additions & 2 deletions sdks/python/src/opik/api_objects/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from opik.rest_api import client as rest_api_client
from opik.rest_api.types import experiment_item as rest_experiment_item
from opik.message_processing.batching import sequence_splitter

from . import experiment_item
from .. import helpers, constants
from ... import Prompt
Expand Down Expand Up @@ -36,8 +38,8 @@ def insert(self, experiment_items: List[experiment_item.ExperimentItem]) -> None
for item in experiment_items
]

batches = helpers.list_to_batches(
rest_experiment_items, batch_size=constants.EXPERIMENT_ITEMS_MAX_BATCH_SIZE
batches = sequence_splitter.split_into_batches(
rest_experiment_items, max_length=constants.EXPERIMENT_ITEMS_MAX_BATCH_SIZE
)

for batch in batches:
Expand Down
6 changes: 1 addition & 5 deletions sdks/python/src/opik/api_objects/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import logging
from typing import Any, List, Optional
from typing import Optional

import uuid_extensions

Expand All @@ -22,10 +22,6 @@ def datetime_to_iso8601_if_not_None(
return datetime_helpers.datetime_to_iso8601(value)


def list_to_batches(items: List[Any], batch_size: int) -> List[List[Any]]:
return [items[i : i + batch_size] for i in range(0, len(items), batch_size)]


def resolve_child_span_project_name(
parent_project_name: Optional[str],
child_project_name: Optional[str],
Expand Down
14 changes: 10 additions & 4 deletions sdks/python/src/opik/api_objects/opik_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
validation_helpers,
)
from ..message_processing import streamer_constructors, messages
from ..message_processing.batching import sequence_splitter

from ..rest_api import client as rest_api_client
from ..rest_api.types import dataset_public, trace_public, span_public, project_public
from ..rest_api.core.api_error import ApiError
Expand Down Expand Up @@ -309,8 +311,10 @@ def log_spans_feedback_scores(
for score_dict in valid_scores
]

for batch in helpers.list_to_batches(
score_messages, batch_size=constants.FEEDBACK_SCORES_MAX_BATCH_SIZE
for batch in sequence_splitter.split_into_batches(
score_messages,
max_payload_size_MB=config.MAX_BATCH_SIZE_MB,
max_length=constants.FEEDBACK_SCORES_MAX_BATCH_SIZE,
):
add_span_feedback_scores_batch_message = (
messages.AddSpanFeedbackScoresBatchMessage(batch=batch)
Expand Down Expand Up @@ -350,8 +354,10 @@ def log_traces_feedback_scores(
)
for score_dict in valid_scores
]
for batch in helpers.list_to_batches(
score_messages, batch_size=constants.FEEDBACK_SCORES_MAX_BATCH_SIZE
for batch in sequence_splitter.split_into_batches(
score_messages,
max_payload_size_MB=config.MAX_BATCH_SIZE_MB,
max_length=constants.FEEDBACK_SCORES_MAX_BATCH_SIZE,
):
add_span_feedback_scores_batch_message = (
messages.AddTraceFeedbackScoresBatchMessage(batch=batch)
Expand Down
2 changes: 2 additions & 0 deletions sdks/python/src/opik/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

_SESSION_CACHE_DICT: Dict[str, Any] = {}

MAX_BATCH_SIZE_MB = 50

OPIK_URL_CLOUD: Final[str] = "https://www.comet.com/opik/api"
OPIK_URL_LOCAL: Final[str] = "http://localhost:5173/api"

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import json
from typing import List, Optional, TypeVar, Sequence
from opik import jsonable_encoder

T = TypeVar("T")


def _get_expected_payload_size_MB(item: T) -> float:
encoded_for_json = jsonable_encoder.jsonable_encoder(item)
json_str = json.dumps(encoded_for_json)
return len(json_str.encode("utf-8")) / (1024 * 1024)


def split_into_batches(
items: Sequence[T],
max_payload_size_MB: Optional[float] = None,
max_length: Optional[int] = None,
) -> List[List[T]]:
assert (max_payload_size_MB is not None) or (
max_length is not None
), "At least one limitation must be set for splitting"

if max_length is None:
max_length = len(items)

if max_payload_size_MB is None:
max_payload_size_MB = float("inf")

batches: List[List[T]] = []
current_batch: List[T] = []
current_batch_size_MB: float = 0.0

for item in items:
item_size_MB = (
0.0 if max_payload_size_MB is None else _get_expected_payload_size_MB(item)
)

if item_size_MB >= max_payload_size_MB:
batches.append([item])
continue

batch_is_already_full = len(current_batch) == max_length
batch_will_exceed_memory_limit_after_adding = (
current_batch_size_MB + item_size_MB > max_payload_size_MB
)

if batch_is_already_full or batch_will_exceed_memory_limit_after_adding:
batches.append(current_batch)
current_batch = [item]
current_batch_size_MB = item_size_MB
else:
current_batch.append(item)
current_batch_size_MB += item_size_MB

if len(current_batch) > 0:
batches.append(current_batch)

return batches
18 changes: 14 additions & 4 deletions sdks/python/src/opik/message_processing/message_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
from ..rest_api import client as rest_api_client
from ..rest_api.types import feedback_score_batch_item
from ..rest_api.types import span_write
from .batching import sequence_splitter

LOGGER = logging.getLogger(__name__)

BATCH_MEMORY_LIMIT_MB = 50


class BaseMessageProcessor(abc.ABC):
@abc.abstractmethod
Expand Down Expand Up @@ -169,7 +172,8 @@ def _process_add_trace_feedback_scores_batch_message(
def _process_create_span_batch_message(
self, message: messages.CreateSpansBatchMessage
) -> None:
span_write_batch: List[span_write.SpanWrite] = []
rest_spans: List[span_write.SpanWrite] = []

for item in message.batch:
span_write_kwargs = {
"id": item.span_id,
Expand All @@ -190,7 +194,13 @@ def _process_create_span_batch_message(
span_write_kwargs
)
cleaned_span_write_kwargs = jsonable_encoder(cleaned_span_write_kwargs)
span_write_batch.append(span_write.SpanWrite(**cleaned_span_write_kwargs))
rest_spans.append(span_write.SpanWrite(**cleaned_span_write_kwargs))

memory_limited_batches = sequence_splitter.split_into_batches(
items=rest_spans,
max_payload_size_MB=BATCH_MEMORY_LIMIT_MB,
)

LOGGER.debug("Create spans batch request: %s", span_write_batch)
self._rest_client.spans.create_spans(spans=span_write_batch)
for batch in memory_limited_batches:
LOGGER.debug("Create spans batch request of size %d", len(batch), batch)
self._rest_client.spans.create_spans(spans=batch)
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import dataclasses
from opik.message_processing.batching import sequence_splitter


@dataclasses.dataclass
class LongStr:
value: str

def __str__(self) -> str:
return self.value[1] + ".." + self.value[-1]

def __repr__(self) -> str:
return str(self)


ONE_MEGABYTE_OBJECT_A = LongStr("a" * 1024 * 1024)
ONE_MEGABYTE_OBJECT_B = LongStr("b" * 1024 * 1024)
ONE_MEGABYTE_OBJECT_C = LongStr("c" * 1024 * 1024)


def test_split_list_into_batches__by_size_only():
items = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
batches = sequence_splitter.split_into_batches(items, max_length=4)

assert batches == [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10]]


def test_split_list_into_batches__by_memory_only():
items = [ONE_MEGABYTE_OBJECT_A] * 2 + [ONE_MEGABYTE_OBJECT_B] * 2
batches = sequence_splitter.split_into_batches(items, max_payload_size_MB=3.5)

assert batches == [
[ONE_MEGABYTE_OBJECT_A, ONE_MEGABYTE_OBJECT_A, ONE_MEGABYTE_OBJECT_B],
[ONE_MEGABYTE_OBJECT_B],
]


def test_split_list_into_batches__by_memory_and_by_size():
FOUR_MEGABYTE_OBJECT_C = [ONE_MEGABYTE_OBJECT_C] * 4
items = (
[ONE_MEGABYTE_OBJECT_A] * 2
+ [FOUR_MEGABYTE_OBJECT_C]
+ [ONE_MEGABYTE_OBJECT_B] * 2
)
batches = sequence_splitter.split_into_batches(
items, max_length=2, max_payload_size_MB=3.5
)

# Object C comes before object A because if item is bigger than the max payload size
# it is immediately added to the result batches list before batch which is currently accumulating
assert batches == [
[FOUR_MEGABYTE_OBJECT_C],
[ONE_MEGABYTE_OBJECT_A, ONE_MEGABYTE_OBJECT_A],
[ONE_MEGABYTE_OBJECT_B, ONE_MEGABYTE_OBJECT_B],
]


def test_split_list_into_batches__empty_list():
items = []
batches = sequence_splitter.split_into_batches(
items, max_length=3, max_payload_size_MB=3.5
)

assert batches == []


def test_split_list_into_batches__multiple_large_objects():
items = [ONE_MEGABYTE_OBJECT_A, ONE_MEGABYTE_OBJECT_B, ONE_MEGABYTE_OBJECT_C]
batches = sequence_splitter.split_into_batches(
items, max_length=2, max_payload_size_MB=0.5
)

assert batches == [
[ONE_MEGABYTE_OBJECT_A],
[ONE_MEGABYTE_OBJECT_B],
[ONE_MEGABYTE_OBJECT_C],
]

0 comments on commit 5dfcea4

Please sign in to comment.