Skip to content

Commit

Permalink
Async pipeline improvements (#123)
Browse files Browse the repository at this point in the history
* Add failing test

* Define a "run_id" in Orchestrator - save results per run_id

* Make unit test work

* Make intermediate results accessible from outside pipeline for investigation

* Remove unused imports

* Update examples and CHANGELOG

* Cleaning: remove deprecated code

* Fix ruff

* Fix examples

* Fix examples again

* Move status to store

* PR reviews

* Removing useless status assignment

* Remove unused import

* Move status to store

* Return RunStatus from method

* Fix bad merge

* Fix comments

* Deal with None statuses in the method dedicated to fetching status - Remove unused statuses

* Fix error message

* Update error message
  • Loading branch information
stellasia authored Sep 10, 2024
1 parent 19fbace commit b65b34f
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 102 deletions.
105 changes: 48 additions & 57 deletions src/neo4j_genai/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,7 @@

class RunStatus(enum.Enum):
UNKNOWN = "UNKNOWN"
SCHEDULED = "SCHEDULED"
WAITING = "WAITING"
RUNNING = "RUNNING"
SKIP = "SKIP"
DONE = "DONE"


Expand All @@ -76,37 +73,6 @@ def __init__(self, name: str, component: Component):
"""
super().__init__(name, {})
self.component = component
self.status: dict[str, RunStatus] = {}
self._lock = asyncio.Lock()
"""This lock is used to make sure we're not trying
to update the status in //. This should prevent the task to
be executed multiple times because the status was not known
by the orchestrator.
"""

async def set_status(self, run_id: str, status: RunStatus) -> None:
"""Set a new status
Args:
run_id (str): Unique ID for the current pipeline run
status (RunStatus): New status
Raises:
PipelineStatusUpdateError if the new status is not
compatible with the current one.
"""
async with self._lock:
current_status = self.status.get(run_id)
if status == current_status:
raise PipelineStatusUpdateError()
if status == RunStatus.RUNNING and current_status == RunStatus.DONE:
# can't go back to RUNNING from DONE
raise PipelineStatusUpdateError()
self.status[run_id] = status

async def read_status(self, run_id: str) -> RunStatus:
async with self._lock:
return self.status.get(run_id, RunStatus.UNKNOWN)

async def execute(self, **kwargs: Any) -> RunResult | None:
"""Execute the task
Expand Down Expand Up @@ -163,31 +129,52 @@ async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
None
"""
input_config = await self.get_input_config_for_task(task)
inputs = self.get_component_inputs(task.name, input_config, data)
inputs = await self.get_component_inputs(task.name, input_config, data)
try:
await task.set_status(self.run_id, RunStatus.RUNNING)
await self.set_task_status(task.name, RunStatus.RUNNING)
except PipelineStatusUpdateError:
logger.info(
f"Component {task.name} already running or done {task.status.get(self.run_id)}"
)
logger.info(f"Component {task.name} already running or done")
return None
res = await task.run(inputs)
await task.set_status(self.run_id, RunStatus.DONE)
await self.set_task_status(task.name, RunStatus.DONE)
if res:
await self.on_task_complete(data=data, task=task, result=res)

async def set_task_status(self, task_name: str, status: RunStatus) -> None:
"""Set a new status
Args:
task_name (str): Name of the component
status (RunStatus): New status
Raises:
PipelineStatusUpdateError if the new status is not
compatible with the current one.
"""
# prevent the method from being called by two concurrent async calls
async with asyncio.Lock():
current_status = await self.get_status_for_component(task_name)
if status == current_status:
raise PipelineStatusUpdateError(f"Status is already '{status}'")
if status == RunStatus.RUNNING and current_status == RunStatus.DONE:
raise PipelineStatusUpdateError("Can't go from DONE to RUNNING")
return await self.pipeline.store.add_status_for_component(
self.run_id, task_name, status.value
)

async def on_task_complete(
self, data: dict[str, Any], task: TaskPipelineNode, result: RunResult
) -> None:
"""When a given task is complete, it will call this method
to find the next tasks to run.
"""
# first call the method for the pipeline
# this is where the results can be saved
# first save this component results
res_to_save = None
if result.result:
res_to_save = result.result.model_dump()
self.add_result_for_component(task.name, res_to_save, is_final=task.is_leaf())
await self.add_result_for_component(
task.name, res_to_save, is_final=task.is_leaf()
)
# then get the next tasks to be executed
# and run them in //
await asyncio.gather(*[self.run_task(n, data) async for n in self.next(task)])
Expand All @@ -200,8 +187,7 @@ async def check_dependencies_complete(self, task: TaskPipelineNode) -> None:
"""
dependencies = self.pipeline.previous_edges(task.name)
for d in dependencies:
start_node = self.pipeline.get_node_by_name(d.start)
d_status = await start_node.read_status(self.run_id)
d_status = await self.get_status_for_component(d.start)
if d_status != RunStatus.DONE:
logger.warning(
f"Missing dependency {d.start} for {task.name} (status: {d_status})"
Expand All @@ -223,7 +209,7 @@ async def next(
for next_edge in possible_next:
next_node = self.pipeline.get_node_by_name(next_edge.end)
# check status
next_node_status = await next_node.read_status(self.run_id)
next_node_status = await self.get_status_for_component(next_node.name)
if next_node_status in [RunStatus.RUNNING, RunStatus.DONE]:
# already running
continue
Expand Down Expand Up @@ -251,8 +237,7 @@ async def get_input_config_for_task(self, task: TaskPipelineNode) -> dict[str, s
# make sure dependencies are satisfied
# and save the inputs defs that needs to be propagated from parent components
for prev_edge in self.pipeline.previous_edges(task.name):
prev_node = self.pipeline.get_node_by_name(prev_edge.start)
prev_status = await prev_node.read_status(self.run_id)
prev_status = await self.get_status_for_component(prev_edge.start)
if prev_status != RunStatus.DONE:
logger.critical(f"Missing dependency {prev_edge.start}")
raise PipelineMissingDependencyError(f"{prev_edge.start} not ready")
Expand All @@ -261,7 +246,7 @@ async def get_input_config_for_task(self, task: TaskPipelineNode) -> dict[str, s
input_config.update(**prev_edge_data)
return input_config

def get_component_inputs(
async def get_component_inputs(
self,
component_name: str,
input_config: dict[str, Any],
Expand All @@ -287,7 +272,7 @@ def get_component_inputs(
# component as input
component = mapping
output_param = None
component_result = self.get_results_for_component(component)
component_result = await self.get_results_for_component(component)
if output_param is not None:
value = component_result.get(output_param)
else:
Expand All @@ -299,25 +284,31 @@ def get_component_inputs(
component_inputs[parameter] = value
return component_inputs

def add_result_for_component(
async def add_result_for_component(
self, name: str, result: dict[str, Any] | None, is_final: bool = False
) -> None:
"""This is where we save the results in the result store and, optionally,
in the final result store.
"""
self.pipeline.store.add_result_for_component(self.run_id, name, result)
await self.pipeline.store.add_result_for_component(self.run_id, name, result)
if is_final:
# The pipeline only returns the results
# of the leaf nodes
# TODO: make this configurable in the future.
existing_results = self.pipeline.final_results.get(self.run_id) or {}
existing_results = await self.pipeline.final_results.get(self.run_id) or {}
existing_results[name] = result
self.pipeline.final_results.add(
await self.pipeline.final_results.add(
self.run_id, existing_results, overwrite=True
)

def get_results_for_component(self, name: str) -> Any:
return self.pipeline.store.get_result_for_component(self.run_id, name)
async def get_results_for_component(self, name: str) -> Any:
return await self.pipeline.store.get_result_for_component(self.run_id, name)

async def get_status_for_component(self, name: str) -> RunStatus:
status = await self.pipeline.store.get_status_for_component(self.run_id, name)
if status is None:
return RunStatus.UNKNOWN
return RunStatus(status)

async def run(self, data: dict[str, Any]) -> None:
"""Run the pipline, starting from the root nodes
Expand Down Expand Up @@ -500,5 +491,5 @@ async def run(self, data: dict[str, Any]) -> PipelineResult:
)
return PipelineResult(
run_id=orchestrator.run_id,
result=self.final_results.get(orchestrator.run_id),
result=await self.final_results.get(orchestrator.run_id),
)
53 changes: 37 additions & 16 deletions src/neo4j_genai/experimental/pipeline/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
from __future__ import annotations

import abc
import asyncio
from typing import Any


class Store(abc.ABC):
"""An interface to save component outputs"""

@abc.abstractmethod
def add(self, key: str, value: Any, overwrite: bool = True) -> None:
async def add(self, key: str, value: Any, overwrite: bool = True) -> None:
"""
Args:
key (str): The key to access the data.
Expand All @@ -41,7 +42,7 @@ def add(self, key: str, value: Any, overwrite: bool = True) -> None:
pass

@abc.abstractmethod
def get(self, key: str) -> Any:
async def get(self, key: str) -> Any:
"""Retrieve value for `key`.
If key not found, returns None.
"""
Expand All @@ -62,16 +63,32 @@ def empty(self) -> None:

class ResultStore(Store, abc.ABC):
@staticmethod
def get_key(run_id: str, task_name: str) -> str:
return f"{run_id}:{task_name}"
def get_key(run_id: str, task_name: str, suffix: str = "") -> str:
key = f"{run_id}:{task_name}"
if suffix:
key += f":{suffix}"
return key

async def add_status_for_component(
self,
run_id: str,
task_name: str,
status: str,
) -> None:
await self.add(
self.get_key(run_id, task_name, "status"), status, overwrite=True
)

async def get_status_for_component(self, run_id: str, task_name: str) -> Any:
return await self.get(self.get_key(run_id, task_name, "status"))

def add_result_for_component(
async def add_result_for_component(
self, run_id: str, task_name: str, result: Any, overwrite: bool = False
) -> None:
self.add(self.get_key(run_id, task_name), result, overwrite=overwrite)
await self.add(self.get_key(run_id, task_name), result, overwrite=overwrite)

def get_result_for_component(self, run_id: str, task_name: str) -> Any:
return self.get(self.get_key(run_id, task_name))
async def get_result_for_component(self, run_id: str, task_name: str) -> Any:
return await self.get(self.get_key(run_id, task_name))


class InMemoryStore(ResultStore):
Expand All @@ -80,14 +97,18 @@ class InMemoryStore(ResultStore):

def __init__(self) -> None:
self._data: dict[str, Any] = {}

def add(self, key: str, value: Any, overwrite: bool = True) -> None:
if (not overwrite) and key in self._data:
raise KeyError(f"{key} already exists")
self._data[key] = value

def get(self, key: str) -> Any:
return self._data.get(key)
self._lock = asyncio.Lock()
"""This lock is used to prevent read while a write in ongoing and vice-versa."""

async def add(self, key: str, value: Any, overwrite: bool = True) -> None:
async with self._lock:
if (not overwrite) and key in self._data:
raise KeyError(f"{key} already exists")
self._data[key] = value

async def get(self, key: str) -> Any:
async with self._lock:
return self._data.get(key)

def all(self) -> dict[str, Any]:
return self._data
Expand Down
16 changes: 12 additions & 4 deletions tests/e2e/test_kg_builder_pipeline_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,13 @@ async def test_pipeline_builder_happy_path(
assert res.run_id is not None
assert res.result == {"writer": {"status": "SUCCESS"}}
# check component's results
chunks = kg_builder_pipeline.store.get_result_for_component(res.run_id, "splitter")
chunks = await kg_builder_pipeline.store.get_result_for_component(
res.run_id, "splitter"
)
assert len(chunks["chunks"]) == 3
graph = kg_builder_pipeline.store.get_result_for_component(res.run_id, "extractor")
graph = await kg_builder_pipeline.store.get_result_for_component(
res.run_id, "extractor"
)
# 3 entities + 3 chunks + 1 document
nodes = graph["nodes"]
assert len(nodes) == 7
Expand Down Expand Up @@ -463,9 +467,13 @@ async def test_pipeline_builder_failing_chunk_do_not_raise(
assert res.run_id is not None
assert res.result == {"writer": {"status": "SUCCESS"}}
# check component's results
chunks = kg_builder_pipeline.store.get_result_for_component(res.run_id, "splitter")
chunks = await kg_builder_pipeline.store.get_result_for_component(
res.run_id, "splitter"
)
assert len(chunks["chunks"]) == 3
graph = kg_builder_pipeline.store.get_result_for_component(res.run_id, "extractor")
graph = await kg_builder_pipeline.store.get_result_for_component(
res.run_id, "extractor"
)
# 3 entities + 3 chunks
nodes = graph["nodes"]
assert len(nodes) == 6
Expand Down
Loading

0 comments on commit b65b34f

Please sign in to comment.