-
Notifications
You must be signed in to change notification settings - Fork 223
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[OPIK-410] improve batch mechanism in SDK (#648)
* Initial batch_splitter implementation * Refactor batch splitter * Add unit tests for batch splitter, refactor the implementation * Fix lint errors * Replace usage of list_to_batches with the new batch splitter method * Update modules and functions names * Update type hint * Add splitting by size to dataset items and feedbacks * Add truncation to logger formatter to prevent logger from writing very heavy payloads in debug mode
- Loading branch information
1 parent
c59c82c
commit 5dfcea4
Showing
9 changed files
with
206 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
58 changes: 58 additions & 0 deletions
58
sdks/python/src/opik/message_processing/batching/sequence_splitter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import json | ||
from typing import List, Optional, TypeVar, Sequence | ||
from opik import jsonable_encoder | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
def _get_expected_payload_size_MB(item: T) -> float: | ||
encoded_for_json = jsonable_encoder.jsonable_encoder(item) | ||
json_str = json.dumps(encoded_for_json) | ||
return len(json_str.encode("utf-8")) / (1024 * 1024) | ||
|
||
|
||
def split_into_batches( | ||
items: Sequence[T], | ||
max_payload_size_MB: Optional[float] = None, | ||
max_length: Optional[int] = None, | ||
) -> List[List[T]]: | ||
assert (max_payload_size_MB is not None) or ( | ||
max_length is not None | ||
), "At least one limitation must be set for splitting" | ||
|
||
if max_length is None: | ||
max_length = len(items) | ||
|
||
if max_payload_size_MB is None: | ||
max_payload_size_MB = float("inf") | ||
|
||
batches: List[List[T]] = [] | ||
current_batch: List[T] = [] | ||
current_batch_size_MB: float = 0.0 | ||
|
||
for item in items: | ||
item_size_MB = ( | ||
0.0 if max_payload_size_MB is None else _get_expected_payload_size_MB(item) | ||
) | ||
|
||
if item_size_MB >= max_payload_size_MB: | ||
batches.append([item]) | ||
continue | ||
|
||
batch_is_already_full = len(current_batch) == max_length | ||
batch_will_exceed_memory_limit_after_adding = ( | ||
current_batch_size_MB + item_size_MB > max_payload_size_MB | ||
) | ||
|
||
if batch_is_already_full or batch_will_exceed_memory_limit_after_adding: | ||
batches.append(current_batch) | ||
current_batch = [item] | ||
current_batch_size_MB = item_size_MB | ||
else: | ||
current_batch.append(item) | ||
current_batch_size_MB += item_size_MB | ||
|
||
if len(current_batch) > 0: | ||
batches.append(current_batch) | ||
|
||
return batches |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
77 changes: 77 additions & 0 deletions
77
sdks/python/tests/unit/message_processing/batching/test_sequence_splitter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import dataclasses | ||
from opik.message_processing.batching import sequence_splitter | ||
|
||
|
||
@dataclasses.dataclass | ||
class LongStr: | ||
value: str | ||
|
||
def __str__(self) -> str: | ||
return self.value[1] + ".." + self.value[-1] | ||
|
||
def __repr__(self) -> str: | ||
return str(self) | ||
|
||
|
||
ONE_MEGABYTE_OBJECT_A = LongStr("a" * 1024 * 1024) | ||
ONE_MEGABYTE_OBJECT_B = LongStr("b" * 1024 * 1024) | ||
ONE_MEGABYTE_OBJECT_C = LongStr("c" * 1024 * 1024) | ||
|
||
|
||
def test_split_list_into_batches__by_size_only(): | ||
items = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] | ||
batches = sequence_splitter.split_into_batches(items, max_length=4) | ||
|
||
assert batches == [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10]] | ||
|
||
|
||
def test_split_list_into_batches__by_memory_only(): | ||
items = [ONE_MEGABYTE_OBJECT_A] * 2 + [ONE_MEGABYTE_OBJECT_B] * 2 | ||
batches = sequence_splitter.split_into_batches(items, max_payload_size_MB=3.5) | ||
|
||
assert batches == [ | ||
[ONE_MEGABYTE_OBJECT_A, ONE_MEGABYTE_OBJECT_A, ONE_MEGABYTE_OBJECT_B], | ||
[ONE_MEGABYTE_OBJECT_B], | ||
] | ||
|
||
|
||
def test_split_list_into_batches__by_memory_and_by_size(): | ||
FOUR_MEGABYTE_OBJECT_C = [ONE_MEGABYTE_OBJECT_C] * 4 | ||
items = ( | ||
[ONE_MEGABYTE_OBJECT_A] * 2 | ||
+ [FOUR_MEGABYTE_OBJECT_C] | ||
+ [ONE_MEGABYTE_OBJECT_B] * 2 | ||
) | ||
batches = sequence_splitter.split_into_batches( | ||
items, max_length=2, max_payload_size_MB=3.5 | ||
) | ||
|
||
# Object C comes before object A because if item is bigger than the max payload size | ||
# it is immediately added to the result batches list before batch which is currently accumulating | ||
assert batches == [ | ||
[FOUR_MEGABYTE_OBJECT_C], | ||
[ONE_MEGABYTE_OBJECT_A, ONE_MEGABYTE_OBJECT_A], | ||
[ONE_MEGABYTE_OBJECT_B, ONE_MEGABYTE_OBJECT_B], | ||
] | ||
|
||
|
||
def test_split_list_into_batches__empty_list(): | ||
items = [] | ||
batches = sequence_splitter.split_into_batches( | ||
items, max_length=3, max_payload_size_MB=3.5 | ||
) | ||
|
||
assert batches == [] | ||
|
||
|
||
def test_split_list_into_batches__multiple_large_objects(): | ||
items = [ONE_MEGABYTE_OBJECT_A, ONE_MEGABYTE_OBJECT_B, ONE_MEGABYTE_OBJECT_C] | ||
batches = sequence_splitter.split_into_batches( | ||
items, max_length=2, max_payload_size_MB=0.5 | ||
) | ||
|
||
assert batches == [ | ||
[ONE_MEGABYTE_OBJECT_A], | ||
[ONE_MEGABYTE_OBJECT_B], | ||
[ONE_MEGABYTE_OBJECT_C], | ||
] |