diff --git a/sdks/python/src/opik/api_objects/experiment/__init__.py b/sdks/python/src/opik/api_objects/experiment/__init__.py index 8fbfc7c5b7..1d18caff3d 100644 --- a/sdks/python/src/opik/api_objects/experiment/__init__.py +++ b/sdks/python/src/opik/api_objects/experiment/__init__.py @@ -1,3 +1,4 @@ from .experiment import Experiment +from .helpers import build_metadata_and_prompt_version -__all__ = ["Experiment"] +__all__ = ["Experiment", "build_metadata_and_prompt_version"] diff --git a/sdks/python/src/opik/api_objects/experiment/helpers.py b/sdks/python/src/opik/api_objects/experiment/helpers.py new file mode 100644 index 0000000000..4a02b6f10c --- /dev/null +++ b/sdks/python/src/opik/api_objects/experiment/helpers.py @@ -0,0 +1,40 @@ +from typing import Optional, Dict, Mapping, Tuple, Any +from .. import prompt +import logging +from opik import jsonable_encoder + +LOGGER = logging.getLogger(__name__) + + +def build_metadata_and_prompt_version( + experiment_config: Optional[Dict[str, Any]], prompt: Optional[prompt.Prompt] +) -> Tuple[Optional[Dict[str, Any]], Optional[Dict[str, str]]]: + metadata = None + prompt_version: Optional[Dict[str, str]] = None + + if experiment_config is None: + experiment_config = {} + + if not isinstance(experiment_config, Mapping): + LOGGER.error( + "Experiment config must be dictionary, but %s was provided. Provided config will be ignored.", + experiment_config, + ) + experiment_config = {} + + if prompt is not None and "prompt" in experiment_config: + LOGGER.warning( + "The prompt parameter will not be added to experiment since there is already `prompt` specified in experiment_config" + ) + return (experiment_config, None) + + if prompt is not None: + prompt_version = {"id": prompt.__internal_api__version_id__} + experiment_config["prompt"] = prompt.prompt + + if experiment_config == {}: + return None, None + + metadata = jsonable_encoder.jsonable_encoder(experiment_config) + + return metadata, prompt_version diff --git a/sdks/python/src/opik/api_objects/opik_client.py b/sdks/python/src/opik/api_objects/opik_client.py index b34730bf92..9bb63d8e09 100644 --- a/sdks/python/src/opik/api_objects/opik_client.py +++ b/sdks/python/src/opik/api_objects/opik_client.py @@ -3,7 +3,7 @@ import datetime import logging -from typing import Optional, Any, Dict, List, Mapping +from typing import Optional, Any, Dict, List from .prompt import Prompt from .prompt.client import PromptClient @@ -29,7 +29,6 @@ datetime_helpers, config, httpx_client, - jsonable_encoder, url_helpers, rest_client_configurator, ) @@ -490,23 +489,10 @@ def create_experiment( experiment.Experiment: The newly created experiment object. """ id = helpers.generate_id() - metadata = None - prompt_version: Optional[Dict[str, str]] = None - if isinstance(experiment_config, Mapping): - if prompt is not None: - prompt_version = {"id": prompt.__internal_api__version_id__} - - if "prompt" not in experiment_config: - experiment_config["prompt"] = prompt.prompt - - metadata = jsonable_encoder.jsonable_encoder(experiment_config) - - elif experiment_config is not None: - LOGGER.error( - "Experiment config must be dictionary, but %s was provided. Config will not be logged.", - experiment_config, - ) + metadata, prompt_version = experiment.build_metadata_and_prompt_version( + experiment_config=experiment_config, prompt=prompt + ) self._rest_client.experiments.create_experiment( name=name, diff --git a/sdks/python/tests/unit/api_objects/experiment/__init__.py b/sdks/python/tests/unit/api_objects/experiment/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sdks/python/tests/unit/api_objects/experiment/test_helpers.py b/sdks/python/tests/unit/api_objects/experiment/test_helpers.py new file mode 100644 index 0000000000..9699392800 --- /dev/null +++ b/sdks/python/tests/unit/api_objects/experiment/test_helpers.py @@ -0,0 +1,60 @@ +from opik.api_objects import experiment +import pytest +import types + + +def fake_prompt(): + return types.SimpleNamespace( + __internal_api__version_id__="some-prompt-version-id", + prompt="some-prompt-value", + ) + + +@pytest.mark.parametrize( + argnames="input_kwargs,expected", + argvalues=[ + ( + {"experiment_config": None, "prompt": None}, + {"metadata": None, "prompt_version": None}, + ), + ( + {"experiment_config": {}, "prompt": None}, + {"metadata": None, "prompt_version": None}, + ), + ( + {"experiment_config": None, "prompt": fake_prompt()}, + { + "metadata": {"prompt": "some-prompt-value"}, + "prompt_version": {"id": "some-prompt-version-id"}, + }, + ), + ( + {"experiment_config": {}, "prompt": fake_prompt()}, + { + "metadata": {"prompt": "some-prompt-value"}, + "prompt_version": {"id": "some-prompt-version-id"}, + }, + ), + ( + {"experiment_config": {"some-key": "some-value"}, "prompt": None}, + {"metadata": {"some-key": "some-value"}, "prompt_version": None}, + ), + ( + { + "experiment_config": "NOT-DICT-VALUE-THAT-WILL-BE-IGNORED-AND-REPLACED-WITH-DICT-WITH-PROMPT", + "prompt": fake_prompt(), + }, + { + "metadata": {"prompt": "some-prompt-value"}, + "prompt_version": {"id": "some-prompt-version-id"}, + }, + ), + ], +) +def test_experiment_build_metadata_from_prompt_version(input_kwargs, expected): + metadata, prompt_version = experiment.build_metadata_and_prompt_version( + **input_kwargs + ) + + assert metadata == expected["metadata"] + assert prompt_version == expected["prompt_version"]