Skip to content

Commit

Permalink
fixed review and updated unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jynx10 committed Nov 8, 2023
1 parent 9149bd0 commit 2d926b3
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 23 deletions.
6 changes: 5 additions & 1 deletion src/comet_llm/chains/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
experiment_api,
experiment_info,
llm_result,
logging_messages,
)
from ..types import JSONEncodable
from . import chain, state
Expand Down Expand Up @@ -72,6 +73,7 @@ def start_chain(
state.set_global_chain(global_chain)


@exceptions.filter(allow_raising=config.raising_enabled(), summary=app.SUMMARY)
def end_chain(
outputs: Dict[str, JSONEncodable],
metadata: Optional[Dict[str, JSONEncodable]] = None,
Expand All @@ -90,7 +92,9 @@ def end_chain(
"""
global_chain = state.get_global_chain()
if global_chain is None:
return # type: ignore
raise exceptions.CometLLMException(
logging_messages.GLOBAL_CHAIN_NOT_INITIALIZED % "`end_chain`"
)

global_chain.set_outputs(outputs=outputs, metadata=metadata)
return log_chain(global_chain)
Expand Down
28 changes: 23 additions & 5 deletions src/comet_llm/chains/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@
# LICENSE file in the root directory of this package.
# *******************************************************

import logging
from typing import TYPE_CHECKING, Dict, List, Optional

from .. import datetimes
from comet_llm import logging as comet_logging

from .. import config, datetimes, exceptions, logging_messages
from ..types import JSONEncodable
from . import deepmerge, state

if TYPE_CHECKING:
from . import chain

LOGGER = logging.getLogger(__name__)


class Span:
"""
Expand Down Expand Up @@ -76,15 +81,28 @@ def name(self) -> str: # pragma: no cover

def __enter__(self) -> "Span":
chain = state.get_global_chain()

if chain is None:
chain_not_initialized_exception = exceptions.CometLLMException(
logging_messages.GLOBAL_CHAIN_NOT_INITIALIZED % "`Span`"
)
if config.raising_enabled():
raise chain_not_initialized_exception

comet_logging.log_once_at_level(
LOGGER, logging.ERROR, str(chain_not_initialized_exception)
)

return self

self.__api__start__(chain)
return self

def __api__start__(self, chain: "chain.Chain") -> None:
if chain is not None:
self._connect_to_chain(chain)
self._connect_to_chain(chain)

self._timer.start()
self._chain.context.add(self.id) # type: ignore
self._timer.start()
self._chain.context.add(self.id) # type: ignore

def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore
self.__api__end__()
Expand Down
6 changes: 0 additions & 6 deletions src/comet_llm/chains/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,6 @@ def chain_exists(self) -> bool:
@property
def chain(self) -> "chain.Chain":
result: "chain.Chain" = self._thread_context_registry.get("global-chain")
if result is None:
raise exceptions.CometLLMException(
"Global chain is not initialized for this thread. Initialize it with `comet_llm.start_chain(...)`"
)

return result

@chain.setter
Expand All @@ -59,7 +54,6 @@ def global_chain_exists() -> bool:
return _APP_STATE.chain_exists()


@exceptions.filter(allow_raising=config.raising_enabled(), summary=app.SUMMARY)
def get_global_chain() -> "chain.Chain":
return _APP_STATE.chain

Expand Down
2 changes: 2 additions & 0 deletions src/comet_llm/logging_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@
)

INVALID_TIMESTAMP = "Invalid timestamp: %s. Timestamp must be in seconds if specified."

GLOBAL_CHAIN_NOT_INITIALIZED = "Global chain is not initialized for this thread. Initialize it with `comet_llm.start_chain(...)` if you wish to use %s"
33 changes: 31 additions & 2 deletions tests/unit/chains/test_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
from testix import saveargument

from comet_llm.chains import span
from comet_llm.exceptions import CometLLMException


@pytest.fixture(autouse=True)
def mock_imports(patch_module):
patch_module(span, "state")
patch_module(span, "datetimes")
patch_module(span, "convert")
patch_module(span, "comet_logging")
patch_module(span, "config")
patch_module(span, "LOGGER", "logger")


def _construct(
Expand Down Expand Up @@ -183,7 +187,7 @@ def test_set_output__new_metadata_is_not_None__existing_metadata_is_merged_with_
}


def test_span__no_chain_started__wont_connect_to_chain():
def test_span__no_chain_started_raising_exceptions_disabled__wont_connect_to_chain():
START_TIMESTAMP = 10
END_TIMESTAMP = 15
DURATION = 5
Expand All @@ -200,6 +204,12 @@ def test_span__no_chain_started__wont_connect_to_chain():

s.state.get_global_chain() >> None

s.config.raising_enabled() >> False

s.comet_logging.log_once_at_level(
"logger", 40, "Global chain is not initialized for this thread. Initialize it with `comet_llm.start_chain(...)` if you wish to use `Span`"
)

with span.Span(
category="llm-call",
inputs={"input": "input"},
Expand All @@ -217,4 +227,23 @@ def test_span__no_chain_started__wont_connect_to_chain():
"end_timestamp": END_TIMESTAMP,
"parent_ids": None,
"metadata": {},
}
}


def test_span__no_chain_started_raising_exceptions_enabled__exception_raised():
with Scenario() as s:
s.state.get_new_id() >> "example_id"
timer = Fake("timer")

s.datetimes.Timer() >> timer

s.state.get_global_chain() >> None

s.config.raising_enabled() >> True

with pytest.raises(CometLLMException):
with span.Span(
category="llm-call",
inputs={"input": "input"},
) as tested_span:
tested_span.set_outputs({"outputs": "outputs"})
9 changes: 0 additions & 9 deletions tests/unit/chains/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,6 @@ def test_new_id__happyflow():
assert tested.new_id() == 2


def test_chain_property_chain_was_not_set__exception_raised():
tested = _construct()

with Scenario() as s:
s.registry.get("global-chain") >> None
with pytest.raises(exceptions.CometLLMException):
tested.chain


def test_chain_exists__chain_was_not_set__returned_False():
tested = _construct()

Expand Down

0 comments on commit 2d926b3

Please sign in to comment.