Skip to content

Commit

Permalink
Add @track(flush=True)
Browse files Browse the repository at this point in the history
  • Loading branch information
dsblank committed Sep 13, 2024
1 parent 501aacd commit b209cf1
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions sdks/python/src/opik/decorator/base_track_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def track(
capture_input: bool = True,
capture_output: bool = True,
generations_aggregator: Optional[Callable[[List[Any]], Any]] = None,
flush: bool = False,
) -> Union[Callable, Callable[[Callable], Callable]]:
"""
Decorator to track the execution of a function.
Expand Down Expand Up @@ -81,6 +82,7 @@ def track(
capture_input=capture_input,
capture_output=capture_output,
generations_aggregator=generations_aggregator,
flush=flush,
)

def decorator(func: Callable) -> Callable:
Expand All @@ -93,6 +95,7 @@ def decorator(func: Callable) -> Callable:
capture_input=capture_input,
capture_output=capture_output,
generations_aggregator=generations_aggregator,
flush=flush,
)

return decorator
Expand All @@ -107,6 +110,7 @@ def _decorate(
capture_input: bool,
capture_output: bool,
generations_aggregator: Optional[Callable[[List[Any]], Any]],
flush: bool,
) -> Callable:
if not inspect_helpers.is_async(func):
return self._tracked_sync(
Expand All @@ -118,6 +122,7 @@ def _decorate(
capture_input=capture_input,
capture_output=capture_output,
generations_aggregator=generations_aggregator,
flush=flush,
)

return self._tracked_async(
Expand All @@ -129,6 +134,7 @@ def _decorate(
capture_input=capture_input,
capture_output=capture_output,
generations_aggregator=generations_aggregator,
flush=flush,
)

def _tracked_sync(
Expand All @@ -141,6 +147,7 @@ def _tracked_sync(
capture_input: bool,
capture_output: bool,
generations_aggregator: Optional[Callable[[List[Any]], str]],
flush: bool,
) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any: # type: ignore
Expand Down Expand Up @@ -180,6 +187,9 @@ def wrapper(*args, **kwargs) -> Any: # type: ignore
output=result,
capture_output=capture_output,
)
if flush:
opik_ = opik_client.get_client_cached()
opik_.flush()
if result is not None:
return result

Expand All @@ -195,6 +205,7 @@ def _tracked_async(
capture_input: bool,
capture_output: bool,
generations_aggregator: Optional[Callable[[List[Any]], str]],
flush: bool,
) -> Callable:
@functools.wraps(func)
async def wrapper(*args, **kwargs) -> Any: # type: ignore
Expand Down Expand Up @@ -233,6 +244,9 @@ async def wrapper(*args, **kwargs) -> Any: # type: ignore
output=result,
capture_output=capture_output,
)
if flush:
opik_ = opik_client.get_client_cached()
opik_.flush()
if result is not None:
return result

Expand Down

0 comments on commit b209cf1

Please sign in to comment.