Skip to content

Commit

Permalink
OPIK-615 [SDK] DSPY integration (#940)
Browse files Browse the repository at this point in the history
* OPIK-615 [SDK] DSPY integration

* wip

* wip

* wip

* wip

* new UI representation

* new UI representation fix

* detect provider+model

* detect span type

* fix linter

* add integration tests

* handle changing spans order

* handle api_key for dspy tests

---------

Co-authored-by: Aliaksandr Kuzmik <[email protected]>
  • Loading branch information
japdubengsub and alexkuzmik authored Dec 27, 2024
1 parent f3b278a commit 7537149
Show file tree
Hide file tree
Showing 8 changed files with 786 additions and 1 deletion.
51 changes: 51 additions & 0 deletions .github/workflows/lib-dspy-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Workflow to run DSPy tests
#
# Please read inputs to provide correct values.
#
name: SDK Lib DSPy Tests
run-name: "SDK Lib DSPy Tests ${{ github.ref_name }} by @${{ github.actor }}"
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
OPENAI_ORG_ID: ${{ secrets.OPENAI_ORG_ID }}
on:
workflow_call:

jobs:
tests:
name: DSPy Python ${{matrix.python_version}}
runs-on: ubuntu-latest
defaults:
run:
working-directory: sdks/python

strategy:
fail-fast: true
matrix:
python_version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- name: Check out code
uses: actions/checkout@v4

- name: Setup Python ${{matrix.python_version}}
uses: actions/setup-python@v5
with:
python-version: ${{matrix.python_version}}

- name: Install opik
run: pip install .

- name: Install test tools
run: |
cd ./tests
pip install --no-cache-dir --disable-pip-version-check -r test_requirements.txt
- name: Install lib
run: |
cd ./tests
pip install --no-cache-dir --disable-pip-version-check -r library_integration/dspy/requirements.txt
- name: Run tests
run: |
cd ./tests/library_integration/dspy/
python -m pytest -vv .
7 changes: 7 additions & 0 deletions .github/workflows/lib-integration-tests-runner.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ on:
- aisuite
- haystack
- guardrails
- dspy
schedule:
- cron: "0 0 */1 * *"
pull_request:
Expand Down Expand Up @@ -87,3 +88,9 @@ jobs:
if: contains(fromJSON('["guardrails", "all"]'), needs.init_environment.outputs.LIBS)
uses: ./.github/workflows/lib-guardrails-tests.yml
secrets: inherit

dspy_tests:
needs: [init_environment]
if: contains(fromJSON('["dspy", "all"]'), needs.init_environment.outputs.LIBS)
uses: ./.github/workflows/lib-dspy-tests.yml
secrets: inherit
Empty file.
246 changes: 246 additions & 0 deletions sdks/python/src/opik/integrations/dspy/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
from contextvars import ContextVar, Token
from typing import Any, Dict, Optional, Union

import dspy
from dspy.utils.callback import BaseCallback

from opik import opik_context
from opik.api_objects import helpers, span, trace
from opik.api_objects.opik_client import get_client_cached
from opik.decorator import error_info_collector

ContextType = Union[span.SpanData, trace.TraceData]


class OpikCallback(BaseCallback):
def __init__(
self,
project_name: Optional[str] = None,
):
self._map_call_id_to_span_data: Dict[str, span.SpanData] = {}
self._map_call_id_to_trace_data: Dict[str, trace.TraceData] = {}
self._map_span_id_or_trace_id_to_token: Dict[str, Token] = {}

self._current_callback_context: ContextVar[Optional[ContextType]] = ContextVar(
"opik_context", default=None
)

self._project_name = project_name

self._opik_client = get_client_cached()

def on_module_start(
self,
call_id: str,
instance: Any,
inputs: Dict[str, Any],
) -> None:
if current_callback_context_data := self._current_callback_context.get():
if isinstance(current_callback_context_data, span.SpanData):
self._attach_span_to_existing_span(
call_id=call_id,
current_span_data=current_callback_context_data,
instance=instance,
inputs=inputs,
)
else:
self._attach_span_to_existing_trace(
call_id=call_id,
current_trace_data=current_callback_context_data,
instance=instance,
inputs=inputs,
)
return

if current_span_data := opik_context.get_current_span_data():
self._attach_span_to_existing_span(
call_id=call_id,
current_span_data=current_span_data,
instance=instance,
inputs=inputs,
)
new_span_data = self._map_call_id_to_span_data[call_id]
self._callback_context_set(new_span_data)
return

if current_trace_data := opik_context.get_current_trace_data():
self._attach_span_to_existing_trace(
call_id=call_id,
current_trace_data=current_trace_data,
instance=instance,
inputs=inputs,
)
new_span_data = self._map_call_id_to_span_data[call_id]
self._callback_context_set(new_span_data)
return

self._start_trace(
call_id=call_id,
instance=instance,
inputs=inputs,
)

def _attach_span_to_existing_span(
self,
call_id: str,
current_span_data: span.SpanData,
instance: Any,
inputs: Dict[str, Any],
) -> None:
project_name = helpers.resolve_child_span_project_name(
parent_project_name=current_span_data.project_name,
child_project_name=self._project_name,
)
span_type = self._get_span_type(instance)

span_data = span.SpanData(
trace_id=current_span_data.trace_id,
parent_span_id=current_span_data.id,
name=instance.__class__.__name__,
input=inputs,
type=span_type,
project_name=project_name,
)
self._map_call_id_to_span_data[call_id] = span_data

def _attach_span_to_existing_trace(
self,
call_id: str,
current_trace_data: trace.TraceData,
instance: Any,
inputs: Dict[str, Any],
) -> None:
project_name = helpers.resolve_child_span_project_name(
current_trace_data.project_name,
self._project_name,
)
span_type = self._get_span_type(instance)

span_data = span.SpanData(
trace_id=current_trace_data.id,
parent_span_id=None,
name=instance.__class__.__name__,
input=inputs,
type=span_type,
project_name=project_name,
)
self._map_call_id_to_span_data[call_id] = span_data

def _start_trace(
self,
call_id: str,
instance: Any,
inputs: Dict[str, Any],
) -> None:
trace_data = trace.TraceData(
name=instance.__class__.__name__,
input=inputs,
metadata={"created_from": "dspy"},
project_name=self._project_name,
)
self._map_call_id_to_trace_data[call_id] = trace_data
self._callback_context_set(trace_data)

def on_module_end(
self,
call_id: str,
outputs: Optional[Any],
exception: Optional[Exception] = None,
) -> None:
self._end_span(
call_id=call_id,
exception=exception,
outputs=outputs,
)
self._end_trace(call_id=call_id)

def _end_trace(self, call_id: str) -> None:
if trace_data := self._map_call_id_to_trace_data.pop(call_id, None):
trace_data.init_end_time()
self._opik_client.trace(**trace_data.__dict__)

# remove trace data from context
if token := self._map_span_id_or_trace_id_to_token.pop(trace_data.id, None):
self._current_callback_context.reset(token)

def _end_span(
self,
call_id: str,
outputs: Optional[Any],
exception: Optional[Exception] = None,
) -> None:
if span_data := self._map_call_id_to_span_data.pop(call_id, None):
if exception:
error_info = error_info_collector.collect(exception)
span_data.update(error_info=error_info)

span_data.update(output={"output": outputs}).init_end_time()
self._opik_client.span(**span_data.__dict__)

# remove span data from context
if token := self._map_span_id_or_trace_id_to_token.pop(span_data.id, None):
self._current_callback_context.reset(token)

def on_lm_start(
self,
call_id: str,
instance: Any,
inputs: Dict[str, Any],
) -> None:
current_callback_context_data = self._current_callback_context.get()
assert current_callback_context_data is not None

project_name = helpers.resolve_child_span_project_name(
current_callback_context_data.project_name,
self._project_name,
)

if isinstance(current_callback_context_data, span.SpanData):
trace_id = current_callback_context_data.trace_id
parent_span_id = current_callback_context_data.id
else:
trace_id = current_callback_context_data.id
parent_span_id = None

provider, model = instance.model.split(r"/", 1)
span_type = self._get_span_type(instance)

span_data = span.SpanData(
trace_id=trace_id,
name=instance.__class__.__name__,
parent_span_id=parent_span_id,
type=span_type,
input=inputs,
project_name=project_name,
provider=provider,
model=model,
)
self._map_call_id_to_span_data[call_id] = span_data

def on_lm_end(
self,
call_id: str,
outputs: Optional[Dict[str, Any]],
exception: Optional[Exception] = None,
) -> None:
self._end_span(
call_id=call_id,
exception=exception,
outputs=outputs,
)
self._end_trace(call_id=call_id)

def flush(self) -> None:
"""Sends pending Opik data to the backend"""
self._opik_client.flush()

def _callback_context_set(self, value: ContextType) -> None:
token = self._current_callback_context.set(value)
self._map_span_id_or_trace_id_to_token[value.id] = token

def _get_span_type(self, instance: Any) -> span.SpanType:
if isinstance(instance, dspy.Predict):
return "llm"
elif isinstance(instance, dspy.LM):
return "llm"
return "general"
2 changes: 1 addition & 1 deletion sdks/python/src/opik/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class FeedbackScoreDict(TypedDict):

class ErrorInfoDict(TypedDict):
"""
A TypedDict representing the information about the error occured.
A TypedDict representing the information about the error occurred.
"""

exception_type: str
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
dspy
Loading

0 comments on commit 7537149

Please sign in to comment.