From 2d926b3c55ebf45e3cabd28b4de6b0c5308608f0 Mon Sep 17 00:00:00 2001 From: Guy Avnon Date: Wed, 8 Nov 2023 18:19:57 +0200 Subject: [PATCH] fixed review and updated unit tests --- src/comet_llm/chains/api.py | 6 +++++- src/comet_llm/chains/span.py | 28 +++++++++++++++++++++----- src/comet_llm/chains/state.py | 6 ------ src/comet_llm/logging_messages.py | 2 ++ tests/unit/chains/test_span.py | 33 +++++++++++++++++++++++++++++-- tests/unit/chains/test_state.py | 9 --------- 6 files changed, 61 insertions(+), 23 deletions(-) diff --git a/src/comet_llm/chains/api.py b/src/comet_llm/chains/api.py index 56f99cae80..7c215d28f4 100644 --- a/src/comet_llm/chains/api.py +++ b/src/comet_llm/chains/api.py @@ -24,6 +24,7 @@ experiment_api, experiment_info, llm_result, + logging_messages, ) from ..types import JSONEncodable from . import chain, state @@ -72,6 +73,7 @@ def start_chain( state.set_global_chain(global_chain) +@exceptions.filter(allow_raising=config.raising_enabled(), summary=app.SUMMARY) def end_chain( outputs: Dict[str, JSONEncodable], metadata: Optional[Dict[str, JSONEncodable]] = None, @@ -90,7 +92,9 @@ def end_chain( """ global_chain = state.get_global_chain() if global_chain is None: - return # type: ignore + raise exceptions.CometLLMException( + logging_messages.GLOBAL_CHAIN_NOT_INITIALIZED % "`end_chain`" + ) global_chain.set_outputs(outputs=outputs, metadata=metadata) return log_chain(global_chain) diff --git a/src/comet_llm/chains/span.py b/src/comet_llm/chains/span.py index 21c05aa301..308a57a74d 100644 --- a/src/comet_llm/chains/span.py +++ b/src/comet_llm/chains/span.py @@ -12,15 +12,20 @@ # LICENSE file in the root directory of this package. # ******************************************************* +import logging from typing import TYPE_CHECKING, Dict, List, Optional -from .. import datetimes +from comet_llm import logging as comet_logging + +from .. import config, datetimes, exceptions, logging_messages from ..types import JSONEncodable from . import deepmerge, state if TYPE_CHECKING: from . import chain +LOGGER = logging.getLogger(__name__) + class Span: """ @@ -76,15 +81,28 @@ def name(self) -> str: # pragma: no cover def __enter__(self) -> "Span": chain = state.get_global_chain() + + if chain is None: + chain_not_initialized_exception = exceptions.CometLLMException( + logging_messages.GLOBAL_CHAIN_NOT_INITIALIZED % "`Span`" + ) + if config.raising_enabled(): + raise chain_not_initialized_exception + + comet_logging.log_once_at_level( + LOGGER, logging.ERROR, str(chain_not_initialized_exception) + ) + + return self + self.__api__start__(chain) return self def __api__start__(self, chain: "chain.Chain") -> None: - if chain is not None: - self._connect_to_chain(chain) + self._connect_to_chain(chain) - self._timer.start() - self._chain.context.add(self.id) # type: ignore + self._timer.start() + self._chain.context.add(self.id) # type: ignore def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore self.__api__end__() diff --git a/src/comet_llm/chains/state.py b/src/comet_llm/chains/state.py index 5660e1eb3b..465dd608d0 100644 --- a/src/comet_llm/chains/state.py +++ b/src/comet_llm/chains/state.py @@ -35,11 +35,6 @@ def chain_exists(self) -> bool: @property def chain(self) -> "chain.Chain": result: "chain.Chain" = self._thread_context_registry.get("global-chain") - if result is None: - raise exceptions.CometLLMException( - "Global chain is not initialized for this thread. Initialize it with `comet_llm.start_chain(...)`" - ) - return result @chain.setter @@ -59,7 +54,6 @@ def global_chain_exists() -> bool: return _APP_STATE.chain_exists() -@exceptions.filter(allow_raising=config.raising_enabled(), summary=app.SUMMARY) def get_global_chain() -> "chain.Chain": return _APP_STATE.chain diff --git a/src/comet_llm/logging_messages.py b/src/comet_llm/logging_messages.py index 99d03533a1..1e0f2c0b24 100644 --- a/src/comet_llm/logging_messages.py +++ b/src/comet_llm/logging_messages.py @@ -28,3 +28,5 @@ ) INVALID_TIMESTAMP = "Invalid timestamp: %s. Timestamp must be in seconds if specified." + +GLOBAL_CHAIN_NOT_INITIALIZED = "Global chain is not initialized for this thread. Initialize it with `comet_llm.start_chain(...)` if you wish to use %s" diff --git a/tests/unit/chains/test_span.py b/tests/unit/chains/test_span.py index f9efdbe1ed..89e0695d17 100644 --- a/tests/unit/chains/test_span.py +++ b/tests/unit/chains/test_span.py @@ -4,6 +4,7 @@ from testix import saveargument from comet_llm.chains import span +from comet_llm.exceptions import CometLLMException @pytest.fixture(autouse=True) @@ -11,6 +12,9 @@ def mock_imports(patch_module): patch_module(span, "state") patch_module(span, "datetimes") patch_module(span, "convert") + patch_module(span, "comet_logging") + patch_module(span, "config") + patch_module(span, "LOGGER", "logger") def _construct( @@ -183,7 +187,7 @@ def test_set_output__new_metadata_is_not_None__existing_metadata_is_merged_with_ } -def test_span__no_chain_started__wont_connect_to_chain(): +def test_span__no_chain_started_raising_exceptions_disabled__wont_connect_to_chain(): START_TIMESTAMP = 10 END_TIMESTAMP = 15 DURATION = 5 @@ -200,6 +204,12 @@ def test_span__no_chain_started__wont_connect_to_chain(): s.state.get_global_chain() >> None + s.config.raising_enabled() >> False + + s.comet_logging.log_once_at_level( + "logger", 40, "Global chain is not initialized for this thread. Initialize it with `comet_llm.start_chain(...)` if you wish to use `Span`" + ) + with span.Span( category="llm-call", inputs={"input": "input"}, @@ -217,4 +227,23 @@ def test_span__no_chain_started__wont_connect_to_chain(): "end_timestamp": END_TIMESTAMP, "parent_ids": None, "metadata": {}, - } \ No newline at end of file + } + + +def test_span__no_chain_started_raising_exceptions_enabled__exception_raised(): + with Scenario() as s: + s.state.get_new_id() >> "example_id" + timer = Fake("timer") + + s.datetimes.Timer() >> timer + + s.state.get_global_chain() >> None + + s.config.raising_enabled() >> True + + with pytest.raises(CometLLMException): + with span.Span( + category="llm-call", + inputs={"input": "input"}, + ) as tested_span: + tested_span.set_outputs({"outputs": "outputs"}) diff --git a/tests/unit/chains/test_state.py b/tests/unit/chains/test_state.py index 2e9b4f3733..4b3b7f1198 100644 --- a/tests/unit/chains/test_state.py +++ b/tests/unit/chains/test_state.py @@ -25,15 +25,6 @@ def test_new_id__happyflow(): assert tested.new_id() == 2 -def test_chain_property_chain_was_not_set__exception_raised(): - tested = _construct() - - with Scenario() as s: - s.registry.get("global-chain") >> None - with pytest.raises(exceptions.CometLLMException): - tested.chain - - def test_chain_exists__chain_was_not_set__returned_False(): tested = _construct()