diff --git a/sdks/python/src/opik/decorator/base_track_decorator.py b/sdks/python/src/opik/decorator/base_track_decorator.py index 5a759df085..6682bdfe2c 100644 --- a/sdks/python/src/opik/decorator/base_track_decorator.py +++ b/sdks/python/src/opik/decorator/base_track_decorator.py @@ -41,6 +41,7 @@ def track( capture_input: bool = True, capture_output: bool = True, generations_aggregator: Optional[Callable[[List[Any]], Any]] = None, + flush: bool = False, ) -> Union[Callable, Callable[[Callable], Callable]]: """ Decorator to track the execution of a function. @@ -81,6 +82,7 @@ def track( capture_input=capture_input, capture_output=capture_output, generations_aggregator=generations_aggregator, + flush=flush, ) def decorator(func: Callable) -> Callable: @@ -93,6 +95,7 @@ def decorator(func: Callable) -> Callable: capture_input=capture_input, capture_output=capture_output, generations_aggregator=generations_aggregator, + flush=flush, ) return decorator @@ -107,6 +110,7 @@ def _decorate( capture_input: bool, capture_output: bool, generations_aggregator: Optional[Callable[[List[Any]], Any]], + flush: bool, ) -> Callable: if not inspect_helpers.is_async(func): return self._tracked_sync( @@ -118,6 +122,7 @@ def _decorate( capture_input=capture_input, capture_output=capture_output, generations_aggregator=generations_aggregator, + flush=flush, ) return self._tracked_async( @@ -129,6 +134,7 @@ def _decorate( capture_input=capture_input, capture_output=capture_output, generations_aggregator=generations_aggregator, + flush=flush, ) def _tracked_sync( @@ -141,6 +147,7 @@ def _tracked_sync( capture_input: bool, capture_output: bool, generations_aggregator: Optional[Callable[[List[Any]], str]], + flush: bool, ) -> Callable: @functools.wraps(func) def wrapper(*args, **kwargs) -> Any: # type: ignore @@ -180,6 +187,9 @@ def wrapper(*args, **kwargs) -> Any: # type: ignore output=result, capture_output=capture_output, ) + if flush: + opik_ = opik_client.get_client_cached() + opik_.flush() if result is not None: return result @@ -195,6 +205,7 @@ def _tracked_async( capture_input: bool, capture_output: bool, generations_aggregator: Optional[Callable[[List[Any]], str]], + flush: bool, ) -> Callable: @functools.wraps(func) async def wrapper(*args, **kwargs) -> Any: # type: ignore @@ -233,6 +244,9 @@ async def wrapper(*args, **kwargs) -> Any: # type: ignore output=result, capture_output=capture_output, ) + if flush: + opik_ = opik_client.get_client_cached() + opik_.flush() if result is not None: return result