Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alexkuzmik/update opik context api #203

Merged
merged 3 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions sdks/python/examples/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
@track()
def f3(x):
# creates span3 attached to trace1 with parent span2
span = opik_context.get_current_span()
span.update(tags=["f3-tag"])
opik_context.update_current_span(tags=["tag-f3"])
print("Done f3")
return "f3 output"

Expand Down
5 changes: 2 additions & 3 deletions sdks/python/examples/distributed_tracing_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import threading
import time
from opik import track, flush_tracker
from opik.opik_context import get_current_span
from opik.opik_context import get_distributed_trace_headers


@track()
Expand All @@ -16,8 +16,7 @@ def remote_node(x, opik_headers):

@track()
def local_function(x):
current_span = get_current_span()
opik_headers = current_span.get_distributed_trace_headers()
opik_headers = get_distributed_trace_headers()

t1 = threading.Thread(
target=remote_node, args=("remote-function-input", opik_headers)
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/examples/manual_chain_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
os.environ["OPIK_URL_OVERRIDE"] = "http://localhost:5173/api"
client = opik.Opik()

trace = client.trace()
trace = client.trace(name="trace-1")
span1 = trace.span(name="span-1")
span2 = span1.span(name="span-2")
span2.end()
Expand Down
3 changes: 1 addition & 2 deletions sdks/python/examples/threaded_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ def f3(x, thread_name):
# creates trace1 with span2_1 in thread 1
# creates trace2 with span2_2 in thread 2
print(f"Done f3 from {thread_name}")
span = opik_context.get_current_span()
span.update(tags=[f"f3-thread-{thread_name}"])
opik_context.update_current_span(tags=[f"f3-thread-{thread_name}"])
return f"f3 output from {thread_name}"


Expand Down
14 changes: 14 additions & 0 deletions sdks/python/src/opik/api_objects/opik_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,23 @@ def flush(self, timeout: Optional[int] = None) -> None:
self._streamer.flush(timeout)

def get_trace_content(self, id: str) -> trace_public.TracePublic:
"""
Args:
id (str): trace id
Returns:
trace_public.TracePublic: pydantic model object with all the data associated with the trace found.
Raises an error if trace was not found.
"""
return self._rest_client.traces.get_trace_by_id(id)

def get_span_content(self, id: str) -> span_public.SpanPublic:
"""
Args:
id (str): span id
Returns:
span_public.SpanPublic: pydantic model object with all the data associated with the span found.
Raises an error if span was not found.
"""
return self._rest_client.spans.get_span_by_id(id)


Expand Down
4 changes: 4 additions & 0 deletions sdks/python/src/opik/api_objects/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,4 +204,8 @@ def log_feedback_score(
self._streamer.put(add_span_feedback_batch_message)

def get_distributed_trace_headers(self) -> DistributedTraceHeadersDict:
"""
Returns headers dictionary to be passed into tracked
function on remote node.
"""
return {"opik_parent_span_id": self.id, "opik_trace_id": self.trace_id}
77 changes: 73 additions & 4 deletions sdks/python/src/opik/opik_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Dict, List, Any
from opik.types import UsageDict, DistributedTraceHeadersDict

from . import context_storage
from . import context_storage, dict_utils, exceptions

if TYPE_CHECKING:
from .api_objects import trace, span
Expand All @@ -25,7 +26,75 @@ def get_current_trace() -> Optional["trace.Trace"]:
return context_storage.get_trace()


def get_distributed_trace_headers() -> DistributedTraceHeadersDict:
"""
Returns headers dictionary to be passed into tracked
function on remote node.
"""
current_span = context_storage.top_span()
if current_span is None:
raise exceptions.OpikException(
"There is no span in the context to get distributed trace headers from."
)

return current_span.get_distributed_trace_headers()


def update_current_span(
input: Optional[Dict[str, Any]] = None,
output: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[List[str]] = None,
usage: Optional[UsageDict] = None,
) -> None:
"""
Updates current span created by track() decorator or raises exception if no span was found.
Context-wise.
"""
new_params = dict_utils.remove_none_from_dict(
{
"input": input,
"output": output,
"metadata": metadata,
"tags": tags,
"usage": usage,
}
)
current_span = context_storage.top_span()
if current_span is None:
raise exceptions.OpikException("There is no span in the context.")

current_span.update(**new_params)


def update_current_trace(
input: Optional[Dict[str, Any]] = None,
output: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[List[str]] = None,
) -> None:
"""
Updates current trace created by track() decorator or raises exception if no trace was found.
Context-wise.
"""

new_params = dict_utils.remove_none_from_dict(
{
"input": input,
"output": output,
"metadata": metadata,
"tags": tags,
}
)
current_trace = context_storage.get_trace()

if current_trace is None:
raise exceptions.OpikException("There is no trace in the context.")

current_trace.update(**new_params)


__all__ = [
"get_current_span",
"get_current_trace",
"update_current_span",
"update_current_trace",
]
2 changes: 1 addition & 1 deletion sdks/python/tests/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def _random_chars(n: int = 6) -> str:
@pytest.fixture(scope="session")
def configure_e2e_tests_env():
os.environ["OPIK_PROJECT_NAME"] = "e2e-tests"
# os.environ["OPIK_URL_OVERRIDE"] = "http://localhost:5173/api"
os.environ["OPIK_URL_OVERRIDE"] = "http://localhost:5173/api"


@pytest.fixture()
Expand Down
3 changes: 1 addition & 2 deletions sdks/python/tests/unit/decorator/test_tracker_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,8 +1036,7 @@ def distributed_node_runner(y, thread_id, opik_headers):

@tracker.track(capture_output=True)
def f_outer(x):
current_span = opik_context.get_current_span()
distributed_trace_headers = current_span.get_distributed_trace_headers()
distributed_trace_headers = opik_context.get_distributed_trace_headers()
t1 = threading.Thread(
target=distributed_node_runner,
args=("remote-input-1", "thread-1", distributed_trace_headers),
Expand Down
Loading