diff --git a/src/neo4j_genai/experimental/pipeline/pipeline.py b/src/neo4j_genai/experimental/pipeline/pipeline.py index 7c95a30d..1955caea 100644 --- a/src/neo4j_genai/experimental/pipeline/pipeline.py +++ b/src/neo4j_genai/experimental/pipeline/pipeline.py @@ -48,10 +48,7 @@ class RunStatus(enum.Enum): UNKNOWN = "UNKNOWN" - SCHEDULED = "SCHEDULED" - WAITING = "WAITING" RUNNING = "RUNNING" - SKIP = "SKIP" DONE = "DONE" @@ -76,37 +73,6 @@ def __init__(self, name: str, component: Component): """ super().__init__(name, {}) self.component = component - self.status: dict[str, RunStatus] = {} - self._lock = asyncio.Lock() - """This lock is used to make sure we're not trying - to update the status in //. This should prevent the task to - be executed multiple times because the status was not known - by the orchestrator. - """ - - async def set_status(self, run_id: str, status: RunStatus) -> None: - """Set a new status - - Args: - run_id (str): Unique ID for the current pipeline run - status (RunStatus): New status - - Raises: - PipelineStatusUpdateError if the new status is not - compatible with the current one. - """ - async with self._lock: - current_status = self.status.get(run_id) - if status == current_status: - raise PipelineStatusUpdateError() - if status == RunStatus.RUNNING and current_status == RunStatus.DONE: - # can't go back to RUNNING from DONE - raise PipelineStatusUpdateError() - self.status[run_id] = status - - async def read_status(self, run_id: str) -> RunStatus: - async with self._lock: - return self.status.get(run_id, RunStatus.UNKNOWN) async def execute(self, **kwargs: Any) -> RunResult | None: """Execute the task @@ -163,31 +129,52 @@ async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None: None """ input_config = await self.get_input_config_for_task(task) - inputs = self.get_component_inputs(task.name, input_config, data) + inputs = await self.get_component_inputs(task.name, input_config, data) try: - await task.set_status(self.run_id, RunStatus.RUNNING) + await self.set_task_status(task.name, RunStatus.RUNNING) except PipelineStatusUpdateError: - logger.info( - f"Component {task.name} already running or done {task.status.get(self.run_id)}" - ) + logger.info(f"Component {task.name} already running or done") return None res = await task.run(inputs) - await task.set_status(self.run_id, RunStatus.DONE) + await self.set_task_status(task.name, RunStatus.DONE) if res: await self.on_task_complete(data=data, task=task, result=res) + async def set_task_status(self, task_name: str, status: RunStatus) -> None: + """Set a new status + + Args: + task_name (str): Name of the component + status (RunStatus): New status + + Raises: + PipelineStatusUpdateError if the new status is not + compatible with the current one. + """ + # prevent the method from being called by two concurrent async calls + async with asyncio.Lock(): + current_status = await self.get_status_for_component(task_name) + if status == current_status: + raise PipelineStatusUpdateError(f"Status is already '{status}'") + if status == RunStatus.RUNNING and current_status == RunStatus.DONE: + raise PipelineStatusUpdateError("Can't go from DONE to RUNNING") + return await self.pipeline.store.add_status_for_component( + self.run_id, task_name, status.value + ) + async def on_task_complete( self, data: dict[str, Any], task: TaskPipelineNode, result: RunResult ) -> None: """When a given task is complete, it will call this method to find the next tasks to run. """ - # first call the method for the pipeline - # this is where the results can be saved + # first save this component results res_to_save = None if result.result: res_to_save = result.result.model_dump() - self.add_result_for_component(task.name, res_to_save, is_final=task.is_leaf()) + await self.add_result_for_component( + task.name, res_to_save, is_final=task.is_leaf() + ) # then get the next tasks to be executed # and run them in // await asyncio.gather(*[self.run_task(n, data) async for n in self.next(task)]) @@ -200,8 +187,7 @@ async def check_dependencies_complete(self, task: TaskPipelineNode) -> None: """ dependencies = self.pipeline.previous_edges(task.name) for d in dependencies: - start_node = self.pipeline.get_node_by_name(d.start) - d_status = await start_node.read_status(self.run_id) + d_status = await self.get_status_for_component(d.start) if d_status != RunStatus.DONE: logger.warning( f"Missing dependency {d.start} for {task.name} (status: {d_status})" @@ -223,7 +209,7 @@ async def next( for next_edge in possible_next: next_node = self.pipeline.get_node_by_name(next_edge.end) # check status - next_node_status = await next_node.read_status(self.run_id) + next_node_status = await self.get_status_for_component(next_node.name) if next_node_status in [RunStatus.RUNNING, RunStatus.DONE]: # already running continue @@ -251,8 +237,7 @@ async def get_input_config_for_task(self, task: TaskPipelineNode) -> dict[str, s # make sure dependencies are satisfied # and save the inputs defs that needs to be propagated from parent components for prev_edge in self.pipeline.previous_edges(task.name): - prev_node = self.pipeline.get_node_by_name(prev_edge.start) - prev_status = await prev_node.read_status(self.run_id) + prev_status = await self.get_status_for_component(prev_edge.start) if prev_status != RunStatus.DONE: logger.critical(f"Missing dependency {prev_edge.start}") raise PipelineMissingDependencyError(f"{prev_edge.start} not ready") @@ -261,7 +246,7 @@ async def get_input_config_for_task(self, task: TaskPipelineNode) -> dict[str, s input_config.update(**prev_edge_data) return input_config - def get_component_inputs( + async def get_component_inputs( self, component_name: str, input_config: dict[str, Any], @@ -287,7 +272,7 @@ def get_component_inputs( # component as input component = mapping output_param = None - component_result = self.get_results_for_component(component) + component_result = await self.get_results_for_component(component) if output_param is not None: value = component_result.get(output_param) else: @@ -299,25 +284,31 @@ def get_component_inputs( component_inputs[parameter] = value return component_inputs - def add_result_for_component( + async def add_result_for_component( self, name: str, result: dict[str, Any] | None, is_final: bool = False ) -> None: """This is where we save the results in the result store and, optionally, in the final result store. """ - self.pipeline.store.add_result_for_component(self.run_id, name, result) + await self.pipeline.store.add_result_for_component(self.run_id, name, result) if is_final: # The pipeline only returns the results # of the leaf nodes # TODO: make this configurable in the future. - existing_results = self.pipeline.final_results.get(self.run_id) or {} + existing_results = await self.pipeline.final_results.get(self.run_id) or {} existing_results[name] = result - self.pipeline.final_results.add( + await self.pipeline.final_results.add( self.run_id, existing_results, overwrite=True ) - def get_results_for_component(self, name: str) -> Any: - return self.pipeline.store.get_result_for_component(self.run_id, name) + async def get_results_for_component(self, name: str) -> Any: + return await self.pipeline.store.get_result_for_component(self.run_id, name) + + async def get_status_for_component(self, name: str) -> RunStatus: + status = await self.pipeline.store.get_status_for_component(self.run_id, name) + if status is None: + return RunStatus.UNKNOWN + return RunStatus(status) async def run(self, data: dict[str, Any]) -> None: """Run the pipline, starting from the root nodes @@ -500,5 +491,5 @@ async def run(self, data: dict[str, Any]) -> PipelineResult: ) return PipelineResult( run_id=orchestrator.run_id, - result=self.final_results.get(orchestrator.run_id), + result=await self.final_results.get(orchestrator.run_id), ) diff --git a/src/neo4j_genai/experimental/pipeline/stores.py b/src/neo4j_genai/experimental/pipeline/stores.py index 08055627..855546d6 100644 --- a/src/neo4j_genai/experimental/pipeline/stores.py +++ b/src/neo4j_genai/experimental/pipeline/stores.py @@ -19,6 +19,7 @@ from __future__ import annotations import abc +import asyncio from typing import Any @@ -26,7 +27,7 @@ class Store(abc.ABC): """An interface to save component outputs""" @abc.abstractmethod - def add(self, key: str, value: Any, overwrite: bool = True) -> None: + async def add(self, key: str, value: Any, overwrite: bool = True) -> None: """ Args: key (str): The key to access the data. @@ -41,7 +42,7 @@ def add(self, key: str, value: Any, overwrite: bool = True) -> None: pass @abc.abstractmethod - def get(self, key: str) -> Any: + async def get(self, key: str) -> Any: """Retrieve value for `key`. If key not found, returns None. """ @@ -62,16 +63,32 @@ def empty(self) -> None: class ResultStore(Store, abc.ABC): @staticmethod - def get_key(run_id: str, task_name: str) -> str: - return f"{run_id}:{task_name}" + def get_key(run_id: str, task_name: str, suffix: str = "") -> str: + key = f"{run_id}:{task_name}" + if suffix: + key += f":{suffix}" + return key + + async def add_status_for_component( + self, + run_id: str, + task_name: str, + status: str, + ) -> None: + await self.add( + self.get_key(run_id, task_name, "status"), status, overwrite=True + ) + + async def get_status_for_component(self, run_id: str, task_name: str) -> Any: + return await self.get(self.get_key(run_id, task_name, "status")) - def add_result_for_component( + async def add_result_for_component( self, run_id: str, task_name: str, result: Any, overwrite: bool = False ) -> None: - self.add(self.get_key(run_id, task_name), result, overwrite=overwrite) + await self.add(self.get_key(run_id, task_name), result, overwrite=overwrite) - def get_result_for_component(self, run_id: str, task_name: str) -> Any: - return self.get(self.get_key(run_id, task_name)) + async def get_result_for_component(self, run_id: str, task_name: str) -> Any: + return await self.get(self.get_key(run_id, task_name)) class InMemoryStore(ResultStore): @@ -80,14 +97,18 @@ class InMemoryStore(ResultStore): def __init__(self) -> None: self._data: dict[str, Any] = {} - - def add(self, key: str, value: Any, overwrite: bool = True) -> None: - if (not overwrite) and key in self._data: - raise KeyError(f"{key} already exists") - self._data[key] = value - - def get(self, key: str) -> Any: - return self._data.get(key) + self._lock = asyncio.Lock() + """This lock is used to prevent read while a write in ongoing and vice-versa.""" + + async def add(self, key: str, value: Any, overwrite: bool = True) -> None: + async with self._lock: + if (not overwrite) and key in self._data: + raise KeyError(f"{key} already exists") + self._data[key] = value + + async def get(self, key: str) -> Any: + async with self._lock: + return self._data.get(key) def all(self) -> dict[str, Any]: return self._data diff --git a/tests/e2e/test_kg_builder_pipeline_e2e.py b/tests/e2e/test_kg_builder_pipeline_e2e.py index 868a5e49..db27d825 100644 --- a/tests/e2e/test_kg_builder_pipeline_e2e.py +++ b/tests/e2e/test_kg_builder_pipeline_e2e.py @@ -261,9 +261,13 @@ async def test_pipeline_builder_happy_path( assert res.run_id is not None assert res.result == {"writer": {"status": "SUCCESS"}} # check component's results - chunks = kg_builder_pipeline.store.get_result_for_component(res.run_id, "splitter") + chunks = await kg_builder_pipeline.store.get_result_for_component( + res.run_id, "splitter" + ) assert len(chunks["chunks"]) == 3 - graph = kg_builder_pipeline.store.get_result_for_component(res.run_id, "extractor") + graph = await kg_builder_pipeline.store.get_result_for_component( + res.run_id, "extractor" + ) # 3 entities + 3 chunks + 1 document nodes = graph["nodes"] assert len(nodes) == 7 @@ -463,9 +467,13 @@ async def test_pipeline_builder_failing_chunk_do_not_raise( assert res.run_id is not None assert res.result == {"writer": {"status": "SUCCESS"}} # check component's results - chunks = kg_builder_pipeline.store.get_result_for_component(res.run_id, "splitter") + chunks = await kg_builder_pipeline.store.get_result_for_component( + res.run_id, "splitter" + ) assert len(chunks["chunks"]) == 3 - graph = kg_builder_pipeline.store.get_result_for_component(res.run_id, "extractor") + graph = await kg_builder_pipeline.store.get_result_for_component( + res.run_id, "extractor" + ) # 3 entities + 3 chunks nodes = graph["nodes"] assert len(nodes) == 6 diff --git a/tests/unit/experimental/pipeline/test_orchestrator.py b/tests/unit/experimental/pipeline/test_orchestrator.py index 9e21b5da..5e95e0b4 100644 --- a/tests/unit/experimental/pipeline/test_orchestrator.py +++ b/tests/unit/experimental/pipeline/test_orchestrator.py @@ -24,7 +24,8 @@ ) -def test_orchestrator_get_component_inputs_from_user_only() -> None: +@pytest.mark.asyncio +async def test_orchestrator_get_component_inputs_from_user_only() -> None: """Components take all their inputs from user input.""" pipe = Pipeline() pipe.add_component(ComponentPassThrough(), "a") @@ -34,16 +35,19 @@ def test_orchestrator_get_component_inputs_from_user_only() -> None: "a": {"value": "user input for component a"}, "b": {"value": "user input for component b"}, } - data = orchestrator.get_component_inputs("a", {}, input_data) + data = await orchestrator.get_component_inputs("a", {}, input_data) assert data == {"value": "user input for component a"} - data = orchestrator.get_component_inputs("b", {}, input_data) + data = await orchestrator.get_component_inputs("b", {}, input_data) assert data == {"value": "user input for component b"} @patch( "neo4j_genai.experimental.pipeline.pipeline.Orchestrator.get_results_for_component" ) -def test_pipeline_get_component_inputs_from_parent_specific(mock_result: Mock) -> None: +@pytest.mark.asyncio +async def test_pipeline_get_component_inputs_from_parent_specific( + mock_result: Mock, +) -> None: """Propagate one specific output field from 'a' to the next component.""" pipe = Pipeline() pipe.add_component(ComponentPassThrough(), "a") @@ -54,14 +58,17 @@ def test_pipeline_get_component_inputs_from_parent_specific(mock_result: Mock) - mock_result.return_value = {"result": "output from component a"} orchestrator = Orchestrator(pipe) - data = orchestrator.get_component_inputs("b", {"value": "a.result"}, {}) + data = await orchestrator.get_component_inputs("b", {"value": "a.result"}, {}) assert data == {"value": "output from component a"} @patch( "neo4j_genai.experimental.pipeline.pipeline.Orchestrator.get_results_for_component" ) -def test_orchestrator_get_component_inputs_from_parent_all(mock_result: Mock) -> None: +@pytest.mark.asyncio +async def test_orchestrator_get_component_inputs_from_parent_all( + mock_result: Mock, +) -> None: """Use the component name to get the full output (without extracting a specific field). """ @@ -74,14 +81,15 @@ def test_orchestrator_get_component_inputs_from_parent_all(mock_result: Mock) -> mock_result.return_value = {"result": "output from component a"} orchestrator = Orchestrator(pipe) - data = orchestrator.get_component_inputs("b", {"value": "a"}, {}) + data = await orchestrator.get_component_inputs("b", {"value": "a"}, {}) assert data == {"value": {"result": "output from component a"}} @patch( "neo4j_genai.experimental.pipeline.pipeline.Orchestrator.get_results_for_component" ) -def test_orchestrator_get_component_inputs_from_parent_and_input( +@pytest.mark.asyncio +async def test_orchestrator_get_component_inputs_from_parent_and_input( mock_result: Mock, ) -> None: """Some parameters from user input, some other parameter from previous component.""" @@ -94,7 +102,7 @@ def test_orchestrator_get_component_inputs_from_parent_and_input( mock_result.return_value = {"result": "output from component a"} orchestrator = Orchestrator(pipe) - data = orchestrator.get_component_inputs( + data = await orchestrator.get_component_inputs( "b", {"value": "a"}, {"b": {"other_value": "user input for component b 'other_value' param"}}, @@ -108,7 +116,8 @@ def test_orchestrator_get_component_inputs_from_parent_and_input( @patch( "neo4j_genai.experimental.pipeline.pipeline.Orchestrator.get_results_for_component" ) -def test_orchestrator_get_component_inputs_ignore_user_input_if_input_def_provided( +@pytest.mark.asyncio +async def test_orchestrator_get_component_inputs_ignore_user_input_if_input_def_provided( mock_result: Mock, ) -> None: """If a parameter is defined both in the user input and in an input definition @@ -125,7 +134,7 @@ def test_orchestrator_get_component_inputs_ignore_user_input_if_input_def_provid orchestrator = Orchestrator(pipe) with pytest.warns(Warning) as w: - data = orchestrator.get_component_inputs( + data = await orchestrator.get_component_inputs( "b", {"value": "a"}, {"b": {"value": "user input for component a"}} ) assert data == {"value": {"result": "output from component a"}} @@ -158,27 +167,60 @@ def pipeline_aggregation() -> Pipeline: @pytest.mark.asyncio -async def test_orchestrator_branch(pipeline_branch: Pipeline) -> None: +@patch( + "neo4j_genai.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" +) +async def test_orchestrator_branch( + mock_status: Mock, pipeline_branch: Pipeline +) -> None: + """a -> b, c""" orchestrator = Orchestrator(pipeline=pipeline_branch) node_a = pipeline_branch.get_node_by_name("a") - node_a.status = {orchestrator.run_id: RunStatus.DONE} + mock_status.side_effect = [ + # next b + RunStatus.UNKNOWN, + # dep of b = a + RunStatus.DONE, + # next c + RunStatus.UNKNOWN, + # dep of c = a + RunStatus.DONE, + ] next_tasks = [n async for n in orchestrator.next(node_a)] next_task_names = [n.name for n in next_tasks] assert next_task_names == ["b", "c"] @pytest.mark.asyncio -async def test_orchestrator_aggregation(pipeline_aggregation: Pipeline) -> None: +@patch( + "neo4j_genai.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" +) +async def test_orchestrator_aggregation( + mock_status: Mock, pipeline_aggregation: Pipeline +) -> None: + """a, b -> c""" orchestrator = Orchestrator(pipeline=pipeline_aggregation) node_a = pipeline_aggregation.get_node_by_name("a") - node_a.status = {orchestrator.run_id: RunStatus.DONE} - next_tasks = [n async for n in orchestrator.next(node_a)] - next_task_names = [n.name for n in next_tasks] - # "c" not ready yet + mock_status.side_effect = [ + # next c: + RunStatus.UNKNOWN, + # dep of c = a + RunStatus.DONE, + # dep of c = b + RunStatus.UNKNOWN, + ] + next_task_names = [n.name async for n in orchestrator.next(node_a)] + # "c" dependencies not ready yet assert next_task_names == [] # set "b" to DONE - node_b = pipeline_aggregation.get_node_by_name("b") - node_b.status = {orchestrator.run_id: RunStatus.DONE} + mock_status.side_effect = [ + # next c: + RunStatus.UNKNOWN, + # dep of c = a + RunStatus.DONE, + # dep of c = b + RunStatus.DONE, + ] # then "c" can start next_tasks = [n async for n in orchestrator.next(node_a)] next_task_names = [n.name for n in next_tasks] @@ -189,6 +231,5 @@ async def test_orchestrator_aggregation(pipeline_aggregation: Pipeline) -> None: async def test_orchestrator_aggregation_waiting(pipeline_aggregation: Pipeline) -> None: orchestrator = Orchestrator(pipeline=pipeline_aggregation) node_a = pipeline_aggregation.get_node_by_name("a") - node_a.status = {orchestrator.run_id: RunStatus.DONE} next_tasks = [n async for n in orchestrator.next(node_a)] assert next_tasks == [] diff --git a/tests/unit/experimental/pipeline/test_store.py b/tests/unit/experimental/pipeline/test_store.py index 5de7c1f8..bcaee20b 100644 --- a/tests/unit/experimental/pipeline/test_store.py +++ b/tests/unit/experimental/pipeline/test_store.py @@ -2,12 +2,14 @@ from neo4j_genai.experimental.pipeline.stores import InMemoryStore -def test_memory_store() -> None: +@pytest.mark.asyncio +async def test_memory_store() -> None: store = InMemoryStore() - store.add("key", "value") - assert store.get("key") == "value" + await store.add("key", "value") + res = await store.get("key") + assert res == "value" with pytest.raises(KeyError): - store.add("key", "value", overwrite=False) + await store.add("key", "value", overwrite=False) assert store.all() == {"key": "value"}