Skip to content

Commit

Permalink
[OPIK-38] Deduplicate items before inserting them in a dataset (#340)
Browse files Browse the repository at this point in the history
* Implemented deduplication

* Added documentation

* Added update unit test

* Add support for delete method

* Fix linter

* Fix python 3.8 tests
  • Loading branch information
jverre authored and Douglas Blank committed Oct 4, 2024
1 parent d642f00 commit eca042e
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ dataset.insert([
```

:::tip
Opik automatically deduplicates items that are inserted into a dataset when using the Python SDK. This means that you can insert the same item multiple times without duplicating it in the dataset.
:::

Instead of using the `DatasetItem` class, you can also use a dictionary to insert items to a dataset. The dictionary should have the `input` key while the `expected_output` and `metadata` are optional:

```python
Expand All @@ -56,8 +59,6 @@ dataset.insert([
{"input": {"user_question": "What is the capital of France?"}, "expected_output": {"assistant_answer": "Paris"}},
])
```
:::


You can also insert items from a JSONL file:

Expand Down
46 changes: 43 additions & 3 deletions sdks/python/src/opik/api_objects/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from opik import exceptions

from .. import helpers, constants
from . import dataset_item, converters

from . import dataset_item, converters, utils
import pandas

LOGGER = logging.getLogger(__name__)
Expand All @@ -27,6 +26,8 @@ def __init__(
self._name = name
self._description = description
self._rest_client = rest_client
self._hash_to_id: Dict[str, str] = {}
self._id_to_hash: Dict[str, str] = {}

@property
def name(self) -> str:
Expand All @@ -53,6 +54,23 @@ def insert(
for item in items
]

# Remove duplicates if they already exist
deduplicated_items = []
for item in items:
item_hash = utils.compute_content_hash(item)

if item_hash in self._hash_to_id:
if item.id is None or self._hash_to_id[item_hash] == item.id: # type: ignore
LOGGER.debug(
"Duplicate item found with hash: %s - ignored the event",
item_hash,
)
continue

deduplicated_items.append(item)
self._hash_to_id[item_hash] = item.id # type: ignore
self._id_to_hash[item.id] = item_hash # type: ignore

rest_items = [
rest_dataset_item.DatasetItem(
id=item.id if item.id is not None else helpers.generate_id(), # type: ignore
Expand All @@ -63,7 +81,7 @@ def insert(
span_id=item.span_id, # type: ignore
source=item.source, # type: ignore
)
for item in items
for item in deduplicated_items
]

batches = helpers.list_to_batches(
Expand All @@ -76,6 +94,21 @@ def insert(
dataset_name=self._name, items=batch
)

def _sync_hashes(self) -> None:
"""Updates all the hashes in the dataset"""
LOGGER.debug("Start hash sync in dataset")
all_items = self.get_all_items()

self._hash_to_id = {}
self._id_to_hash = {}

for item in all_items:
item_hash = utils.compute_content_hash(item)
self._hash_to_id[item_hash] = item.id # type: ignore
self._id_to_hash[item.id] = item_hash # type: ignore

LOGGER.debug("Finish hash sync in dataset")

def update(self, items: List[dataset_item.DatasetItem]) -> None:
"""
Update existing items in the dataset.
Expand Down Expand Up @@ -109,12 +142,19 @@ def delete(self, items_ids: List[str]) -> None:
LOGGER.debug("Deleting dataset items batch: %s", batch)
self._rest_client.datasets.delete_dataset_items(item_ids=batch)

for item_id in batch:
if item_id in self._id_to_hash:
hash = self._id_to_hash[item_id]
del self._id_to_hash[item_id]
del self._hash_to_id[hash]

def clear(self) -> None:
"""
Delete all items from the given dataset.
"""
all_items = self.get_all_items()
item_ids = [item.id for item in all_items if item.id is not None]

self.delete(item_ids)

def to_pandas(self) -> pandas.DataFrame:
Expand Down
24 changes: 24 additions & 0 deletions sdks/python/src/opik/api_objects/dataset/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import json
import hashlib
from typing import Dict, Any, Union
from . import dataset_item


def compute_content_hash(item: Union[dataset_item.DatasetItem, Dict[str, Any]]) -> str:
if isinstance(item, dataset_item.DatasetItem):
content = {
"input": item.input,
"expected_output": item.expected_output,
"metadata": item.metadata,
}
else:
content = item

# Convert the dictionary to a JSON string with sorted keys for consistency
json_string = json.dumps(content, sort_keys=True)

# Compute the SHA256 hash of the JSON string
hash_object = hashlib.sha256(json_string.encode())

# Return the hexadecimal representation of the hash
return hash_object.hexdigest()
2 changes: 2 additions & 0 deletions sdks/python/src/opik/api_objects/opik_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ def get_dataset(self, name: str) -> dataset.Dataset:
rest_client=self._rest_client,
)

dataset_._sync_hashes()

return dataset_

def delete_dataset(self, name: str) -> None:
Expand Down
27 changes: 27 additions & 0 deletions sdks/python/tests/e2e/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,30 @@ def test_create_and_populate_dataset__happyflow(
description=DESCRIPTION,
dataset_items=EXPECTED_DATASET_ITEMS,
)


def test_deduplication(opik_client: opik.Opik, dataset_name: str):
DESCRIPTION = "E2E test dataset"

item = {
"input": {"question": "What is the of capital of France?"},
"expected_output": {"output": "Paris"},
}

# Write the dataset
dataset = opik_client.create_dataset(dataset_name, description=DESCRIPTION)
dataset.insert([item])

# Read the dataset and insert the same item
new_dataset = opik_client.get_dataset(dataset_name)
new_dataset.insert([item])

# Verify the dataset
verifiers.verify_dataset(
opik_client=opik_client,
name=dataset_name,
description=DESCRIPTION,
dataset_items=[
dataset_item.DatasetItem(**item),
],
)
174 changes: 174 additions & 0 deletions sdks/python/tests/unit/api_objects/dataset/test_deduplication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import unittest
from unittest.mock import Mock
from opik.api_objects.dataset.dataset import Dataset
from opik.api_objects.dataset.dataset_item import DatasetItem


def test_insert_deduplication():
# Create a mock REST client
mock_rest_client = Mock()

# Create a Dataset instance
dataset = Dataset("test_dataset", "Test description", mock_rest_client)

# Create two identical dictionaries
item_dict = {
"input": {"key": "value", "key2": "value2"},
"expected_output": {"key": "value", "key2": "value2"},
"metadata": {"key": "value", "key2": "value2"},
}

# Insert the identical items
dataset.insert([item_dict, item_dict])

# Check that create_or_update_dataset_items was called only once
assert (
mock_rest_client.datasets.create_or_update_dataset_items.call_count == 1
), "create_or_update_dataset_items should be called only once"

# Get the arguments passed to create_or_update_dataset_items
call_args = mock_rest_client.datasets.create_or_update_dataset_items.call_args
inserted_items = call_args[1]["items"]

# Check that only one item was inserted
assert len(inserted_items) == 1, "Only one item should be inserted"


def test_insert_deduplication_with_different_items():
# Create a mock REST client
mock_rest_client = Mock()

# Create a Dataset instance
dataset = Dataset("test_dataset", "Test description", mock_rest_client)

# Create two different dictionaries
item_dict1 = {
"input": {"key": "value1"},
"expected_output": {"key": "output1"},
"metadata": {"key": "meta1"},
}
item_dict2 = {
"input": {"key": "value2"},
"expected_output": {"key": "output2"},
"metadata": {"key": "meta2"},
}

# Insert the different items
dataset.insert([item_dict1, item_dict2])

# Check that create_or_update_dataset_items was called only once
assert (
mock_rest_client.datasets.create_or_update_dataset_items.call_count == 1
), "create_or_update_dataset_items should be called only once"

# Get the arguments passed to create_or_update_dataset_items
call_args = mock_rest_client.datasets.create_or_update_dataset_items.call_args
inserted_items = call_args[1]["items"]

# Check that two items were inserted
assert len(inserted_items) == 2, "Two items should be inserted"


def test_insert_deduplication_with_partial_overlap():
# Create a mock REST client
mock_rest_client = Mock()

# Create a Dataset instance
dataset = Dataset("test_dataset", "Test description", mock_rest_client)

# Create three dictionaries, two of which are identical
item_dict1 = {
"input": {"key": "value1"},
"expected_output": {"key": "output1"},
"metadata": {"key": "meta1"},
}
item_dict2 = {
"input": {"key": "value2"},
"expected_output": {"key": "output2"},
"metadata": {"key": "meta2"},
}

# Insert the items
dataset.insert([item_dict1, item_dict2, item_dict1])

# Check that create_or_update_dataset_items was called only once
assert (
mock_rest_client.datasets.create_or_update_dataset_items.call_count == 1
), "create_or_update_dataset_items should be called only once"

# Get the arguments passed to create_or_update_dataset_items
call_args = mock_rest_client.datasets.create_or_update_dataset_items.call_args
inserted_items = call_args[1]["items"]

# Check that two items were inserted
assert len(inserted_items) == 2, "Two items should be inserted"


def test_update_flow():
# Create a mock REST client
mock_rest_client = Mock()

# Create a Dataset instance
dataset = Dataset("test_dataset", "Test description", mock_rest_client)

# Create an initial item
initial_item = {
"input": {"key": "initial_value"},
"expected_output": {"key": "initial_output"},
"metadata": {"key": "initial_metadata"},
}

# Insert the initial item
dataset.insert([initial_item])

# Check that create_or_update_dataset_items was called once for insertion
assert (
mock_rest_client.datasets.create_or_update_dataset_items.call_count == 1
), "create_or_update_dataset_items should be called once for insertion"

# Get the arguments passed to create_or_update_dataset_items for insertion
insert_call_args = (
mock_rest_client.datasets.create_or_update_dataset_items.call_args
)
inserted_items = insert_call_args[1]["items"]

# Check that one item was inserted
assert len(inserted_items) == 1, "One item should be inserted"

# Create an updated version of the item
updated_item = DatasetItem(
id=inserted_items[0].id,
input={"key": "updated_value"},
expected_output={"key": "updated_output"},
metadata={"key": "updated_metadata"},
)

# Update the item
dataset.update([updated_item])

# Check that create_or_update_dataset_items was called twice in total (once for insertion, once for update)
assert (
mock_rest_client.datasets.create_or_update_dataset_items.call_count == 2
), "create_or_update_dataset_items should be called twice in total"

# Get the arguments passed to create_or_update_dataset_items for update
update_call_args = (
mock_rest_client.datasets.create_or_update_dataset_items.call_args
)
updated_items = update_call_args[1]["items"]

# Check that one item was updated
assert len(updated_items) == 1, "One item should be updated"

# Verify the content of the updated item
assert updated_items[0].input == {"key": "updated_value"}, "Input should be updated"
assert updated_items[0].expected_output == {
"key": "updated_output"
}, "Expected output should be updated"
assert updated_items[0].metadata == {
"key": "updated_metadata"
}, "Metadata should be updated"


if __name__ == "__main__":
unittest.main()

0 comments on commit eca042e

Please sign in to comment.