Skip to content

Commit

Permalink
Make pipeline work with parallel runs (#119)
Browse files Browse the repository at this point in the history
* Add failing test

* Define a "run_id" in Orchestrator - save results per run_id

* Make unit test work

* Make intermediate results accessible from outside pipeline for investigation

* Remove unused imports

* Update examples and CHANGELOG

* Cleaning: remove deprecated code

* Fix ruff

* Fix examples

* Fix examples again

* PR reviews

* Removing useless status assignment
  • Loading branch information
stellasia authored Sep 8, 2024
1 parent 411b5ea commit c284b08
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 152 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## Next

## Fixed
- Pipelines now return correct results when the same pipeline is run in parallel.

## 0.5.0

### Added
Expand Down
5 changes: 3 additions & 2 deletions examples/pipeline/kg_builder_from_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@
LangChainTextSplitterAdapter,
)
from neo4j_genai.experimental.pipeline import Component, DataModel
from neo4j_genai.experimental.pipeline.pipeline import PipelineResult
from neo4j_genai.llm import OpenAILLM
from pydantic import BaseModel, validate_call

logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=logging.INFO)


class DocumentChunkModel(DataModel):
Expand Down Expand Up @@ -98,7 +99,7 @@ async def run(self, graph: Neo4jGraph) -> WriterModel:
)


async def main(neo4j_driver: neo4j.Driver) -> dict[str, Any]:
async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:
from neo4j_genai.experimental.pipeline import Pipeline

# Instantiate Entity and Relation objects
Expand Down
7 changes: 5 additions & 2 deletions examples/pipeline/kg_builder_from_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

import asyncio
import logging.config
from typing import Any

import neo4j
from langchain_text_splitters import CharacterTextSplitter
from neo4j_genai.embeddings.openai import OpenAIEmbeddings
from neo4j_genai.experimental.components.embedder import TextChunkEmbedder
from neo4j_genai.experimental.components.entity_relation_extractor import (
LLMEntityRelationExtractor,
OnError,
Expand All @@ -35,6 +36,7 @@
LangChainTextSplitterAdapter,
)
from neo4j_genai.experimental.pipeline import Pipeline
from neo4j_genai.experimental.pipeline.pipeline import PipelineResult
from neo4j_genai.llm import OpenAILLM

# set log level to DEBUG for all neo4j_genai.* loggers
Expand All @@ -58,7 +60,7 @@
)


async def main(neo4j_driver: neo4j.Driver) -> dict[str, Any]:
async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:
"""This is where we define and run the KG builder pipeline, instantiating a few
components:
- Text Splitter: in this example we use a text splitter from the LangChain package
Expand All @@ -80,6 +82,7 @@ async def main(neo4j_driver: neo4j.Driver) -> dict[str, Any]:
),
"splitter",
)
pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder")
pipe.add_component(SchemaBuilder(), "schema")
pipe.add_component(
LLMEntityRelationExtractor(
Expand Down
28 changes: 16 additions & 12 deletions examples/pipeline/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
from __future__ import annotations

import asyncio
from typing import List

import neo4j
from neo4j_genai.embeddings.openai import OpenAIEmbeddings
from neo4j_genai.experimental.pipeline import Component, Pipeline
from neo4j_genai.experimental.pipeline.component import DataModel
from neo4j_genai.experimental.pipeline.pipeline import PipelineResult
from neo4j_genai.experimental.pipeline.types import (
ComponentConfig,
ConnectionConfig,
Expand All @@ -37,35 +39,37 @@
from neo4j_genai.retrievers.base import Retriever


class StringDataModel(DataModel):
result: str
class ComponentResultDataModel(DataModel):
"""A simple DataModel with a single text field"""

text: str


class RetrieverComponent(Component):
def __init__(self, retriever: Retriever) -> None:
self.retriever = retriever

async def run(self, query: str) -> StringDataModel:
async def run(self, query: str) -> ComponentResultDataModel:
res = self.retriever.search(query_text=query)
return StringDataModel(result="\n".join(c.content for c in res.items))
return ComponentResultDataModel(text="\n".join(c.content for c in res.items))


class PromptTemplateComponent(Component):
def __init__(self, prompt: PromptTemplate) -> None:
self.prompt = prompt

async def run(self, query: str, context: list[str]) -> StringDataModel:
async def run(self, query: str, context: List[str]) -> ComponentResultDataModel:
prompt = self.prompt.format(query, context, examples="")
return StringDataModel(result=prompt)
return ComponentResultDataModel(text=prompt)


class LLMComponent(Component):
def __init__(self, llm: LLMInterface) -> None:
self.llm = llm

async def run(self, prompt: str) -> StringDataModel:
async def run(self, prompt: str) -> ComponentResultDataModel:
llm_response = self.llm.invoke(prompt)
return StringDataModel(result=llm_response.content)
return ComponentResultDataModel(text=llm_response.content)


if __name__ == "__main__":
Expand Down Expand Up @@ -96,21 +100,21 @@ async def run(self, prompt: str) -> StringDataModel:
ConnectionConfig(
start="retrieve",
end="augment",
input_config={"context": "retrieve.result"},
input_config={"context": "retrieve.text"},
),
ConnectionConfig(
start="augment",
end="generate",
input_config={"prompt": "augment.result"},
input_config={"prompt": "augment.text"},
),
],
)
)

query = "A movie about the US presidency"
result = asyncio.run(
pipe_output: PipelineResult = asyncio.run(
pipe.run({"retrieve": {"query": query}, "augment": {"query": query}})
)
print(result["generate"]["result"])
print(pipe_output.result["generate"]["text"])

driver.close()
Loading

0 comments on commit c284b08

Please sign in to comment.