Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OPIK-615 [SDK] DSPY integration #940

Merged
merged 15 commits into from
Dec 27, 2024
Empty file.
99 changes: 99 additions & 0 deletions sdks/python/src/opik/integrations/dspy/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Any, Dict, Optional

from dspy.utils.callback import BaseCallback

from opik import context_storage, opik_context
from opik.api_objects import opik_client, span, trace
from opik.decorator import error_info_collector


class OpikCallback(BaseCallback):

def __init__(
self,
project_name: Optional[str] = None,
):
self._opik_trace_data: Optional[trace.TraceData] = None
japdubengsub marked this conversation as resolved.
Show resolved Hide resolved
self._project_name = project_name
self._opik_client = opik_client.Opik(
_use_batching=True,
project_name=project_name,
)

def on_module_start(
self,
call_id: str,
instance: Any,
inputs: Dict[str, Any],
):
print(f"on_module_start() is called with call_id: {call_id}, instance: {instance.__class__.__name__}, inputs: {inputs}")

if self._opik_trace_data is not None:
return

existing_trace_data = opik_context.get_current_trace_data()
if existing_trace_data:
self._opik_trace_data = existing_trace_data
japdubengsub marked this conversation as resolved.
Show resolved Hide resolved
else:
trace_data = trace.TraceData(
name=instance.__class__.__name__,
metadata={"created_from": "dspy"},
project_name=self._project_name,
)
self._opik_trace_data = trace_data

def on_module_end(
self,
call_id: str,
outputs: Optional[Any],
exception: Optional[Exception] = None,
):
print(f"on_module_end() is called with call_id: {call_id}, outputs: {outputs}, exception: {exception}")

if self._opik_trace_data is None:
return

if exception:
error_info = error_info_collector.collect(exception)
self._opik_trace_data.update(error_info=error_info)

self._opik_trace_data.init_end_time()
self._opik_client.trace(**self._opik_trace_data.__dict__)
self._opik_trace_data = None

def on_lm_start(
self,
call_id: str,
instance: Any,
inputs: Dict[str, Any],
):
print(f"LM is called with inputs: {inputs}")

span_data = span.SpanData(
trace_id=self._opik_trace_data.id,
name=instance.__class__.__name__,
parent_span_id=None,
type="llm",
input=inputs,
project_name=self._opik_trace_data.project_name,
)
context_storage.add_span_data(span_data)
japdubengsub marked this conversation as resolved.
Show resolved Hide resolved


def on_lm_end(
self,
call_id: str,
outputs: Optional[Dict[str, Any]],
exception: Optional[Exception] = None,
):
print(f"LM is finished with outputs: {outputs}")

span_data = context_storage.pop_span_data()

if exception:
error_info = error_info_collector.collect(exception)
span_data.update(error_info=error_info)

span_data.update(output={"output": outputs}).init_end_time()

self._opik_client.span(**span_data.__dict__)
2 changes: 1 addition & 1 deletion sdks/python/src/opik/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class FeedbackScoreDict(TypedDict):

class ErrorInfoDict(TypedDict):
"""
A TypedDict representing the information about the error occured.
A TypedDict representing the information about the error occurred.
"""

exception_type: str
Expand Down
Loading