diff --git a/sdks/python/src/opik/api_objects/span.py b/sdks/python/src/opik/api_objects/span.py index ec75b15012..99abcfd125 100644 --- a/sdks/python/src/opik/api_objects/span.py +++ b/sdks/python/src/opik/api_objects/span.py @@ -271,6 +271,12 @@ def update(self, **new_data: Any) -> "SpanData": if key == "metadata": self._update_metadata(value) continue + elif key == "output": + self._update_output(value) + continue + elif key == "input": + self._update_input(value) + continue self.__dict__[key] = value @@ -282,6 +288,18 @@ def _update_metadata(self, new_metadata: Dict[str, Any]) -> None: else: self.metadata = dict_utils.deepmerge(self.metadata, new_metadata) + def _update_output(self, new_output: Dict[str, Any]) -> None: + if self.output is None: + self.output = new_output + else: + self.output = dict_utils.deepmerge(self.output, new_output) + + def _update_input(self, new_input: Dict[str, Any]) -> None: + if self.input is None: + self.input = new_input + else: + self.input = dict_utils.deepmerge(self.input, new_input) + def init_end_time(self) -> "SpanData": self.end_time = datetime_helpers.local_timestamp() diff --git a/sdks/python/src/opik/api_objects/trace.py b/sdks/python/src/opik/api_objects/trace.py index 3a80c210c5..37bb8701df 100644 --- a/sdks/python/src/opik/api_objects/trace.py +++ b/sdks/python/src/opik/api_objects/trace.py @@ -242,6 +242,12 @@ def update(self, **new_data: Any) -> "TraceData": if key == "metadata": self._update_metadata(value) continue + elif key == "output": + self._update_output(value) + continue + elif key == "input": + self._update_input(value) + continue self.__dict__[key] = value @@ -253,6 +259,18 @@ def _update_metadata(self, new_metadata: Dict[str, Any]) -> None: else: self.metadata = dict_utils.deepmerge(self.metadata, new_metadata) + def _update_output(self, new_output: Dict[str, Any]) -> None: + if self.output is None: + self.output = new_output + else: + self.output = dict_utils.deepmerge(self.output, new_output) + + def _update_input(self, new_input: Dict[str, Any]) -> None: + if self.input is None: + self.input = new_input + else: + self.input = dict_utils.deepmerge(self.input, new_input) + def init_end_time(self) -> "TraceData": self.end_time = datetime_helpers.local_timestamp() return self diff --git a/sdks/python/src/opik/decorator/base_track_decorator.py b/sdks/python/src/opik/decorator/base_track_decorator.py index 814e240c2c..f1d551366d 100644 --- a/sdks/python/src/opik/decorator/base_track_decorator.py +++ b/sdks/python/src/opik/decorator/base_track_decorator.py @@ -399,7 +399,7 @@ def _after_call( ) client = opik_client.get_client_cached() - + print("span_data_to_end", span_data_to_end) span_data_to_end.init_end_time().update( **end_arguments.to_kwargs(), ) diff --git a/sdks/python/tests/unit/decorator/test_tracker_outputs.py b/sdks/python/tests/unit/decorator/test_tracker_outputs.py index aa542b1a62..3889fdb7c2 100644 --- a/sdks/python/tests/unit/decorator/test_tracker_outputs.py +++ b/sdks/python/tests/unit/decorator/test_tracker_outputs.py @@ -936,6 +936,44 @@ def f(x): assert_equal(EXPECTED_TRACE_TREE, fake_backend.trace_trees[0]) +def test_track__span_and_trace_output_updated_via_opik_context(fake_backend): + @tracker.track + def f(x): + opik_context.update_current_span( + output={"span-output-key": "span-output-value"}, + ) + opik_context.update_current_trace( + output={"trace-output-key": "trace-output-value"}, + ) + + return "f-output" + + f("f-input") + tracker.flush_tracker() + + EXPECTED_TRACE_TREE = TraceModel( + id=ANY_BUT_NONE, + name="f", + input={"x": "f-input"}, + output={"output": "f-output", "trace-output-key": "trace-output-value"}, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + spans=[ + SpanModel( + id=ANY_BUT_NONE, + name="f", + input={"x": "f-input"}, + output={"output": "f-output", "span-output-key": "span-output-value"}, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + spans=[], + ) + ], + ) + + assert len(fake_backend.trace_trees) == 1 + + assert_equal(EXPECTED_TRACE_TREE, fake_backend.trace_trees[0]) def test_track__span_and_trace_updated_via_opik_context_with_feedback_scores__feedback_scores_are_also_logged( fake_backend, @@ -989,6 +1027,7 @@ def f(x): assert_equal(EXPECTED_TRACE_TREE, fake_backend.trace_trees[0]) + def test_tracker__ignore_list_was_passed__ignored_inputs_are_not_logged(fake_backend): @tracker.track(ignore_arguments=["a", "c", "e", "unknown_argument"]) def f(a, b, c=3, d=4, e=5):