diff --git a/sdks/python/examples/decorators.py b/sdks/python/examples/decorators.py index f7d5fcac46..6da4484d97 100644 --- a/sdks/python/examples/decorators.py +++ b/sdks/python/examples/decorators.py @@ -5,8 +5,7 @@ @track() def f3(x): # creates span3 attached to trace1 with parent span2 - span = opik_context.get_current_span() - span.update(tags=["f3-tag"]) + opik_context.update_current_span(tags=["tag-f3"]) print("Done f3") return "f3 output" diff --git a/sdks/python/examples/distributed_tracing_example.py b/sdks/python/examples/distributed_tracing_example.py index 4a3ff0b978..1e1d58ecf4 100644 --- a/sdks/python/examples/distributed_tracing_example.py +++ b/sdks/python/examples/distributed_tracing_example.py @@ -1,7 +1,7 @@ import threading import time from opik import track, flush_tracker -from opik.opik_context import get_current_span +from opik.opik_context import get_distributed_trace_headers @track() @@ -16,8 +16,7 @@ def remote_node(x, opik_headers): @track() def local_function(x): - current_span = get_current_span() - opik_headers = current_span.get_distributed_trace_headers() + opik_headers = get_distributed_trace_headers() t1 = threading.Thread( target=remote_node, args=("remote-function-input", opik_headers) diff --git a/sdks/python/examples/manual_chain_building.py b/sdks/python/examples/manual_chain_building.py index aa171cf9e9..0a7d1aedc6 100644 --- a/sdks/python/examples/manual_chain_building.py +++ b/sdks/python/examples/manual_chain_building.py @@ -4,7 +4,7 @@ os.environ["OPIK_URL_OVERRIDE"] = "http://localhost:5173/api" client = opik.Opik() -trace = client.trace() +trace = client.trace(name="trace-1") span1 = trace.span(name="span-1") span2 = span1.span(name="span-2") span2.end() diff --git a/sdks/python/examples/threaded_decorators.py b/sdks/python/examples/threaded_decorators.py index e10e5f3543..01bc06c2f4 100644 --- a/sdks/python/examples/threaded_decorators.py +++ b/sdks/python/examples/threaded_decorators.py @@ -8,8 +8,7 @@ def f3(x, thread_name): # creates trace1 with span2_1 in thread 1 # creates trace2 with span2_2 in thread 2 print(f"Done f3 from {thread_name}") - span = opik_context.get_current_span() - span.update(tags=[f"f3-thread-{thread_name}"]) + opik_context.update_current_span(tags=[f"f3-thread-{thread_name}"]) return f"f3 output from {thread_name}" diff --git a/sdks/python/src/opik/api_objects/opik_client.py b/sdks/python/src/opik/api_objects/opik_client.py index ecd578bc07..9c73560bf4 100644 --- a/sdks/python/src/opik/api_objects/opik_client.py +++ b/sdks/python/src/opik/api_objects/opik_client.py @@ -378,9 +378,23 @@ def flush(self, timeout: Optional[int] = None) -> None: self._streamer.flush(timeout) def get_trace_content(self, id: str) -> trace_public.TracePublic: + """ + Args: + id (str): trace id + Returns: + trace_public.TracePublic: pydantic model object with all the data associated with the trace found. + Raises an error if trace was not found. + """ return self._rest_client.traces.get_trace_by_id(id) def get_span_content(self, id: str) -> span_public.SpanPublic: + """ + Args: + id (str): span id + Returns: + span_public.SpanPublic: pydantic model object with all the data associated with the span found. + Raises an error if span was not found. + """ return self._rest_client.spans.get_span_by_id(id) diff --git a/sdks/python/src/opik/api_objects/span.py b/sdks/python/src/opik/api_objects/span.py index df2ddee6c7..8fda0d3ece 100644 --- a/sdks/python/src/opik/api_objects/span.py +++ b/sdks/python/src/opik/api_objects/span.py @@ -204,4 +204,8 @@ def log_feedback_score( self._streamer.put(add_span_feedback_batch_message) def get_distributed_trace_headers(self) -> DistributedTraceHeadersDict: + """ + Returns headers dictionary to be passed into tracked + function on remote node. + """ return {"opik_parent_span_id": self.id, "opik_trace_id": self.trace_id} diff --git a/sdks/python/src/opik/opik_context.py b/sdks/python/src/opik/opik_context.py index 2744b4713c..756a4b479e 100644 --- a/sdks/python/src/opik/opik_context.py +++ b/sdks/python/src/opik/opik_context.py @@ -1,6 +1,7 @@ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Dict, List, Any +from opik.types import UsageDict, DistributedTraceHeadersDict -from . import context_storage +from . import context_storage, dict_utils, exceptions if TYPE_CHECKING: from .api_objects import trace, span @@ -25,7 +26,75 @@ def get_current_trace() -> Optional["trace.Trace"]: return context_storage.get_trace() +def get_distributed_trace_headers() -> DistributedTraceHeadersDict: + """ + Returns headers dictionary to be passed into tracked + function on remote node. + """ + current_span = context_storage.top_span() + if current_span is None: + raise exceptions.OpikException( + "There is no span in the context to get distributed trace headers from." + ) + + return current_span.get_distributed_trace_headers() + + +def update_current_span( + input: Optional[Dict[str, Any]] = None, + output: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, Any]] = None, + tags: Optional[List[str]] = None, + usage: Optional[UsageDict] = None, +) -> None: + """ + Updates current span created by track() decorator or raises exception if no span was found. + Context-wise. + """ + new_params = dict_utils.remove_none_from_dict( + { + "input": input, + "output": output, + "metadata": metadata, + "tags": tags, + "usage": usage, + } + ) + current_span = context_storage.top_span() + if current_span is None: + raise exceptions.OpikException("There is no span in the context.") + + current_span.update(**new_params) + + +def update_current_trace( + input: Optional[Dict[str, Any]] = None, + output: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, Any]] = None, + tags: Optional[List[str]] = None, +) -> None: + """ + Updates current trace created by track() decorator or raises exception if no trace was found. + Context-wise. + """ + + new_params = dict_utils.remove_none_from_dict( + { + "input": input, + "output": output, + "metadata": metadata, + "tags": tags, + } + ) + current_trace = context_storage.get_trace() + + if current_trace is None: + raise exceptions.OpikException("There is no trace in the context.") + + current_trace.update(**new_params) + + __all__ = [ - "get_current_span", - "get_current_trace", + "update_current_span", + "update_current_trace", ] diff --git a/sdks/python/tests/e2e/conftest.py b/sdks/python/tests/e2e/conftest.py index ebdbe52262..37cc21d998 100644 --- a/sdks/python/tests/e2e/conftest.py +++ b/sdks/python/tests/e2e/conftest.py @@ -15,7 +15,7 @@ def _random_chars(n: int = 6) -> str: @pytest.fixture(scope="session") def configure_e2e_tests_env(): os.environ["OPIK_PROJECT_NAME"] = "e2e-tests" - # os.environ["OPIK_URL_OVERRIDE"] = "http://localhost:5173/api" + os.environ["OPIK_URL_OVERRIDE"] = "http://localhost:5173/api" @pytest.fixture() diff --git a/sdks/python/tests/unit/decorator/test_tracker_outputs.py b/sdks/python/tests/unit/decorator/test_tracker_outputs.py index 7563d716ad..d2cf721879 100644 --- a/sdks/python/tests/unit/decorator/test_tracker_outputs.py +++ b/sdks/python/tests/unit/decorator/test_tracker_outputs.py @@ -1036,8 +1036,7 @@ def distributed_node_runner(y, thread_id, opik_headers): @tracker.track(capture_output=True) def f_outer(x): - current_span = opik_context.get_current_span() - distributed_trace_headers = current_span.get_distributed_trace_headers() + distributed_trace_headers = opik_context.get_distributed_trace_headers() t1 = threading.Thread( target=distributed_node_runner, args=("remote-input-1", "thread-1", distributed_trace_headers),