diff --git a/CHANGELOG.md b/CHANGELOG.md index 7254d570..2ac264f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ## Added - Integrated json-repair package to handle and repair invalid JSON generated by LLMs. - Introduced InvalidJSONError exception for handling cases where JSON repair fails. +- Ability to create a Pipeline or SimpleKGPipeline from a config file. See [the example](examples/build_graph/from_config_files/simple_kg_pipeline_from_config_file.py). ## Changed - Updated LLM prompts to include stricter instructions for generating valid JSON. diff --git a/docs/source/api.rst b/docs/source/api.rst index d1b1066d..7d46b3f0 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -9,17 +9,18 @@ API Documentation Components ********** -KGWriter -======== +DataLoader +========== -.. autoclass:: neo4j_graphrag.experimental.components.kg_writer.KGWriter - :members: run +.. autoclass:: neo4j_graphrag.experimental.components.pdf_loader.DataLoader + :members: run, get_document_metadata -Neo4jWriter -=========== -.. autoclass:: neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter - :members: run +PdfLoader +========= + +.. autoclass:: neo4j_graphrag.experimental.components.pdf_loader.PdfLoader + :members: run, load_file TextSplitter ============ @@ -85,6 +86,17 @@ LLMEntityRelationExtractor .. autoclass:: neo4j_graphrag.experimental.components.entity_relation_extractor.LLMEntityRelationExtractor :members: run +KGWriter +======== + +.. autoclass:: neo4j_graphrag.experimental.components.kg_writer.KGWriter + :members: run + +Neo4jWriter +=========== + +.. autoclass:: neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter + :members: run SinglePropertyExactMatchResolver ================================ @@ -112,6 +124,23 @@ SimpleKGPipeline :members: run_async +************ +Config files +************ + + +SimpleKGPipelineConfig +====================== + +.. autoclass:: neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig + + +PipelineRunner +============== + +.. autoclass:: neo4j_graphrag.experimental.pipeline.config.runner.PipelineRunner + + .. _retrievers-section: ********** diff --git a/docs/source/images/kg_builder_pipeline.png b/docs/source/images/kg_builder_pipeline.png index b8faf2ba..935c2a12 100644 Binary files a/docs/source/images/kg_builder_pipeline.png and b/docs/source/images/kg_builder_pipeline.png differ diff --git a/docs/source/types.rst b/docs/source/types.rst index 6322b29b..253994ad 100644 --- a/docs/source/types.rst +++ b/docs/source/types.rst @@ -82,3 +82,62 @@ SchemaConfig ============ .. autoclass:: neo4j_graphrag.experimental.components.schema.SchemaConfig + +LexicalGraphConfig +=================== + +.. autoclass:: neo4j_graphrag.experimental.components.types.LexicalGraphConfig + + +Neo4jDriverType +=============== + +.. autoclass:: neo4j_graphrag.experimental.pipeline.config.object_config.Neo4jDriverType + + +Neo4jDriverConfig +================= + +.. autoclass:: neo4j_graphrag.experimental.pipeline.config.object_config.Neo4jDriverConfig + + +LLMType +======= + +.. autoclass:: neo4j_graphrag.experimental.pipeline.config.object_config.LLMType + + +LLMConfig +========= + +.. autoclass:: neo4j_graphrag.experimental.pipeline.config.object_config.LLMConfig + + +EmbedderType +============ + +.. autoclass:: neo4j_graphrag.experimental.pipeline.config.object_config.EmbedderType + + +EmbedderConfig +============== + +.. autoclass:: neo4j_graphrag.experimental.pipeline.config.object_config.EmbedderConfig + + +ComponentType +============= + +.. autoclass:: neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType + + +ComponentConfig +=============== + +.. autoclass:: neo4j_graphrag.experimental.pipeline.config.object_config.ComponentConfig + + +ParamFromEnvConfig +================== + +.. autoclass:: neo4j_graphrag.experimental.pipeline.config.param_resolver.ParamFromEnvConfig diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index b758767a..87ce2f42 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -18,11 +18,11 @@ Pipeline Structure A Knowledge Graph (KG) construction pipeline requires a few components (some of the below components are optional): -- **Document parser**: extract text from files (PDFs, ...). -- **Document chunker**: split the text into smaller pieces of text, manageable by the LLM context window (token limit). +- **Data loader**: extract text from files (PDFs, ...). +- **Text splitter**: split the text into smaller pieces of text (chunks), manageable by the LLM context window (token limit). - **Chunk embedder** (optional): compute the chunk embeddings. - **Schema builder**: provide a schema to ground the LLM extracted entities and relations and obtain an easily navigable KG. -- **LexicalGraphBuilder**: build the lexical graph (Document, Chunk and their relationships) (optional). +- **Lexical graph builder**: build the lexical graph (Document, Chunk and their relationships) (optional). - **Entity and relation extractor**: extract relevant entities and relations from the text. - **Knowledge Graph writer**: save the identified entities and relations. - **Entity resolver**: merge similar entities into a single node. @@ -34,7 +34,486 @@ A Knowledge Graph (KG) construction pipeline requires a few components (some of This package contains the interface and implementations for each of these components, which are detailed in the following sections. To see an end-to-end example of a Knowledge Graph construction pipeline, -refer to the `example folder `_ in the project GitHub repository. +visit the `example folder `_ +in the project's GitHub repository. + + +****************** +Simple KG Pipeline +****************** + +The simplest way to begin building a KG from unstructured data using this package +is utilizing the `SimpleKGPipeline` interface: + +.. code:: python + + from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline + + kg_builder = SimpleKGPipeline( + llm=llm, # an LLMInterface for Entity and Relation extraction + driver=neo4j_driver, # a neo4j driver to write results to graph + embedder=embedder, # an Embedder for chunks + from_pdf=True, # set to False if parsing an already extracted text + ) + await kg_builder.run_async(file_path=str(file_path)) + # await kg_builder.run_async(text="my text") # if using from_pdf=False + + +See: + +- :ref:`Using Another LLM Model` to learn how to instantiate the `llm` +- :ref:`Embedders` to learn how to instantiate the `embedder` + + +The following section outlines the configuration parameters for this class. + +Customizing the SimpleKGPipeline +================================ + +Graph Schema +------------ + +It is possible to guide the LLM by supplying a list of entities, relationships, +and instructions on how to connect them. However, note that the extracted graph +may not fully adhere to these guidelines. Entities and relationships can be +represented as either simple strings (for their labels) or dictionaries. If using +a dictionary, it must include a label key and can optionally include description +and properties keys, as shown below: + +.. code:: python + + ENTITIES = [ + # entities can be defined with a simple label... + "Person", + # ... or with a dict if more details are needed, + # such as a description: + {"label": "House", "description": "Family the person belongs to"}, + # or a list of properties the LLM will try to attach to the entity: + {"label": "Planet", "properties": [{"name": "weather", "type": "STRING"}]}, + ] + # same thing for relationships: + RELATIONS = [ + "PARENT_OF", + { + "label": "HEIR_OF", + "description": "Used for inheritor relationship between father and sons", + }, + {"label": "RULES", "properties": [{"name": "fromYear", "type": "INTEGER"}]}, + ] + +The `potential_schema` is defined by a list of triplet in the format: +`(source_node_label, relationship_label, target_node_label)`. For instance: + + +.. code:: python + + POTENTIAL_SCHEMA = [ + ("Person", "PARENT_OF", "Person"), + ("Person", "HEIR_OF", "House"), + ("House", "RULES", "Planet"), + ] + +This schema information can be provided to the `SimpleKGBuilder` as demonstrated below: + +.. code:: python + + kg_builder = SimpleKGPipeline( + # ... + entities=ENTITIES, + relations=RELATIONS, + potential_schema=POTENTIAL_SCHEMA, + # ... + ) + +Prompt Template, Lexical Graph Config and Error Behavior +-------------------------------------------------------- + +These parameters are part of the `EntityAndRelationExtractor` component. +For detailed information, refer to the section on :ref:`Entity and Relation Extractor`. +They are also accessible via the `SimpleKGPipeline` interface. + +.. code:: python + + kg_builder = SimpleKGPipeline( + # ... + prompt_template="", + lexical_graph_config=my_config, + on_error="RAISE", + # ... + ) + +Skip Entity Resolution +---------------------- + +By default, after each run, an Entity Resolution step is performed to merge nodes +that share the same label and name property. To disable this behavior, adjust +the following parameter: + +.. code:: python + + kg_builder = SimpleKGPipeline( + # ... + perform_entity_resolution=False, + # ... + ) + +Neo4j Database +-------------- + +To write to a non-default Neo4j database, specify the database name using this parameter: + +.. code:: python + + kg_builder = SimpleKGPipeline( + # ... + neo4j_database="myDb", + # ... + ) + +Using Custom Components +----------------------- + +For advanced customization or when using a custom implementation, you can pass +instances of specific components to the `SimpleKGPipeline`. The components that can +customized at the moment are: + +- `text_splitter`: must be an instance of :ref:`TextSplitter` +- `pdf_loader`: must be an instance of :ref:`PdfLoader` +- `kg_writer`: must be an instance of :ref:`KGWriter` + +For instance, the following code can be used to customize the chunk size and +chunk overlap in the text splitter component: + +.. code:: python + + from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( + FixedSizeSplitter, + ) + + text_splitter = FixedSizeSplitter(chunk_size=500, chunk_overlap=100) + + kg_builder = SimpleKGPipeline( + # ... + text_splitter=text_splitter, + # ... + ) + + +Using a Config file +=================== + +.. code:: python + + from neo4j_graphrag.experimental.pipeline.config.runner import PipelineRunner + + file_path = "my_config.json" + + pipeline = PipelineRunner.from_config_file(file_path) + await pipeline.run({"text": "my text"}) + + +The config file can be written in either JSON or YAML format. + +Here is an example of a base configuration file in JSON format: + +.. code:: json + + { + "version_": 1, + "template_": "SimpleKGPipeline", + "neo4j_config": {}, + "llm_config": {}, + "embedder_config": {} + } + +And like this in YAML: + +.. code:: yaml + + version_: 1 + template_: SimpleKGPipeline + neo4j_config: + llm_config: + embedder_config: + + +Defining a Neo4j Driver +----------------------- + +Below is an example of configuring a Neo4j driver in a JSON configuration file: + +.. code:: json + + { + "neo4j_config": { + "params_": { + "uri": "bolt://...", + "user": "neo4j", + "password": "password" + } + } + } + +Same for YAML: + +.. code:: yaml + + neo4j_config: + params_: + uri: bolt:// + user: neo4j + password: password + +In some cases, it may be necessary to avoid hard-coding sensitive values, +such as passwords or API keys, to ensure security. To address this, the configuration +parser supports parameter resolution methods. + +Parameter resolution +-------------------- + +To instruct the configuration parser to read a parameter from an environment variable, +use the following syntax: + +.. code:: json + + { + "neo4j_config": { + "params_": { + "uri": "bolt://...", + "user": "neo4j", + "password": { + "resolver_": "ENV", + "var_": "NEO4J_PASSWORD" + } + } + } + } + +And for YAML: + +.. code:: yaml + + neo4j_config: + params_: + uri: bolt:// + user: neo4j + password: + resolver_: ENV + var_: NEO4J_PASSWORD + +- The `resolver_=ENV` key is mandatory and its value cannot be altered. +- The `var_` key specifies the name of the environment variable to be read. + +This syntax can be applied to all parameters. + + +Defining an LLM +---------------- + +Below is an example of configuring an LLM in a JSON configuration file: + +.. code:: json + + { + "llm_config": { + "class_": "OpenAILLM", + "params_": { + "mode_name": "gpt-4o", + "api_key": { + "resolver_": "ENV", + "var_": "OPENAI_API_KEY", + }, + "model_params": { + "temperature": 0, + "max_tokens": 2000, + "response_format": {"type": "json_object"} + } + } + } + } + +And the equivalent YAML: + +.. code:: yaml + + llm_config: + class_: OpenAILLM + params_: + model_name: gpt-4o + api_key: + resolver_: ENV + var_: OPENAI_API_KEY + model_params: + temperature: 0 + max_tokens: 2000 + response_format: + type: json_object + +- The `class_` key specifies the path to the class to be instantiated. +- The `params_` key contains the parameters to be passed to the class constructor. + +When using an LLM implementation provided by this package, the full path in the `class_` key +can be omitted (the parser will automatically import from `neo4j_graphrag.llm`). +For custom implementations, the full path must be explicitly specified, +for example: `my_package.my_llm.MyLLM`. + +Defining an Embedder +-------------------- + +The same principles apply to `embedder_config`: + +.. code:: json + + { + "embedder_config": { + "class_": "OpenAIEmbeddings", + "params_": { + "mode": "text-embedding-ada-002", + "api_key": { + "resolver_": "ENV", + "var_": "OPENAI_API_KEY", + } + } + } + } + +Or the YAML version: + +.. code:: yaml + + embedder_config: + class_: OpenAIEmbeddings + params_: + api_key: + resolver_: ENV + var_: OPENAI_API_KEY + +- For embedder implementations from this package, the full path can be omitted in the `class_` key (the parser will import from `neo4j_graphrag.embeddings`). +- For custom implementations, the full path must be provided, for example: `my_package.my_embedding.MyEmbedding`. + + +Other configuration +------------------- + +The other parameters exposed in the :ref:`SimpleKGPipeline` can also be configured +within the configuration file. + +.. code:: json + + { + "from_pdf": false, + "perform_entity_resolution": true, + "neo4j_database": "myDb", + "on_error": "IGNORE", + "prompt_template": "...", + "entities": [ + "Person", + { + "label": "House", + "description": "Family the person belongs to", + "properties": [ + {"name": "name", "type": "STRING"} + ] + }, + { + "label": "Planet", + "properties": [ + {"name": "name", "type": "STRING"}, + {"name": "weather", "type": "STRING"} + ] + } + ], + "relations": [ + "PARENT_OF", + { + "label": "HEIR_OF", + "description": "Used for inheritor relationship between father and sons" + }, + { + "label": "RULES", + "properties": [ + {"name": "fromYear", "type": "INTEGER"} + ] + } + ], + "potential_schema": [ + ["Person", "PARENT_OF", "Person"], + ["Person", "HEIR_OF", "House"], + ["House", "RULES", "Planet"] + ], + "lexical_graph_config": { + "chunk_node_label": "TextPart" + } + } + + +or in YAML: + +.. code:: yaml + + from_pdf: false + perform_entity_resolution: true + neo4j_database: myDb + on_error: IGNORE + prompt_template: ... + entities: + - label: Person + - label: House + description: Family the person belongs to + properties: + - name: name + type: STRING + - label: Planet + properties: + - name: name + type: STRING + - name: weather + type: STRING + relations: + - label: PARENT_OF + - label: HEIR_OF + description: Used for inheritor relationship between father and sons + - label: RULES + properties: + - name: fromYear + type: INTEGER + potential_schema: + - ["Person", "PARENT_OF", "Person"] + - ["Person", "HEIR_OF", "House"] + - ["House", "RULES", "Planet"] + lexical_graph_config: + chunk_node_label: TextPart + + +It is also possible to further customize components, with a syntax similar to the one +used for `llm_config` or `embedder_config`: + +.. code:: json + + { + "text_splitter": { + "class_": "text_splitters.FixedSizeSplitter", + "params_": { + "chunk_size": 500, + "chunk_overlap": 100 + } + } + + } + +The YAML equivalent: + +.. code:: yaml + + text_splitter: + class_: text_splitters.fixed_size_splitter.FixedSizeSplitter + params_: + chunk_size: 100 + chunk_overlap: 10 + +The `neo4j_graphrag.experimental.components` prefix will be appended automatically +if needed. + ********************************** Knowledge Graph Builder Components @@ -63,10 +542,10 @@ They can also be used within a pipeline: pipeline.add_component(my_component, "component_name") -Document Parser -=============== +Data Loader +============ -Document parsers start from a file path and return the text extracted from this file. +Data loaders start from a file path and return the text extracted from this file. This package currently supports text extraction from PDFs: @@ -92,8 +571,8 @@ To implement your own loader, use the `DataLoader` interface: -Document Splitter -================= +Text Splitter +============== Document splitters, as the name indicate, split documents into smaller chunks that can be processed within the LLM token limits: diff --git a/docs/source/user_guide_rag.rst b/docs/source/user_guide_rag.rst index d939ed85..2b2b528b 100644 --- a/docs/source/user_guide_rag.rst +++ b/docs/source/user_guide_rag.rst @@ -397,7 +397,7 @@ However, in most cases, a text (from the user) will be provided instead of a vec In this scenario, an `Embedder` is required. Search Similar Text ------------------------------ +-------------------- When searching for a text, specifying how the retriever transforms (embeds) the text into a vector is required. Therefore, the retriever requires knowledge of an embedder: @@ -418,7 +418,7 @@ into a vector is required. Therefore, the retriever requires knowledge of an emb Embedders ------------------------------ +--------- Currently, this package supports the following embedders: @@ -485,7 +485,7 @@ using the `return_properties` parameter: Pre-Filters ------------------------------ +----------- When performing a similarity search, one may have constraints to apply. For instance, filtering out movies released before 2000. This can be achieved diff --git a/examples/README.md b/examples/README.md index 95f9443a..2faed5f8 100644 --- a/examples/README.md +++ b/examples/README.md @@ -14,6 +14,7 @@ are listed in [the last section of this file](#customize). - [End to end PDF to graph simple pipeline](build_graph/simple_kg_builder_from_pdf.py) - [End to end text to graph simple pipeline](build_graph/simple_kg_builder_from_text.py) +- [Build KG pipeline from config file](build_graph/from_config_files/simple_kg_pipeline_from_config_file.py) ## Retrieve @@ -93,6 +94,7 @@ are listed in [the last section of this file](#customize). - [End to end example with explicit components and PDF input](./customize/build_graph/pipeline/kg_builder_from_pdf.py) - [Process multiple documents](./customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py) - [Export lexical graph creation into another pipeline](./customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py) +- [Build pipeline from config file](customize/build_graph/pipeline/from_config_files/pipeline_from_config_file.py) #### Components diff --git a/examples/build_graph/from_config_files/simple_kg_pipeline_config.json b/examples/build_graph/from_config_files/simple_kg_pipeline_config.json new file mode 100644 index 00000000..ef251624 --- /dev/null +++ b/examples/build_graph/from_config_files/simple_kg_pipeline_config.json @@ -0,0 +1,112 @@ +{ + "version_": "1", + "template_": "SimpleKGPipeline", + "neo4j_config": { + "params_": { + "uri": { + "resolver_": "ENV", + "var_": "NEO4J_URI" + }, + "user": { + "resolver_": "ENV", + "var_": "NEO4J_USER" + }, + "password": { + "resolver_": "ENV", + "var_": "NEO4J_PASSWORD" + } + } + }, + "llm_config": { + "class_": "OpenAILLM", + "params_": { + "api_key": { + "resolver_": "ENV", + "var_": "OPENAI_API_KEY" + }, + "model_name": "gpt-4o", + "model_params": { + "temperature": 0, + "max_tokens": 2000, + "response_format": {"type": "json_object"} + } + } + }, + "embedder_config": { + "class_": "OpenAIEmbeddings", + "params_": { + "api_key": { + "resolver_": "ENV", + "var_": "OPENAI_API_KEY" + } + } + }, + "from_pdf": false, + "entities": [ + "Person", + { + "label": "House", + "description": "Family the person belongs to", + "properties": [ + { + "name": "name", + "type": "STRING" + } + ] + }, + { + "label": "Planet", + "properties": [ + { + "name": "name", + "type": "STRING" + }, + { + "name": "weather", + "type": "STRING" + } + ] + } + ], + "relations": [ + "PARENT_OF", + { + "label": "HEIR_OF", + "description": "Used for inheritor relationship between father and sons" + }, + { + "label": "RULES", + "properties": [ + { + "name": "fromYear", + "type": "INTEGER" + } + ] + } + ], + "potential_schema": [ + [ + "Person", + "PARENT_OF", + "Person" + ], + [ + "Person", + "HEIR_OF", + "House" + ], + [ + "House", + "RULES", + "Planet" + ] + ], + "text_splitter": { + "class_": "text_splitters.fixed_size_splitter.FixedSizeSplitter", + "params_": { + "chunk_size": 100, + "chunk_overlap": 10 + } + }, + "perform_entity_resolution": true +} diff --git a/examples/build_graph/from_config_files/simple_kg_pipeline_config.yaml b/examples/build_graph/from_config_files/simple_kg_pipeline_config.yaml new file mode 100644 index 00000000..8917e8ca --- /dev/null +++ b/examples/build_graph/from_config_files/simple_kg_pipeline_config.yaml @@ -0,0 +1,63 @@ +version_: "1" +template_: SimpleKGPipeline +neo4j_config: + params_: + uri: + resolver_: ENV + var_: NEO4J_URI + user: + resolver_: ENV + var_: NEO4J_USER + password: + resolver_: ENV + var_: NEO4J_PASSWORD +llm_config: + class_: OpenAILLM + params_: + api_key: + resolver_: ENV + var_: OPENAI_API_KEY + model_name: gpt-4o + model_params: + temperature: 0 + max_tokens: 2000 + response_format: + type: json_object +embedder_config: + class_: OpenAIEmbeddings + params_: + api_key: + resolver_: ENV + var_: OPENAI_API_KEY +from_pdf: false +entities: + - label: Person + - label: House + description: Family the person belongs to + properties: + - name: name + type: STRING + - label: Planet + properties: + - name: name + type: STRING + - name: weather + type: STRING +relations: + - label: PARENT_OF + - label: HEIR_OF + description: Used for inheritor relationship between father and sons + - label: RULES + properties: + - name: fromYear + type: INTEGER +potential_schema: + - ["Person", "PARENT_OF", "Person"] + - ["Person", "HEIR_OF", "House"] + - ["House", "RULES", "Planet"] +text_splitter: + class_: text_splitters.fixed_size_splitter.FixedSizeSplitter + params_: + chunk_size: 100 + chunk_overlap: 10 +perform_entity_resolution: true diff --git a/examples/build_graph/from_config_files/simple_kg_pipeline_from_config_file.py b/examples/build_graph/from_config_files/simple_kg_pipeline_from_config_file.py new file mode 100644 index 00000000..62ba6c85 --- /dev/null +++ b/examples/build_graph/from_config_files/simple_kg_pipeline_from_config_file.py @@ -0,0 +1,47 @@ +"""In this example, the pipeline is defined in a JSON ('simple_kg_pipeline_config.json') +or YAML ('simple_kg_pipeline_config.yaml') file. + +According to the configuration file, some parameters will be read from the env vars +(Neo4j credentials and the OpenAI API key). +""" + +import asyncio +import logging + +## If env vars are in a .env file, uncomment: +## (requires pip install python-dotenv) +# from dotenv import load_dotenv +# load_dotenv() +# env vars manually set for testing: +import os +from pathlib import Path + +from neo4j_graphrag.experimental.pipeline.config.runner import PipelineRunner +from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult + +logging.basicConfig() +logging.getLogger("neo4j_graphrag").setLevel(logging.DEBUG) + +os.environ["NEO4J_URI"] = "bolt://localhost:7687" +os.environ["NEO4J_USER"] = "neo4j" +os.environ["NEO4J_PASSWORD"] = "password" +# os.environ["OPENAI_API_KEY"] = "sk-..." + + +root_dir = Path(__file__).parent +file_path = root_dir / "simple_kg_pipeline_config.yaml" +# file_path = root_dir / "simple_kg_pipeline_config.json" + + +# Text to process +TEXT = """The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House Atreides, +an aristocratic family that rules the planet Caladan, the rainy planet, since 10191.""" + + +async def main() -> PipelineResult: + pipeline = PipelineRunner.from_config_file(file_path) + return await pipeline.run({"text": TEXT}) + + +if __name__ == "__main__": + print(asyncio.run(main())) diff --git a/examples/customize/build_graph/pipeline/from_config_files/pipeline_config.json b/examples/customize/build_graph/pipeline/from_config_files/pipeline_config.json new file mode 100644 index 00000000..ce815412 --- /dev/null +++ b/examples/customize/build_graph/pipeline/from_config_files/pipeline_config.json @@ -0,0 +1,68 @@ +{ + "version_": "1", + "template_": "none", + "name": "", + "neo4j_config": { + "params_": { + "uri": { + "resolver_": "ENV", + "var_": "NEO4J_URI" + }, + "user": { + "resolver_": "ENV", + "var_": "NEO4J_USER" + }, + "password": { + "resolver_": "ENV", + "var_": "NEO4J_PASSWORD" + } + } + }, + "extras": { + "database": "neo4j" + }, + "component_config": { + "splitter": { + "class_": "text_splitters.fixed_size_splitter.FixedSizeSplitter" + }, + "builder": { + "class_": "lexical_graph.LexicalGraphBuilder", + "params_": { + "config": { + "chunk_node_label": "TextPart" + } + } + }, + "writer": { + "name_": "writer", + "class_": "kg_writer.Neo4jWriter", + "params_": { + "driver": { + "resolver_": "CONFIG_KEY", + "key_": "neo4j_config.default" + }, + "neo4j_database": { + "resolver_": "CONFIG_KEY", + "key_": "extras.database" + } + } + } + }, + "connection_config": [ + { + "start": "splitter", + "end": "builder", + "input_config": { + "text_chunks": "splitter" + } + }, + { + "start": "builder", + "end": "writer", + "input_config": { + "graph": "builder.graph", + "lexical_graph_config": "builder.config" + } + } + ] +} diff --git a/examples/customize/build_graph/pipeline/from_config_files/pipeline_config.yaml b/examples/customize/build_graph/pipeline/from_config_files/pipeline_config.yaml new file mode 100644 index 00000000..87ac905e --- /dev/null +++ b/examples/customize/build_graph/pipeline/from_config_files/pipeline_config.yaml @@ -0,0 +1,45 @@ +version_: "1" +template_: none +neo4j_config: + params_: + uri: + resolver_: ENV + var_: NEO4J_URI + user: + resolver_: ENV + var_: NEO4J_USER + password: + resolver_: ENV + var_: NEO4J_PASSWORD +extras: + database: neo4j +component_config: + splitter: + class_: text_splitters.fixed_size_splitter.FixedSizeSplitter + params_: + chunk_size: 100 + chunk_overlap: 10 + builder: + class_: lexical_graph.LexicalGraphBuilder + params_: + config: + chunk_node_label: TextPart + writer: + class_: kg_writer.Neo4jWriter + params_: + driver: + resolver_: CONFIG_KEY + key_: neo4j_config.default + neo4j_database: + resolver_: CONFIG_KEY + key_: extras.database +connection_config: + - start: splitter + end: builder + input_config: + text_chunks: splitter + - start: builder + end: writer + input_config: + graph: builder.graph + lexical_graph_config: builder.config diff --git a/examples/customize/build_graph/pipeline/from_config_files/pipeline_from_config_file.py b/examples/customize/build_graph/pipeline/from_config_files/pipeline_from_config_file.py new file mode 100644 index 00000000..9a2a1680 --- /dev/null +++ b/examples/customize/build_graph/pipeline/from_config_files/pipeline_from_config_file.py @@ -0,0 +1,40 @@ +"""In this example, the pipeline is defined in a JSON file 'pipeline_config.json'. +According to the configuration file, some parameters will be read from the env vars +(Neo4j credentials and the OpenAI API key). +""" + +import asyncio + +## If env vars are in a .env file, uncomment: +## (requires pip install python-dotenv) +# from dotenv import load_dotenv +# load_dotenv() +# env vars manually set for testing: +import os +from pathlib import Path + +from neo4j_graphrag.experimental.pipeline.config.runner import PipelineRunner +from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult + +os.environ["NEO4J_URI"] = "bolt://localhost:7687" +os.environ["NEO4J_USER"] = "neo4j" +os.environ["NEO4J_PASSWORD"] = "password" +# os.environ["OPENAI_API_KEY"] = "sk-..." + + +root_dir = Path(__file__).parent +# file_path = root_dir / "pipeline_config.json" +file_path = root_dir / "pipeline_config.yaml" + +# Text to process +TEXT = """The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House Atreides, +an aristocratic family that rules the planet Caladan, the rainy planet, since 10191.""" + + +async def main() -> PipelineResult: + pipeline = PipelineRunner.from_config_file(file_path) + return await pipeline.run({"splitter": {"text": TEXT}}) + + +if __name__ == "__main__": + print(asyncio.run(main())) diff --git a/poetry.lock b/poetry.lock index e7bf3849..a662e717 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1905,18 +1905,18 @@ files = [ [[package]] name = "langchain-core" -version = "0.3.23" +version = "0.3.24" description = "Building applications with LLMs through composability" optional = true python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_core-0.3.23-py3-none-any.whl", hash = "sha256:550c0b996990830fa6515a71a1192a8a0343367999afc36d4ede14222941e420"}, - {file = "langchain_core-0.3.23.tar.gz", hash = "sha256:f9e175e3b82063cc3b160c2ca2b155832e1c6f915312e1204828f97d4aabf6e1"}, + {file = "langchain_core-0.3.24-py3-none-any.whl", hash = "sha256:97192552ef882a3dd6ae3b870a180a743801d0137a1159173f51ac555eeb7eec"}, + {file = "langchain_core-0.3.24.tar.gz", hash = "sha256:460851e8145327f70b70aad7dce2cdbd285e144d14af82b677256b941fc99656"}, ] [package.dependencies] jsonpatch = ">=1.33,<2.0" -langsmith = ">=0.1.125,<0.2.0" +langsmith = ">=0.1.125,<0.3" packaging = ">=23.2,<25" pydantic = [ {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""}, @@ -1976,13 +1976,13 @@ langchain-core = ">=0.3.15,<0.4.0" [[package]] name = "langsmith" -version = "0.1.147" +version = "0.2.2" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = true -python-versions = "<4.0,>=3.8.1" +python-versions = "<4.0,>=3.9" files = [ - {file = "langsmith-0.1.147-py3-none-any.whl", hash = "sha256:7166fc23b965ccf839d64945a78e9f1157757add228b086141eb03a60d699a15"}, - {file = "langsmith-0.1.147.tar.gz", hash = "sha256:2e933220318a4e73034657103b3b1a3a6109cc5db3566a7e8e03be8d6d7def7a"}, + {file = "langsmith-0.2.2-py3-none-any.whl", hash = "sha256:4786d7dcdbc25e43d4a1bf70bbe12938a9eb2364feec8f6fc4d967162519b367"}, + {file = "langsmith-0.2.2.tar.gz", hash = "sha256:6f515ee41ae80968a7d552be1154414ccde57a0a534c960c8c3cd1835734095f"}, ] [package.dependencies] @@ -2866,13 +2866,13 @@ files = [ [[package]] name = "openai" -version = "1.57.1" +version = "1.57.2" description = "The official Python library for the openai API" optional = true python-versions = ">=3.8" files = [ - {file = "openai-1.57.1-py3-none-any.whl", hash = "sha256:3865686c927e93492d1145938d4a24b634951531c4b2769d43ca5dbd4b25d8fd"}, - {file = "openai-1.57.1.tar.gz", hash = "sha256:a95f22e04ab3df26e64a15d958342265e802314131275908b3b3e36f8c5d4377"}, + {file = "openai-1.57.2-py3-none-any.whl", hash = "sha256:f7326283c156fdee875746e7e54d36959fb198eadc683952ee05e3302fbd638d"}, + {file = "openai-1.57.2.tar.gz", hash = "sha256:5f49fd0f38e9f2131cda7deb45dafdd1aee4f52a637e190ce0ecf40147ce8cee"}, ] [package.dependencies] @@ -4928,6 +4928,17 @@ build = ["cmake (>=3.20)", "lit"] tests = ["autopep8", "flake8", "isort", "llnl-hatchet", "numpy", "pytest", "scipy (>=1.7.1)"] tutorials = ["matplotlib", "pandas", "tabulate"] +[[package]] +name = "types-pyyaml" +version = "6.0.12.20240917" +description = "Typing stubs for PyYAML" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-PyYAML-6.0.12.20240917.tar.gz", hash = "sha256:d1405a86f9576682234ef83bcb4e6fff7c9305c8b1fbad5e0bcd4f7dbdc9c587"}, + {file = "types_PyYAML-6.0.12.20240917-py3-none-any.whl", hash = "sha256:392b267f1c0fe6022952462bf5d6523f31e37f6cea49b14cee7ad634b6301570"}, +] + [[package]] name = "types-requests" version = "2.31.0.6" @@ -5267,4 +5278,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = "^3.9.0" -content-hash = "eb706cb5d3d47a4d3ac32bc7f228bda27de77c87398999d1e610552ff19af93f" +content-hash = "5f729b5f7f31021258d04fcf26e2310f685f1b97113e888ba346df5c7393d4e4" diff --git a/pyproject.toml b/pyproject.toml index 6007709d..604fde99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ openai = {version = "^1.51.1", optional = true } anthropic = { version = "^0.36.0", optional = true} sentence-transformers = {version = "^3.0.0", optional = true } json-repair = "^0.30.2" +types-pyyaml = "^6.0.12.20240917" [tool.poetry.group.dev.dependencies] urllib3 = "<2" diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index 225b4c0f..d4070aea 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -23,7 +23,6 @@ from typing import Any, List, Optional, Union, cast import json_repair - from pydantic import ValidationError, validate_call from neo4j_graphrag.exceptions import LLMGenerationError diff --git a/src/neo4j_graphrag/experimental/components/lexical_graph.py b/src/neo4j_graphrag/experimental/components/lexical_graph.py index 92681a8b..ce96b9fd 100644 --- a/src/neo4j_graphrag/experimental/components/lexical_graph.py +++ b/src/neo4j_graphrag/experimental/components/lexical_graph.py @@ -42,6 +42,7 @@ class LexicalGraphBuilder(Component): - A relationship between a chunk and the next one in the document """ + @validate_call def __init__( self, config: LexicalGraphConfig = LexicalGraphConfig(), diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 64e908ed..0f118545 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -17,9 +17,14 @@ from typing import Any, Dict, List, Literal, Optional, Tuple from pydantic import BaseModel, ValidationError, model_validator, validate_call +from typing_extensions import Self from neo4j_graphrag.exceptions import SchemaValidationError from neo4j_graphrag.experimental.pipeline.component import Component, DataModel +from neo4j_graphrag.experimental.pipeline.types import ( + EntityInputType, + RelationInputType, +) class SchemaProperty(BaseModel): @@ -55,6 +60,14 @@ class SchemaEntity(BaseModel): description: str = "" properties: List[SchemaProperty] = [] + @classmethod + def from_text_or_dict(cls, input: EntityInputType) -> Self: + if isinstance(input, SchemaEntity): + return input + if isinstance(input, str): + return cls(label=input) + return cls.model_validate(input) + class SchemaRelation(BaseModel): """ @@ -65,6 +78,14 @@ class SchemaRelation(BaseModel): description: str = "" properties: List[SchemaProperty] = [] + @classmethod + def from_text_or_dict(cls, input: RelationInputType) -> Self: + if isinstance(input, SchemaRelation): + return input + if isinstance(input, str): + return cls(label=input) + return cls.model_validate(input) + class SchemaConfig(DataModel): """ diff --git a/src/neo4j_graphrag/experimental/pipeline/config/__init__.py b/src/neo4j_graphrag/experimental/pipeline/config/__init__.py new file mode 100644 index 00000000..c0199c14 --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/neo4j_graphrag/experimental/pipeline/config/base.py b/src/neo4j_graphrag/experimental/pipeline/config/base.py new file mode 100644 index 00000000..665a56d0 --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/base.py @@ -0,0 +1,62 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Abstract class for all pipeline configs.""" + +from __future__ import annotations + +import logging +from typing import Any + +from pydantic import BaseModel, PrivateAttr + +from neo4j_graphrag.experimental.pipeline.config.param_resolver import ( + ParamConfig, + ParamToResolveConfig, +) + +logger = logging.getLogger(__name__) + + +class AbstractConfig(BaseModel): + """Base class for all configs. + Provides methods to get a class from a string and resolve a parameter defined by + a dict with a 'resolver_' key. + + Each subclass must implement a 'parse' method that returns the relevant object. + """ + + _global_data: dict[str, Any] = PrivateAttr({}) + """Additional parameter ignored by all Pydantic model_* methods.""" + + def resolve_param(self, param: ParamConfig) -> Any: + """Finds the parameter value from its definition.""" + if not isinstance(param, ParamToResolveConfig): + # some parameters do not have to be resolved, real + # values are already provided + return param + return param.resolve(self._global_data) + + def resolve_params(self, params: dict[str, ParamConfig]) -> dict[str, Any]: + """Resolve all parameters + + Returning dict[str, Any] because parameters can be anything (str, float, list, dict...) + """ + return { + param_name: self.resolve_param(param) + for param_name, param in params.items() + } + + def parse(self, resolved_data: dict[str, Any] | None = None) -> Any: + raise NotImplementedError() diff --git a/src/neo4j_graphrag/experimental/pipeline/config/config_reader.py b/src/neo4j_graphrag/experimental/pipeline/config/config_reader.py new file mode 100644 index 00000000..c3df28d9 --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/config_reader.py @@ -0,0 +1,85 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Read JSON or YAML files and returns a dict. +No data validation performed at this stage. +""" + +import json +import logging +from pathlib import Path +from typing import Any, Optional + +import fsspec +import yaml +from fsspec.implementations.local import LocalFileSystem + +logger = logging.getLogger(__name__) + + +class ConfigReader: + """Reads config from a file (JSON or YAML format) + and returns a dict. + + File format is guessed from the extension. Supported extensions are + (lower or upper case): + + - .json + - .yaml, .yml + + Example: + + .. code-block:: python + + from pathlib import Path + from neo4j_graphrag.experimental.pipeline.config.reader import ConfigReader + reader = ConfigReader() + reader.read(Path("my_file.json")) + + If reading a file with a different extension but still in JSON or YAML format, + it is possible to call directly the `read_json` or `read_yaml` methods: + + .. code-block:: python + + reader.read_yaml(Path("my_file.txt")) + + """ + + def __init__(self, fs: Optional[fsspec.AbstractFileSystem] = None) -> None: + self.fs = fs or LocalFileSystem() + + def read_json(self, file_path: str) -> Any: + logger.debug(f"CONFIG_READER: read from json {file_path}") + with self.fs.open(file_path, "r") as f: + return json.load(f) + + def read_yaml(self, file_path: str) -> Any: + logger.debug(f"CONFIG_READER: read from yaml {file_path}") + with self.fs.open(file_path, "r") as f: + return yaml.safe_load(f) + + def _guess_format_and_read(self, file_path: str) -> dict[str, Any]: + p = Path(file_path) + extension = p.suffix.lower() + # Note: .suffix returns an empty string if Path has no extension + # if not returning a dict, parsing will fail later on + if extension in [".json"]: + return self.read_json(file_path) # type: ignore[no-any-return] + if extension in [".yaml", ".yml"]: + return self.read_yaml(file_path) # type: ignore[no-any-return] + raise ValueError(f"Unsupported extension: {extension}") + + def read(self, file_path: str) -> dict[str, Any]: + data = self._guess_format_and_read(file_path) + return data diff --git a/src/neo4j_graphrag/experimental/pipeline/config/object_config.py b/src/neo4j_graphrag/experimental/pipeline/config/object_config.py new file mode 100644 index 00000000..cb47b380 --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/object_config.py @@ -0,0 +1,266 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Config for all parameters that can be both provided as object instance or +config dict with 'class_' and 'params_' keys. + +Nomenclature in this file: + +- `*Config` models are used to represent "things" as dict to be used in a config file. + e.g.: + - neo4j.Driver => {"uri": "", "user": "", "password": ""} + - LLMInterface => {"class_": "OpenAI", "params_": {"model_name": "gpt-4o"}} +- `*Type` models are wrappers around an object and a 'Config' the object can be created + from. They are used to allow the instantiation of "PipelineConfig" either from + instantiated objects (when used in code) and from a config dict (when used to + load config from file). +""" + +from __future__ import annotations + +import importlib +import logging +from typing import ( + Any, + ClassVar, + Generic, + Optional, + TypeVar, + Union, + cast, +) + +import neo4j +from pydantic import ( + ConfigDict, + Field, + RootModel, + field_validator, +) + +from neo4j_graphrag.embeddings import Embedder +from neo4j_graphrag.experimental.pipeline import Component +from neo4j_graphrag.experimental.pipeline.config.base import AbstractConfig +from neo4j_graphrag.experimental.pipeline.config.param_resolver import ( + ParamConfig, +) +from neo4j_graphrag.llm import LLMInterface + +logger = logging.getLogger(__name__) + + +T = TypeVar("T") +"""Generic type to help mypy with the parse method when we know the exact +expected return type (e.g. for the Neo4jDriverConfig below). +""" + + +class ObjectConfig(AbstractConfig, Generic[T]): + """A config class to represent an object from a class name + and its constructor parameters. + """ + + class_: str | None = Field(default=None, validate_default=True) + """Path to class to be instantiated.""" + params_: dict[str, ParamConfig] = {} + """Initialization parameters.""" + + DEFAULT_MODULE: ClassVar[str] = "." + """Default module to import the class from.""" + INTERFACE: ClassVar[type] = object + """Constraint on the class (must be a subclass of).""" + REQUIRED_PARAMS: ClassVar[list[str]] = [] + """List of required parameters for this object constructor.""" + + @field_validator("params_") + @classmethod + def validate_params(cls, params_: dict[str, Any]) -> dict[str, Any]: + """Make sure all required parameters are provided.""" + for p in cls.REQUIRED_PARAMS: + if p not in params_: + raise ValueError(f"Missing parameter {p}") + return params_ + + def get_module(self) -> str: + return self.DEFAULT_MODULE + + def get_interface(self) -> type: + return self.INTERFACE + + @classmethod + def _get_class(cls, class_path: str, optional_module: Optional[str] = None) -> type: + """Get class from string and an optional module + + Will first try to import the class from `class_path` alone. If it results in an ImportError, + will try to import from `f'{optional_module}.{class_path}'` + + Args: + class_path (str): Class path with format 'my_module.MyClass'. + optional_module (Optional[str]): Optional module path. Used to provide a default path for some known objects and simplify the notation. + + Raises: + ValueError: if the class can't be imported, even using the optional module. + """ + *modules, class_name = class_path.rsplit(".", 1) + module_name = modules[0] if modules else optional_module + if module_name is None: + raise ValueError("Must specify a module to import class from") + try: + module = importlib.import_module(module_name) + klass = getattr(module, class_name) + except (ImportError, AttributeError): + if optional_module and module_name != optional_module: + full_klass_path = optional_module + "." + class_path + return cls._get_class(full_klass_path) + raise ValueError(f"Could not find {class_name} in {module_name}") + return cast(type, klass) + + def parse(self, resolved_data: dict[str, Any] | None = None) -> T: + """Import `class_`, resolve `params_` and instantiate object.""" + self._global_data = resolved_data or {} + logger.debug(f"OBJECT_CONFIG: parsing {self} using {resolved_data}") + if self.class_ is None: + raise ValueError(f"`class_` is not required to parse object {self}") + klass = self._get_class(self.class_, self.get_module()) + if not issubclass(klass, self.get_interface()): + raise ValueError( + f"Invalid class '{klass}'. Expected a subclass of '{self.get_interface()}'" + ) + params = self.resolve_params(self.params_) + try: + obj = klass(**params) + except TypeError as e: + logger.error( + "OBJECT_CONFIG: failed to instantiate object due to improperly configured parameters" + ) + raise e + return cast(T, obj) + + +class Neo4jDriverConfig(ObjectConfig[neo4j.Driver]): + REQUIRED_PARAMS = ["uri", "user", "password"] + + @field_validator("class_", mode="before") + @classmethod + def validate_class(cls, class_: Any) -> str: + """`class_` parameter is not used because we're always using the sync driver.""" + if class_: + logger.info("Parameter class_ is not used for Neo4jDriverConfig") + # not used + return "not used" + + def parse(self, resolved_data: dict[str, Any] | None = None) -> neo4j.Driver: + params = self.resolve_params(self.params_) + # we know these params are there because of the required params validator + uri = params.pop("uri") + user = params.pop("user") + password = params.pop("password") + driver = neo4j.GraphDatabase.driver(uri, auth=(user, password), **params) + return driver + + +# note: using the notation with RootModel + root: field +# instead of RootModel[] for clarity +# but this requires the type: ignore comment below +class Neo4jDriverType(RootModel): # type: ignore[type-arg] + """A model to wrap neo4j.Driver and Neo4jDriverConfig objects. + + The `parse` method always returns a neo4j.Driver. + """ + + root: Union[neo4j.Driver, Neo4jDriverConfig] + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def parse(self, resolved_data: dict[str, Any] | None = None) -> neo4j.Driver: + if isinstance(self.root, neo4j.Driver): + return self.root + # self.root is a Neo4jDriverConfig object + return self.root.parse(resolved_data) + + +class LLMConfig(ObjectConfig[LLMInterface]): + """Configuration for any LLMInterface object. + + By default, will try to import from `neo4j_graphrag.llm`. + """ + + DEFAULT_MODULE = "neo4j_graphrag.llm" + INTERFACE = LLMInterface + + +class LLMType(RootModel): # type: ignore[type-arg] + """A model to wrap LLMInterface and LLMConfig objects. + + The `parse` method always returns an object inheriting from LLMInterface. + """ + + root: Union[LLMInterface, LLMConfig] + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def parse(self, resolved_data: dict[str, Any] | None = None) -> LLMInterface: + if isinstance(self.root, LLMInterface): + return self.root + return self.root.parse(resolved_data) + + +class EmbedderConfig(ObjectConfig[Embedder]): + """Configuration for any Embedder object. + + By default, will try to import from `neo4j_graphrag.embeddings`. + """ + + DEFAULT_MODULE = "neo4j_graphrag.embeddings" + INTERFACE = Embedder + + +class EmbedderType(RootModel): # type: ignore[type-arg] + """A model to wrap Embedder and EmbedderConfig objects. + + The `parse` method always returns an object inheriting from Embedder. + """ + + root: Union[Embedder, EmbedderConfig] + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def parse(self, resolved_data: dict[str, Any] | None = None) -> Embedder: + if isinstance(self.root, Embedder): + return self.root + return self.root.parse(resolved_data) + + +class ComponentConfig(ObjectConfig[Component]): + """A config model for all components. + + In addition to the object config, components can have pre-defined parameters + that will be passed to the `run` method, ie `run_params_`. + """ + + run_params_: dict[str, ParamConfig] = {} + + DEFAULT_MODULE = "neo4j_graphrag.experimental.components" + INTERFACE = Component + + +class ComponentType(RootModel): # type: ignore[type-arg] + root: Union[Component, ComponentConfig] + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def parse(self, resolved_data: dict[str, Any] | None = None) -> Component: + if isinstance(self.root, Component): + return self.root + return self.root.parse(resolved_data) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/param_resolver.py b/src/neo4j_graphrag/experimental/pipeline/config/param_resolver.py new file mode 100644 index 00000000..24190add --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/param_resolver.py @@ -0,0 +1,60 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import os +from typing import Any, ClassVar, Literal, Union + +from pydantic import BaseModel + + +class ParamResolverEnum(str, enum.Enum): + ENV = "ENV" + CONFIG_KEY = "CONFIG_KEY" + + +class ParamToResolveConfig(BaseModel): + def resolve(self, data: dict[str, Any]) -> Any: + raise NotImplementedError + + +class ParamFromEnvConfig(ParamToResolveConfig): + resolver_: Literal[ParamResolverEnum.ENV] = ParamResolverEnum.ENV + var_: str + + def resolve(self, data: dict[str, Any]) -> Any: + return os.environ.get(self.var_) + + +class ParamFromKeyConfig(ParamToResolveConfig): + resolver_: Literal[ParamResolverEnum.CONFIG_KEY] = ParamResolverEnum.CONFIG_KEY + key_: str + + KEY_SEP: ClassVar[str] = "." + + def resolve(self, data: dict[str, Any]) -> Any: + d = data + for k in self.key_.split(self.KEY_SEP): + d = d[k] + return d + + +ParamConfig = Union[ + float, + str, + ParamFromEnvConfig, + ParamFromKeyConfig, + dict[str, Any], +] diff --git a/src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py b/src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py new file mode 100644 index 00000000..c3871179 --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py @@ -0,0 +1,199 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, ClassVar, Literal, Optional, Union + +import neo4j +from pydantic import field_validator + +from neo4j_graphrag.embeddings import Embedder +from neo4j_graphrag.experimental.pipeline.config.base import AbstractConfig +from neo4j_graphrag.experimental.pipeline.config.object_config import ( + ComponentType, + EmbedderType, + LLMType, + Neo4jDriverType, +) +from neo4j_graphrag.experimental.pipeline.config.types import PipelineType +from neo4j_graphrag.experimental.pipeline.types import ( + ComponentDefinition, + ConnectionDefinition, + PipelineDefinition, +) +from neo4j_graphrag.llm import LLMInterface + +logger = logging.getLogger(__name__) + + +class AbstractPipelineConfig(AbstractConfig): + """This class defines the fields possibly used by all pipelines: neo4j drivers, LLMs... + neo4j_config, llm_config can be provided by user as a single item or a dict of items. + Validators deal with type conversion so that the field in all instances is a dict of items. + """ + + neo4j_config: dict[str, Neo4jDriverType] = {} + llm_config: dict[str, LLMType] = {} + embedder_config: dict[str, EmbedderType] = {} + # extra parameters values that can be used in different places of the config file + extras: dict[str, Any] = {} + + DEFAULT_NAME: ClassVar[str] = "default" + """Name of the default item in dict + """ + + @field_validator("neo4j_config", mode="before") + @classmethod + def validate_drivers( + cls, drivers: Union[Neo4jDriverType, dict[str, Any]] + ) -> dict[str, Any]: + if not isinstance(drivers, dict) or "params_" in drivers: + return {cls.DEFAULT_NAME: drivers} + return drivers + + @field_validator("llm_config", mode="before") + @classmethod + def validate_llms(cls, llms: Union[LLMType, dict[str, Any]]) -> dict[str, Any]: + if not isinstance(llms, dict) or "class_" in llms: + return {cls.DEFAULT_NAME: llms} + return llms + + @field_validator("embedder_config", mode="before") + @classmethod + def validate_embedders( + cls, embedders: Union[EmbedderType, dict[str, Any]] + ) -> dict[str, Any]: + if not isinstance(embedders, dict) or "class_" in embedders: + return {cls.DEFAULT_NAME: embedders} + return embedders + + def _resolve_component_definition( + self, name: str, config: ComponentType + ) -> ComponentDefinition: + component = config.parse(self._global_data) + if hasattr(config.root, "run_params_"): + component_run_params = self.resolve_params(config.root.run_params_) + else: + component_run_params = {} + component_def = ComponentDefinition( + name=name, + component=component, + run_params=component_run_params, + ) + logger.debug(f"PIPELINE_CONFIG: resolved component {component_def}") + return component_def + + def _parse_global_data(self) -> dict[str, Any]: + """Global data contains data that can be referenced in other parts of the + config. + + Typically, neo4j drivers, LLMs and embedders can be referenced in component + input parameters. + """ + # 'extras' parameters can be referenced in other configs, + # that's why they are parsed before the others + # e.g., an API key used for both LLM and Embedder can be stored only + # once in extras. + extra_data = { + "extras": self.resolve_params(self.extras), + } + logger.debug(f"PIPELINE_CONFIG: resolved 'extras': {extra_data}") + drivers: dict[str, neo4j.Driver] = { + driver_name: driver_config.parse(extra_data) + for driver_name, driver_config in self.neo4j_config.items() + } + llms: dict[str, LLMInterface] = { + llm_name: llm_config.parse(extra_data) + for llm_name, llm_config in self.llm_config.items() + } + embedders: dict[str, Embedder] = { + embedder_name: embedder_config.parse(extra_data) + for embedder_name, embedder_config in self.embedder_config.items() + } + global_data = { + **extra_data, + "neo4j_config": drivers, + "llm_config": llms, + "embedder_config": embedders, + } + logger.debug(f"PIPELINE_CONFIG: resolved globals: {global_data}") + return global_data + + def _get_components(self) -> list[ComponentDefinition]: + return [] + + def _get_connections(self) -> list[ConnectionDefinition]: + return [] + + def parse( + self, resolved_data: Optional[dict[str, Any]] = None + ) -> PipelineDefinition: + """Parse the full config and returns a PipelineDefinition object, containing instantiated + components and a list of connections. + """ + self._global_data = self._parse_global_data() + return PipelineDefinition( + components=self._get_components(), + connections=self._get_connections(), + ) + + def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: + return user_input + + async def close(self) -> None: + drivers = self._global_data.get("neo4j_config", {}) + for driver_name in drivers: + driver = drivers[driver_name] + logger.debug(f"PIPELINE_CONFIG: closing driver {driver_name}: {driver}") + driver.close() + + def get_neo4j_driver_by_name(self, name: str) -> neo4j.Driver: + drivers: dict[str, neo4j.Driver] = self._global_data.get("neo4j_config", {}) + return drivers[name] + + def get_default_neo4j_driver(self) -> neo4j.Driver: + return self.get_neo4j_driver_by_name(self.DEFAULT_NAME) + + def get_llm_by_name(self, name: str) -> LLMInterface: + llms: dict[str, LLMInterface] = self._global_data.get("llm_config", {}) + return llms[name] + + def get_default_llm(self) -> LLMInterface: + return self.get_llm_by_name(self.DEFAULT_NAME) + + def get_embedder_by_name(self, name: str) -> Embedder: + embedders: dict[str, Embedder] = self._global_data.get("embedder_config", {}) + return embedders[name] + + def get_default_embedder(self) -> Embedder: + return self.get_embedder_by_name(self.DEFAULT_NAME) + + +class PipelineConfig(AbstractPipelineConfig): + """Configuration class for raw pipelines. + Config must contain all components and connections.""" + + component_config: dict[str, ComponentType] + connection_config: list[ConnectionDefinition] + template_: Literal[PipelineType.NONE] = PipelineType.NONE + + def _get_connections(self) -> list[ConnectionDefinition]: + return self.connection_config + + def _get_components(self) -> list[ComponentDefinition]: + return [ + self._resolve_component_definition(name, component_config) + for name, component_config in self.component_config.items() + ] diff --git a/src/neo4j_graphrag/experimental/pipeline/config/runner.py b/src/neo4j_graphrag/experimental/pipeline/config/runner.py new file mode 100644 index 00000000..a1a22585 --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/runner.py @@ -0,0 +1,132 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pipeline config wrapper (router based on 'template_' key) +and pipeline runner. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import ( + Annotated, + Any, + Union, +) + +from pydantic import ( + BaseModel, + Discriminator, + Field, + Tag, +) +from pydantic.v1.utils import deep_update +from typing_extensions import Self + +from neo4j_graphrag.experimental.pipeline import Pipeline +from neo4j_graphrag.experimental.pipeline.config.config_reader import ConfigReader +from neo4j_graphrag.experimental.pipeline.config.pipeline_config import ( + AbstractPipelineConfig, + PipelineConfig, +) +from neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder import ( + SimpleKGPipelineConfig, +) +from neo4j_graphrag.experimental.pipeline.config.types import PipelineType +from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult +from neo4j_graphrag.experimental.pipeline.types import PipelineDefinition + +logger = logging.getLogger(__name__) + + +def _get_discriminator_value(model: Any) -> PipelineType: + template_ = None + if "template_" in model: + template_ = model["template_"] + if hasattr(model, "template_"): + template_ = model.template_ + return PipelineType(template_) or PipelineType.NONE + + +class PipelineConfigWrapper(BaseModel): + """The pipeline config wrapper will parse the right pipeline config based on the `template_` field.""" + + config: Union[ + Annotated[PipelineConfig, Tag(PipelineType.NONE)], + Annotated[SimpleKGPipelineConfig, Tag(PipelineType.SIMPLE_KG_PIPELINE)], + ] = Field(discriminator=Discriminator(_get_discriminator_value)) + + def parse(self, resolved_data: dict[str, Any] | None = None) -> PipelineDefinition: + return self.config.parse(resolved_data) + + def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: + return self.config.get_run_params(user_input) + + +class PipelineRunner: + """Pipeline runner builds a pipeline from different objects and exposes a run method to run pipeline + + Pipeline can be built from: + - A PipelineDefinition (`__init__` method) + - A PipelineConfig (`from_config` method) + - A config file (`from_config_file` method) + """ + + def __init__( + self, + pipeline_definition: PipelineDefinition, + config: AbstractPipelineConfig | None = None, + do_cleaning: bool = False, + ) -> None: + self.config = config + self.pipeline = Pipeline.from_definition(pipeline_definition) + self.run_params = pipeline_definition.get_run_params() + self.do_cleaning = do_cleaning + + @classmethod + def from_config( + cls, config: AbstractPipelineConfig | dict[str, Any], do_cleaning: bool = False + ) -> Self: + wrapper = PipelineConfigWrapper.model_validate({"config": config}) + return cls(wrapper.parse(), config=wrapper.config, do_cleaning=do_cleaning) + + @classmethod + def from_config_file(cls, file_path: Union[str, Path]) -> Self: + if not isinstance(file_path, str): + file_path = str(file_path) + data = ConfigReader().read(file_path) + return cls.from_config(data, do_cleaning=True) + + async def run(self, user_input: dict[str, Any]) -> PipelineResult: + # pipeline_conditional_run_params = self. + if self.config: + run_param = deep_update( + self.run_params, self.config.get_run_params(user_input) + ) + else: + run_param = deep_update(self.run_params, user_input) + logger.info( + f"PIPELINE_RUNNER: starting pipeline {self.pipeline} with run_params={run_param}" + ) + result = await self.pipeline.run(data=run_param) + if self.do_cleaning: + await self.close() + return result + + async def close(self) -> None: + logger.debug("PIPELINE_RUNNER: cleaning up (closing instantiated drivers...)") + if self.config: + await self.config.close() diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/__init__.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/__init__.py new file mode 100644 index 00000000..125a1c87 --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .simple_kg_builder import SimpleKGPipelineConfig + +__all__ = [ + "SimpleKGPipelineConfig", +] diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/base.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/base.py new file mode 100644 index 00000000..69fbc751 --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/base.py @@ -0,0 +1,63 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Any, ClassVar, Optional + +from neo4j_graphrag.experimental.pipeline.config.pipeline_config import ( + AbstractPipelineConfig, +) +from neo4j_graphrag.experimental.pipeline.types import ComponentDefinition + +logger = logging.getLogger(__name__) + + +class TemplatePipelineConfig(AbstractPipelineConfig): + """This class represent a 'template' pipeline, ie pipeline with pre-defined default + components and fixed connections. + + Component names are defined in the COMPONENTS class var. For each of them, + a `_get_` method must be implemented that returns the proper + component. Optionally, `_get_run_params_for_` can be implemented + to deal with parameters required by the component's run method and predefined on + template initialization. + """ + + COMPONENTS: ClassVar[list[str]] = [] + + def _get_component(self, component_name: str) -> Optional[ComponentDefinition]: + method = getattr(self, f"_get_{component_name}") + component = method() + if component is None: + return None + method = getattr(self, f"_get_run_params_for_{component_name}", None) + run_params = method() if method else {} + component_definition = ComponentDefinition( + name=component_name, + component=component, + run_params=run_params, + ) + logger.debug(f"TEMPLATE_PIPELINE: resolved component {component_definition}") + return component_definition + + def _get_components(self) -> list[ComponentDefinition]: + components = [] + for component_name in self.COMPONENTS: + comp = self._get_component(component_name) + if comp is not None: + components.append(comp) + return components + + def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: + return {} diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py new file mode 100644 index 00000000..73edfd9a --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -0,0 +1,228 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, ClassVar, Literal, Optional, Sequence, Union + +from pydantic import ConfigDict + +from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder +from neo4j_graphrag.experimental.components.entity_relation_extractor import ( + EntityRelationExtractor, + LLMEntityRelationExtractor, + OnError, +) +from neo4j_graphrag.experimental.components.kg_writer import KGWriter, Neo4jWriter +from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader +from neo4j_graphrag.experimental.components.resolver import ( + EntityResolver, + SinglePropertyExactMatchResolver, +) +from neo4j_graphrag.experimental.components.schema import ( + SchemaBuilder, + SchemaEntity, + SchemaRelation, +) +from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter +from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( + FixedSizeSplitter, +) +from neo4j_graphrag.experimental.components.types import LexicalGraphConfig +from neo4j_graphrag.experimental.pipeline.config.object_config import ComponentType +from neo4j_graphrag.experimental.pipeline.config.template_pipeline.base import ( + TemplatePipelineConfig, +) +from neo4j_graphrag.experimental.pipeline.config.types import PipelineType +from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError +from neo4j_graphrag.experimental.pipeline.types import ( + ConnectionDefinition, + EntityInputType, + RelationInputType, +) +from neo4j_graphrag.generation.prompts import ERExtractionTemplate + + +class SimpleKGPipelineConfig(TemplatePipelineConfig): + COMPONENTS: ClassVar[list[str]] = [ + "pdf_loader", + "splitter", + "chunk_embedder", + "schema", + "extractor", + "writer", + "resolver", + ] + + template_: Literal[PipelineType.SIMPLE_KG_PIPELINE] = ( + PipelineType.SIMPLE_KG_PIPELINE + ) + + from_pdf: bool = False + entities: Sequence[EntityInputType] = [] + relations: Sequence[RelationInputType] = [] + potential_schema: Optional[list[tuple[str, str, str]]] = None + on_error: OnError = OnError.IGNORE + prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate() + perform_entity_resolution: bool = True + lexical_graph_config: Optional[LexicalGraphConfig] = None + neo4j_database: Optional[str] = None + + pdf_loader: Optional[ComponentType] = None + kg_writer: Optional[ComponentType] = None + text_splitter: Optional[ComponentType] = None + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def _get_pdf_loader(self) -> Optional[PdfLoader]: + if not self.from_pdf: + return None + if self.pdf_loader: + return self.pdf_loader.parse(self._global_data) # type: ignore + return PdfLoader() + + def _get_splitter(self) -> TextSplitter: + if self.text_splitter: + return self.text_splitter.parse(self._global_data) # type: ignore + return FixedSizeSplitter() + + def _get_chunk_embedder(self) -> TextChunkEmbedder: + return TextChunkEmbedder(embedder=self.get_default_embedder()) + + def _get_schema(self) -> SchemaBuilder: + return SchemaBuilder() + + def _get_run_params_for_schema(self) -> dict[str, Any]: + return { + "entities": [SchemaEntity.from_text_or_dict(e) for e in self.entities], + "relations": [SchemaRelation.from_text_or_dict(r) for r in self.relations], + "potential_schema": self.potential_schema, + } + + def _get_extractor(self) -> EntityRelationExtractor: + return LLMEntityRelationExtractor( + llm=self.get_default_llm(), + prompt_template=self.prompt_template, + on_error=self.on_error, + ) + + def _get_writer(self) -> KGWriter: + if self.kg_writer: + return self.kg_writer.parse(self._global_data) # type: ignore + return Neo4jWriter( + driver=self.get_default_neo4j_driver(), + neo4j_database=self.neo4j_database, + ) + + def _get_resolver(self) -> Optional[EntityResolver]: + if not self.perform_entity_resolution: + return None + return SinglePropertyExactMatchResolver( + driver=self.get_default_neo4j_driver(), + neo4j_database=self.neo4j_database, + ) + + def _get_connections(self) -> list[ConnectionDefinition]: + connections = [] + if self.from_pdf: + connections.append( + ConnectionDefinition( + start="pdf_loader", + end="splitter", + input_config={"text": "pdf_loader.text"}, + ) + ) + connections.append( + ConnectionDefinition( + start="schema", + end="extractor", + input_config={ + "schema": "schema", + "document_info": "pdf_loader.document_info", + }, + ) + ) + else: + connections.append( + ConnectionDefinition( + start="schema", + end="extractor", + input_config={ + "schema": "schema", + }, + ) + ) + connections.append( + ConnectionDefinition( + start="splitter", + end="chunk_embedder", + input_config={ + "text_chunks": "splitter", + }, + ) + ) + connections.append( + ConnectionDefinition( + start="chunk_embedder", + end="extractor", + input_config={ + "chunks": "chunk_embedder", + }, + ) + ) + connections.append( + ConnectionDefinition( + start="extractor", + end="writer", + input_config={ + "graph": "extractor", + }, + ) + ) + + if self.perform_entity_resolution: + connections.append( + ConnectionDefinition( + start="writer", + end="resolver", + input_config={}, + ) + ) + + return connections + + def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: + run_params = {} + if self.lexical_graph_config: + run_params["extractor"] = { + "lexical_graph_config": self.lexical_graph_config + } + text = user_input.get("text") + file_path = user_input.get("file_path") + if not ((text is None) ^ (file_path is None)): + # exactly one of text or user_input must be set + raise PipelineDefinitionError( + "Use either 'text' (when from_pdf=False) or 'file_path' (when from_pdf=True) argument." + ) + if self.from_pdf: + if not file_path: + raise PipelineDefinitionError( + "Expected 'file_path' argument when 'from_pdf' is True." + ) + run_params["pdf_loader"] = {"filepath": file_path} + else: + if not text: + raise PipelineDefinitionError( + "Expected 'text' argument when 'from_pdf' is False." + ) + run_params["splitter"] = {"text": text} + return run_params diff --git a/src/neo4j_graphrag/experimental/pipeline/config/types.py b/src/neo4j_graphrag/experimental/pipeline/config/types.py new file mode 100644 index 00000000..48f91f48 --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/types.py @@ -0,0 +1,26 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import enum + + +class PipelineType(str, enum.Enum): + """Pipeline type: + + NONE => Pipeline + SIMPLE_KG_PIPELINE ~> SimpleKGPipeline + """ + + NONE = "none" + SIMPLE_KG_PIPELINE = "SimpleKGPipeline" diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index 58868cb3..3fca0215 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -18,30 +18,16 @@ from typing import Any, List, Optional, Sequence, Union import neo4j -from pydantic import BaseModel, ConfigDict, Field +from pydantic import ValidationError from neo4j_graphrag.embeddings import Embedder -from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder -from neo4j_graphrag.experimental.components.entity_relation_extractor import ( - LLMEntityRelationExtractor, - OnError, -) -from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter -from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader -from neo4j_graphrag.experimental.components.resolver import ( - SinglePropertyExactMatchResolver, -) -from neo4j_graphrag.experimental.components.schema import ( - SchemaBuilder, - SchemaEntity, - SchemaRelation, -) -from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( - FixedSizeSplitter, -) from neo4j_graphrag.experimental.components.types import LexicalGraphConfig +from neo4j_graphrag.experimental.pipeline.config.runner import PipelineRunner +from neo4j_graphrag.experimental.pipeline.config.template_pipeline import ( + SimpleKGPipelineConfig, +) from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError -from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline, PipelineResult +from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult from neo4j_graphrag.experimental.pipeline.types import ( EntityInputType, RelationInputType, @@ -50,26 +36,6 @@ from neo4j_graphrag.llm.base import LLMInterface -class SimpleKGPipelineConfig(BaseModel): - llm: LLMInterface - driver: neo4j.Driver - from_pdf: bool - embedder: Embedder - entities: list[SchemaEntity] = Field(default_factory=list) - relations: list[SchemaRelation] = Field(default_factory=list) - potential_schema: list[tuple[str, str, str]] = Field(default_factory=list) - pdf_loader: Any = None - kg_writer: Any = None - text_splitter: Any = None - on_error: OnError = OnError.RAISE - prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate() - perform_entity_resolution: bool = True - lexical_graph_config: Optional[LexicalGraphConfig] = None - neo4j_database: Optional[str] = None - - model_config = ConfigDict(arbitrary_types_allowed=True) - - class SimpleKGPipeline: """ A class to simplify the process of building a knowledge graph from text documents. @@ -120,133 +86,29 @@ def __init__( lexical_graph_config: Optional[LexicalGraphConfig] = None, neo4j_database: Optional[str] = None, ): - self.potential_schema = potential_schema or [] - self.entities = [self.to_schema_entity(e) for e in entities or []] - self.relations = [self.to_schema_relation(r) for r in relations or []] - try: - on_error_enum = OnError(on_error) - except ValueError: - raise PipelineDefinitionError( - f"Invalid value for on_error: {on_error}. Expected one of {OnError.possible_values()}." - ) - - config = SimpleKGPipelineConfig( - llm=llm, - driver=driver, - entities=self.entities, - relations=self.relations, - potential_schema=self.potential_schema, - from_pdf=from_pdf, - pdf_loader=pdf_loader, - kg_writer=kg_writer, - text_splitter=text_splitter, - on_error=on_error_enum, - prompt_template=prompt_template, - embedder=embedder, - perform_entity_resolution=perform_entity_resolution, - lexical_graph_config=lexical_graph_config, - neo4j_database=neo4j_database, - ) - - self.from_pdf = config.from_pdf - self.llm = config.llm - self.driver = config.driver - self.embedder = config.embedder - self.text_splitter = config.text_splitter or FixedSizeSplitter() - self.on_error = config.on_error - self.pdf_loader = config.pdf_loader if pdf_loader is not None else PdfLoader() - self.kg_writer = ( - config.kg_writer - if kg_writer is not None - else Neo4jWriter(driver, neo4j_database=config.neo4j_database) - ) - self.prompt_template = config.prompt_template - self.perform_entity_resolution = config.perform_entity_resolution - self.lexical_graph_config = config.lexical_graph_config - self.neo4j_database = config.neo4j_database - - self.pipeline = self._build_pipeline() - - @staticmethod - def to_schema_entity(entity: EntityInputType) -> SchemaEntity: - if isinstance(entity, dict): - return SchemaEntity.model_validate(entity) - return SchemaEntity(label=entity) - - @staticmethod - def to_schema_relation(relation: RelationInputType) -> SchemaRelation: - if isinstance(relation, dict): - return SchemaRelation.model_validate(relation) - return SchemaRelation(label=relation) - - def _build_pipeline(self) -> Pipeline: - pipe = Pipeline() - - pipe.add_component(self.text_splitter, "splitter") - pipe.add_component(SchemaBuilder(), "schema") - pipe.add_component( - LLMEntityRelationExtractor( - llm=self.llm, - on_error=self.on_error, - prompt_template=self.prompt_template, - ), - "extractor", - ) - pipe.add_component(TextChunkEmbedder(embedder=self.embedder), "chunk_embedder") - pipe.add_component(self.kg_writer, "writer") - - if self.from_pdf: - pipe.add_component(self.pdf_loader, "pdf_loader") - - pipe.connect( - "pdf_loader", - "splitter", - input_config={"text": "pdf_loader.text"}, - ) - - pipe.connect( - "schema", - "extractor", - input_config={ - "schema": "schema", - "document_info": "pdf_loader.document_info", - }, - ) - else: - pipe.connect( - "schema", - "extractor", - input_config={ - "schema": "schema", - }, - ) - - pipe.connect( - "splitter", "chunk_embedder", input_config={"text_chunks": "splitter"} - ) - - pipe.connect( - "chunk_embedder", "extractor", input_config={"chunks": "chunk_embedder"} - ) - - # Connect extractor to writer - pipe.connect( - "extractor", - "writer", - input_config={"graph": "extractor"}, - ) - - if self.perform_entity_resolution: - pipe.add_component( - SinglePropertyExactMatchResolver( - self.driver, neo4j_database=self.neo4j_database - ), - "resolver", + config = SimpleKGPipelineConfig( + # argument type are fixed in the Config object + llm_config=llm, # type: ignore[arg-type] + neo4j_config=driver, # type: ignore[arg-type] + embedder_config=embedder, # type: ignore[arg-type] + entities=entities or [], + relations=relations or [], + potential_schema=potential_schema, + from_pdf=from_pdf, + pdf_loader=pdf_loader, + kg_writer=kg_writer, + text_splitter=text_splitter, + on_error=on_error, # type: ignore[arg-type] + prompt_template=prompt_template, + perform_entity_resolution=perform_entity_resolution, + lexical_graph_config=lexical_graph_config, + neo4j_database=neo4j_database, ) - pipe.connect("writer", "resolver", {}) + except ValidationError as e: + raise PipelineDefinitionError() from e - return pipe + self.runner = PipelineRunner.from_config(config) async def run_async( self, file_path: Optional[str] = None, text: Optional[str] = None @@ -261,39 +123,4 @@ async def run_async( Returns: PipelineResult: The result of the pipeline execution. """ - pipe_inputs = self._prepare_inputs(file_path=file_path, text=text) - return await self.pipeline.run(pipe_inputs) - - def _prepare_inputs( - self, file_path: Optional[str], text: Optional[str] - ) -> dict[str, Any]: - if self.from_pdf: - if file_path is None or text is not None: - raise PipelineDefinitionError( - "Expected 'file_path' argument when 'from_pdf' is True." - ) - else: - if text is None or file_path is not None: - raise PipelineDefinitionError( - "Expected 'text' argument when 'from_pdf' is False." - ) - - pipe_inputs: dict[str, Any] = { - "schema": { - "entities": self.entities, - "relations": self.relations, - "potential_schema": self.potential_schema, - }, - } - - if self.from_pdf: - pipe_inputs["pdf_loader"] = {"filepath": file_path} - else: - pipe_inputs["splitter"] = {"text": text} - - if self.lexical_graph_config: - pipe_inputs["extractor"] = { - "lexical_graph_config": self.lexical_graph_config - } - - return pipe_inputs + return await self.runner.run({"file_path": file_path, "text": text}) diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index 5edc2783..e3ded494 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -44,9 +44,9 @@ ) from neo4j_graphrag.experimental.pipeline.stores import InMemoryStore, ResultStore from neo4j_graphrag.experimental.pipeline.types import ( - ComponentConfig, - ConnectionConfig, - PipelineConfig, + ComponentDefinition, + ConnectionDefinition, + PipelineDefinition, ) logger = logging.getLogger(__name__) @@ -349,16 +349,34 @@ def __init__(self, store: Optional[ResultStore] = None) -> None: @classmethod def from_template( - cls, pipeline_template: PipelineConfig, store: Optional[ResultStore] = None + cls, pipeline_template: PipelineDefinition, store: Optional[ResultStore] = None ) -> Pipeline: - """Create a Pipeline from a pydantic model defining the components and their connections""" + warnings.warn( + "from_template is deprecated, use from_definition instead", + DeprecationWarning, + stacklevel=2, + ) + return cls.from_definition(pipeline_template, store) + + @classmethod + def from_definition( + cls, + pipeline_definition: PipelineDefinition, + store: Optional[ResultStore] = None, + ) -> Pipeline: + """Create a Pipeline from a pydantic model defining the components and their connections + + Args: + pipeline_definition (PipelineDefinition): An object defining components and how they are connected to each other. + store (Optional[ResultStore]): Where the results are stored. By default, uses the InMemoryStore. + """ pipeline = Pipeline(store=store) - for component in pipeline_template.components: + for component in pipeline_definition.components: pipeline.add_component( component.component, component.name, ) - for edge in pipeline_template.connections: + for edge in pipeline_definition.connections: pipeline_edge = PipelineEdge( edge.start, edge.end, data={"input_config": edge.input_config} ) @@ -369,18 +387,18 @@ def show_as_dict(self) -> dict[str, Any]: component_config = [] for name, task in self._nodes.items(): component_config.append( - ComponentConfig(name=name, component=task.component) + ComponentDefinition(name=name, component=task.component) ) connection_config = [] for edge in self._edges: connection_config.append( - ConnectionConfig( + ConnectionDefinition( start=edge.start, end=edge.end, input_config=edge.data["input_config"] if edge.data else {}, ) ) - pipeline_config = PipelineConfig( + pipeline_config = PipelineDefinition( components=component_config, connections=connection_config ) return pipeline_config.model_dump() diff --git a/src/neo4j_graphrag/experimental/pipeline/types.py b/src/neo4j_graphrag/experimental/pipeline/types.py index ebdf141d..47aafd8b 100644 --- a/src/neo4j_graphrag/experimental/pipeline/types.py +++ b/src/neo4j_graphrag/experimental/pipeline/types.py @@ -14,29 +14,36 @@ # limitations under the License. from __future__ import annotations -from typing import Union +from collections import defaultdict +from typing import Any, Union from pydantic import BaseModel, ConfigDict from neo4j_graphrag.experimental.pipeline.component import Component -class ComponentConfig(BaseModel): +class ComponentDefinition(BaseModel): name: str component: Component + run_params: dict[str, Any] = {} model_config = ConfigDict(arbitrary_types_allowed=True) -class ConnectionConfig(BaseModel): +class ConnectionDefinition(BaseModel): start: str end: str input_config: dict[str, str] -class PipelineConfig(BaseModel): - components: list[ComponentConfig] - connections: list[ConnectionConfig] +class PipelineDefinition(BaseModel): + components: list[ComponentDefinition] + connections: list[ConnectionDefinition] + + def get_run_params(self) -> defaultdict[str, dict[str, Any]]: + return defaultdict( + dict, {c.name: c.run_params for c in self.components if c.run_params} + ) EntityInputType = Union[str, dict[str, Union[str, list[dict[str, str]]]]] diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 70ca0193..42c21cf8 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -14,6 +14,7 @@ # limitations under the License. from __future__ import annotations +import os import random import string import uuid @@ -33,6 +34,8 @@ from ..e2e.utils import EMBEDDING_BIOLOGY +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + @pytest.fixture(scope="module") def driver() -> Generator[Any, Any, Any]: @@ -48,6 +51,12 @@ def llm() -> MagicMock: return MagicMock(spec=LLMInterface) +@pytest.fixture +def embedder() -> Embedder: + embedder = MagicMock(spec=Embedder) + return embedder + + class RandomEmbedder(Embedder): def embed_query(self, text: str) -> list[float]: return [random.random() for _ in range(1536)] @@ -75,6 +84,31 @@ def retriever_mock() -> MagicMock: return MagicMock(spec=VectorRetriever) +@pytest.fixture +def harry_potter_text() -> str: + with open(os.path.join(BASE_DIR, "data/documents/harry_potter.txt"), "r") as f: + text = f.read() + return text + + +@pytest.fixture +def harry_potter_text_part1() -> str: + with open( + os.path.join(BASE_DIR, "data/documents/harry_potter_part1.txt"), "r" + ) as f: + text = f.read() + return text + + +@pytest.fixture +def harry_potter_text_part2() -> str: + with open( + os.path.join(BASE_DIR, "data/documents/harry_potter_part2.txt"), "r" + ) as f: + text = f.read() + return text + + @pytest.fixture(scope="module") def setup_neo4j_for_retrieval(driver: Driver) -> None: vector_index_name = "vector-index-name" diff --git a/tests/e2e/data/config_files/pipeline_config.json b/tests/e2e/data/config_files/pipeline_config.json new file mode 100644 index 00000000..fe36624d --- /dev/null +++ b/tests/e2e/data/config_files/pipeline_config.json @@ -0,0 +1,72 @@ +{ + "version_": "1", + "template_": "none", + "name": "", + "neo4j_config": { + "params_": { + "uri": { + "resolver_": "ENV", + "var_": "NEO4J_URI" + }, + "user": { + "resolver_": "ENV", + "var_": "NEO4J_USER" + }, + "password": { + "resolver_": "ENV", + "var_": "NEO4J_PASSWORD" + } + } + }, + "extras": { + "database": "neo4j" + }, + "component_config": { + "splitter": { + "class_": "text_splitters.fixed_size_splitter.FixedSizeSplitter", + "params_": { + "chunk_size": 100, + "chunk_overlap": 10 + } + }, + "builder": { + "class_": "lexical_graph.LexicalGraphBuilder", + "params_": { + "config": { + "chunk_node_label": "TextPart" + } + } + }, + "writer": { + "name_": "writer", + "class_": "kg_writer.Neo4jWriter", + "params_": { + "driver": { + "resolver_": "CONFIG_KEY", + "key_": "neo4j_config.default" + }, + "neo4j_database": { + "resolver_": "CONFIG_KEY", + "key_": "extras.database" + } + } + } + }, + "connection_config": [ + { + "start": "splitter", + "end": "builder", + "input_config": { + "text_chunks": "splitter" + } + }, + { + "start": "builder", + "end": "writer", + "input_config": { + "graph": "builder.graph", + "lexical_graph_config": "builder.config" + } + } + ] +} diff --git a/tests/e2e/data/config_files/pipeline_config.yaml b/tests/e2e/data/config_files/pipeline_config.yaml new file mode 100644 index 00000000..87ac905e --- /dev/null +++ b/tests/e2e/data/config_files/pipeline_config.yaml @@ -0,0 +1,45 @@ +version_: "1" +template_: none +neo4j_config: + params_: + uri: + resolver_: ENV + var_: NEO4J_URI + user: + resolver_: ENV + var_: NEO4J_USER + password: + resolver_: ENV + var_: NEO4J_PASSWORD +extras: + database: neo4j +component_config: + splitter: + class_: text_splitters.fixed_size_splitter.FixedSizeSplitter + params_: + chunk_size: 100 + chunk_overlap: 10 + builder: + class_: lexical_graph.LexicalGraphBuilder + params_: + config: + chunk_node_label: TextPart + writer: + class_: kg_writer.Neo4jWriter + params_: + driver: + resolver_: CONFIG_KEY + key_: neo4j_config.default + neo4j_database: + resolver_: CONFIG_KEY + key_: extras.database +connection_config: + - start: splitter + end: builder + input_config: + text_chunks: splitter + - start: builder + end: writer + input_config: + graph: builder.graph + lexical_graph_config: builder.config diff --git a/tests/e2e/data/config_files/simple_kg_pipeline_config.json b/tests/e2e/data/config_files/simple_kg_pipeline_config.json new file mode 100644 index 00000000..c2d629fb --- /dev/null +++ b/tests/e2e/data/config_files/simple_kg_pipeline_config.json @@ -0,0 +1,64 @@ +{ + "version_": "1", + "template_": "SimpleKGPipeline", + "neo4j_config": { + "params_": { + "uri": { + "resolver_": "ENV", + "var_": "NEO4J_URI" + }, + "user": { + "resolver_": "ENV", + "var_": "NEO4J_USER" + }, + "password": { + "resolver_": "ENV", + "var_": "NEO4J_PASSWORD" + } + } + }, + "llm_config": { + "class_": "OpenAILLM", + "params_": { + "api_key": { + "resolver_": "ENV", + "var_": "OPENAI_API_KEY" + }, + "model_name": "gpt-4o", + "model_params": { + "temperature": 0, + "max_tokens": 2000, + "response_format": {"type": "json_object"} + } + } + }, + "embedder_config": { + "class_": "OpenAIEmbeddings", + "params_": { + "api_key": { + "resolver_": "ENV", + "var_": "OPENAI_API_KEY" + } + } + }, + "from_pdf": true, + "entities": [ + "Person", + "Organization", + "Horcrux", + "Location" + ], + "relations": [ + "SITUATED_AT", + "INTERACTS", + "OWNS", + "LED_BY" + ], + "potential_schema": [ + ["Person", "SITUATED_AT", "Location"], + ["Person", "INTERACTS", "Person"], + ["Person", "OWNS", "Horcrux"], + ["Organization", "LED_BY", "Person"] + ], + "perform_entity_resolution": true +} diff --git a/tests/e2e/data/config_files/simple_kg_pipeline_config.yaml b/tests/e2e/data/config_files/simple_kg_pipeline_config.yaml new file mode 100644 index 00000000..abcf9632 --- /dev/null +++ b/tests/e2e/data/config_files/simple_kg_pipeline_config.yaml @@ -0,0 +1,50 @@ +version_: "1" +template_: SimpleKGPipeline +neo4j_config: + params_: + uri: + resolver_: ENV + var_: NEO4J_URI + user: + resolver_: ENV + var_: NEO4J_USER + password: + resolver_: ENV + var_: NEO4J_PASSWORD +llm_config: + class_: OpenAILLM + params_: + api_key: + resolver_: ENV + var_: OPENAI_API_KEY + model_name: gpt-4o + model_params: + temperature: 0 + max_tokens: 2000 + response_format: + type: json_object +embedder_config: + class_: OpenAIEmbeddings + params_: + api_key: + resolver_: ENV + var_: OPENAI_API_KEY +from_pdf: true +entities: + - Person + - Organization + - Location + - Horcrux +relations: + - SITUATED_AT + - INTERACTS + - OWNS + - LED_BY +potential_schema: + - ["Person", "SITUATED_AT", "Location"] + - ["Person", "INTERACTS", "Person"] + - ["Person", "OWNS", "Horcrux"] + - ["Organization", "LED_BY", "Person"] +text_splitter: + class_: text_splitters.fixed_size_splitter.FixedSizeSplitter +perform_entity_resolution: true diff --git a/tests/e2e/data/documents/harry_potter.pdf b/tests/e2e/data/documents/harry_potter.pdf new file mode 100644 index 00000000..49f0f6bd Binary files /dev/null and b/tests/e2e/data/documents/harry_potter.pdf differ diff --git a/tests/e2e/data/harry_potter.txt b/tests/e2e/data/documents/harry_potter.txt similarity index 100% rename from tests/e2e/data/harry_potter.txt rename to tests/e2e/data/documents/harry_potter.txt diff --git a/tests/e2e/experimental/__init__.py b/tests/e2e/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/e2e/experimental/pipeline/__init__.py b/tests/e2e/experimental/pipeline/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/e2e/experimental/pipeline/config/__init__.py b/tests/e2e/experimental/pipeline/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/e2e/experimental/pipeline/config/test_pipeline_runner_e2e.py b/tests/e2e/experimental/pipeline/config/test_pipeline_runner_e2e.py new file mode 100644 index 00000000..9122a935 --- /dev/null +++ b/tests/e2e/experimental/pipeline/config/test_pipeline_runner_e2e.py @@ -0,0 +1,208 @@ +import os +from typing import Any +from unittest.mock import AsyncMock, Mock, patch + +import neo4j +import pytest +from neo4j_graphrag.experimental.pipeline.config.runner import PipelineRunner +from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult +from neo4j_graphrag.llm import LLMResponse + + +@pytest.fixture(scope="function", autouse=True) +def clear_db(driver: neo4j.Driver) -> Any: + driver.execute_query("MATCH (n) DETACH DELETE n") + yield + + +@pytest.mark.asyncio +async def test_pipeline_from_json_config(harry_potter_text: str, driver: Mock) -> None: + os.environ["NEO4J_URI"] = "neo4j://localhost:7687" + os.environ["NEO4J_USER"] = "neo4j" + os.environ["NEO4J_PASSWORD"] = "password" + + runner = PipelineRunner.from_config_file( + "tests/e2e/data/config_files/pipeline_config.json" + ) + res = await runner.run({"splitter": {"text": harry_potter_text}}) + assert isinstance(res, PipelineResult) + assert res.result["writer"]["metadata"] == { + "node_count": 11, + "relationship_count": 10, + } + nodes = driver.execute_query("MATCH (n) RETURN n") + assert len(nodes.records) == 11 + + +@pytest.mark.asyncio +async def test_pipeline_from_yaml_config(harry_potter_text: str, driver: Mock) -> None: + os.environ["NEO4J_URI"] = "neo4j://localhost:7687" + os.environ["NEO4J_USER"] = "neo4j" + os.environ["NEO4J_PASSWORD"] = "password" + + runner = PipelineRunner.from_config_file( + "tests/e2e/data/config_files/pipeline_config.yaml" + ) + res = await runner.run({"splitter": {"text": harry_potter_text}}) + assert isinstance(res, PipelineResult) + assert res.result["writer"]["metadata"] == { + "node_count": 11, + "relationship_count": 10, + } + + nodes = driver.execute_query("MATCH (n) RETURN n") + assert len(nodes.records) == 11 + + +@patch( + "neo4j_graphrag.experimental.pipeline.config.runner.SimpleKGPipelineConfig.get_default_embedder" +) +@patch( + "neo4j_graphrag.experimental.pipeline.config.runner.SimpleKGPipelineConfig.get_default_llm" +) +@pytest.mark.asyncio +async def test_simple_kg_pipeline_from_json_config( + mock_llm: Mock, mock_embedder: Mock, harry_potter_text: str, driver: Mock +) -> None: + mock_llm.return_value.ainvoke = AsyncMock( + side_effect=[ + LLMResponse( + content="""{ + "nodes": [ + { + "id": "0", + "label": "Person", + "properties": { + "name": "Harry Potter" + } + }, + { + "id": "1", + "label": "Person", + "properties": { + "name": "Alastor Mad-Eye Moody" + } + }, + { + "id": "2", + "label": "Organization", + "properties": { + "name": "The Order of the Phoenix" + } + } + ], + "relationships": [ + { + "type": "KNOWS", + "start_node_id": "0", + "end_node_id": "1" + }, + { + "type": "LED_BY", + "start_node_id": "2", + "end_node_id": "1" + } + ] + }""" + ), + ] + ) + mock_embedder.return_value.embed_query.side_effect = [ + [1.0, 2.0], + ] + + os.environ["NEO4J_URI"] = "neo4j://localhost:7687" + os.environ["NEO4J_USER"] = "neo4j" + os.environ["NEO4J_PASSWORD"] = "password" + os.environ["OPENAI_API_KEY"] = "sk-my-secret-key" + + runner = PipelineRunner.from_config_file( + "tests/e2e/data/config_files/simple_kg_pipeline_config.json" + ) + res = await runner.run({"file_path": "tests/e2e/data/documents/harry_potter.pdf"}) + assert isinstance(res, PipelineResult) + # print(await runner.pipeline.store.get_result_for_component(res.run_id, "splitter")) + assert res.result["resolver"] == { + "number_of_nodes_to_resolve": 3, + "number_of_created_nodes": 3, + } + nodes = driver.execute_query("MATCH (n) RETURN n") + # 1 chunk + 1 document + 3 nodes + assert len(nodes.records) == 5 + + +@patch( + "neo4j_graphrag.experimental.pipeline.config.runner.SimpleKGPipelineConfig.get_default_embedder" +) +@patch( + "neo4j_graphrag.experimental.pipeline.config.runner.SimpleKGPipelineConfig.get_default_llm" +) +@pytest.mark.asyncio +async def test_simple_kg_pipeline_from_yaml_config( + mock_llm: Mock, mock_embedder: Mock, harry_potter_text: str, driver: Mock +) -> None: + mock_llm.return_value.ainvoke = AsyncMock( + side_effect=[ + LLMResponse( + content="""{ + "nodes": [ + { + "id": "0", + "label": "Person", + "properties": { + "name": "Harry Potter" + } + }, + { + "id": "1", + "label": "Person", + "properties": { + "name": "Alastor Mad-Eye Moody" + } + }, + { + "id": "2", + "label": "Organization", + "properties": { + "name": "The Order of the Phoenix" + } + } + ], + "relationships": [ + { + "type": "KNOWS", + "start_node_id": "0", + "end_node_id": "1" + }, + { + "type": "LED_BY", + "start_node_id": "2", + "end_node_id": "1" + } + ] + }""" + ), + ] + ) + mock_embedder.return_value.embed_query.side_effect = [ + [1.0, 2.0], + ] + + os.environ["NEO4J_URI"] = "neo4j://localhost:7687" + os.environ["NEO4J_USER"] = "neo4j" + os.environ["NEO4J_PASSWORD"] = "password" + os.environ["OPENAI_API_KEY"] = "sk-my-secret-key" + + runner = PipelineRunner.from_config_file( + "tests/e2e/data/config_files/simple_kg_pipeline_config.yaml" + ) + res = await runner.run({"file_path": "tests/e2e/data/documents/harry_potter.pdf"}) + assert isinstance(res, PipelineResult) + # print(await runner.pipeline.store.get_result_for_component(res.run_id, "splitter")) + assert res.result["resolver"] == { + "number_of_nodes_to_resolve": 3, + "number_of_created_nodes": 3, + } + nodes = driver.execute_query("MATCH (n) RETURN n") + # 1 chunk + 1 document + 3 nodes + assert len(nodes.records) == 5 diff --git a/tests/e2e/test_kg_builder_pipeline_e2e.py b/tests/e2e/test_kg_builder_pipeline_e2e.py index 1713098b..bf74470c 100644 --- a/tests/e2e/test_kg_builder_pipeline_e2e.py +++ b/tests/e2e/test_kg_builder_pipeline_e2e.py @@ -124,31 +124,6 @@ def kg_builder_pipeline( return pipe -@pytest.fixture -def harry_potter_text() -> str: - with open(os.path.join(BASE_DIR, "data/harry_potter.txt"), "r") as f: - text = f.read() - return text - - -@pytest.fixture -def harry_potter_text_part1() -> str: - with open( - os.path.join(BASE_DIR, "data/documents/harry_potter_part1.txt"), "r" - ) as f: - text = f.read() - return text - - -@pytest.fixture -def harry_potter_text_part2() -> str: - with open( - os.path.join(BASE_DIR, "data/documents/harry_potter_part2.txt"), "r" - ) as f: - text = f.read() - return text - - @pytest.mark.asyncio @pytest.mark.usefixtures("setup_neo4j_for_kg_construction") async def test_pipeline_builder_happy_path( diff --git a/tests/e2e/test_simplekgpipeline_e2e.py b/tests/e2e/test_simplekgpipeline_e2e.py index def867f5..4d3059c7 100644 --- a/tests/e2e/test_simplekgpipeline_e2e.py +++ b/tests/e2e/test_simplekgpipeline_e2e.py @@ -14,36 +14,13 @@ # limitations under the License. from __future__ import annotations -import os from unittest.mock import MagicMock import neo4j import pytest -from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.experimental.components.types import LexicalGraphConfig from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline -from neo4j_graphrag.llm import LLMInterface, LLMResponse - -BASE_DIR = os.path.dirname(os.path.abspath(__file__)) - - -@pytest.fixture -def llm() -> LLMInterface: - llm = MagicMock(spec=LLMInterface) - return llm - - -@pytest.fixture -def embedder() -> Embedder: - embedder = MagicMock(spec=Embedder) - return embedder - - -@pytest.fixture -def harry_potter_text() -> str: - with open(os.path.join(BASE_DIR, "data/harry_potter.txt"), "r") as f: - text = f.read() - return text +from neo4j_graphrag.llm import LLMResponse @pytest.mark.asyncio diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 829cad23..5069ab6e 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -19,6 +19,7 @@ import neo4j import pytest from neo4j_graphrag.embeddings.base import Embedder +from neo4j_graphrag.experimental.pipeline import Component from neo4j_graphrag.llm import LLMInterface from neo4j_graphrag.retrievers import ( HybridRetriever, @@ -98,3 +99,8 @@ def format_function(record: neo4j.Record) -> RetrieverResultItem: ) return format_function + + +@pytest.fixture(scope="function") +def component() -> MagicMock: + return MagicMock(spec=Component) diff --git a/tests/unit/experimental/pipeline/config/__init__.py b/tests/unit/experimental/pipeline/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/__init__.py b/tests/unit/experimental/pipeline/config/template_pipeline/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/test_base.py b/tests/unit/experimental/pipeline/config/template_pipeline/test_base.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py new file mode 100644 index 00000000..e8d12095 --- /dev/null +++ b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py @@ -0,0 +1,301 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock, patch + +import neo4j +import pytest +from neo4j_graphrag.embeddings import Embedder +from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder +from neo4j_graphrag.experimental.components.entity_relation_extractor import ( + LLMEntityRelationExtractor, + OnError, +) +from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter +from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader +from neo4j_graphrag.experimental.components.schema import ( + SchemaBuilder, + SchemaEntity, + SchemaRelation, +) +from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( + FixedSizeSplitter, +) +from neo4j_graphrag.experimental.pipeline.config.object_config import ComponentConfig +from neo4j_graphrag.experimental.pipeline.config.template_pipeline import ( + SimpleKGPipelineConfig, +) +from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError +from neo4j_graphrag.generation.prompts import ERExtractionTemplate +from neo4j_graphrag.llm import LLMInterface + + +def test_simple_kg_pipeline_config_pdf_loader_from_pdf_is_false() -> None: + config = SimpleKGPipelineConfig(from_pdf=False) + assert config._get_pdf_loader() is None + + +def test_simple_kg_pipeline_config_pdf_loader_from_pdf_is_true() -> None: + config = SimpleKGPipelineConfig(from_pdf=True) + assert isinstance(config._get_pdf_loader(), PdfLoader) + + +def test_simple_kg_pipeline_config_pdf_loader_from_pdf_is_true_class_overwrite() -> ( + None +): + my_pdf_loader = PdfLoader() + config = SimpleKGPipelineConfig(from_pdf=True, pdf_loader=my_pdf_loader) # type: ignore + assert config._get_pdf_loader() == my_pdf_loader + + +def test_simple_kg_pipeline_config_pdf_loader_class_overwrite_but_from_pdf_is_false() -> ( + None +): + my_pdf_loader = PdfLoader() + config = SimpleKGPipelineConfig(from_pdf=False, pdf_loader=my_pdf_loader) # type: ignore + assert config._get_pdf_loader() is None + + +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse") +def test_simple_kg_pipeline_config_pdf_loader_from_pdf_is_true_class_overwrite_from_config( + mock_component_parse: Mock, +) -> None: + my_pdf_loader_config = ComponentConfig( + class_="", + ) + my_pdf_loader = PdfLoader() + mock_component_parse.return_value = my_pdf_loader + config = SimpleKGPipelineConfig( + from_pdf=True, + pdf_loader=my_pdf_loader_config, # type: ignore + ) + assert config._get_pdf_loader() == my_pdf_loader + + +def test_simple_kg_pipeline_config_text_splitter() -> None: + config = SimpleKGPipelineConfig() + assert isinstance(config._get_splitter(), FixedSizeSplitter) + + +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse") +def test_simple_kg_pipeline_config_text_splitter_overwrite( + mock_component_parse: Mock, +) -> None: + my_text_splitter_config = ComponentConfig( + class_="", + ) + my_text_splitter = FixedSizeSplitter() + mock_component_parse.return_value = my_text_splitter + config = SimpleKGPipelineConfig( + text_splitter=my_text_splitter_config, # type: ignore + ) + assert config._get_splitter() == my_text_splitter + + +@patch( + "neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_embedder" +) +def test_simple_kg_pipeline_config_chunk_embedder( + mock_embedder: Mock, embedder: Embedder +) -> None: + mock_embedder.return_value = embedder + config = SimpleKGPipelineConfig() + chunk_embedder = config._get_chunk_embedder() + assert isinstance(chunk_embedder, TextChunkEmbedder) + assert chunk_embedder._embedder == embedder + + +def test_simple_kg_pipeline_config_schema() -> None: + config = SimpleKGPipelineConfig() + assert isinstance(config._get_schema(), SchemaBuilder) + + +def test_simple_kg_pipeline_config_schema_run_params() -> None: + config = SimpleKGPipelineConfig( + entities=["Person"], + relations=["KNOWS"], + potential_schema=[("Person", "KNOWS", "Person")], + ) + assert config._get_run_params_for_schema() == { + "entities": [SchemaEntity(label="Person")], + "relations": [SchemaRelation(label="KNOWS")], + "potential_schema": [ + ("Person", "KNOWS", "Person"), + ], + } + + +@patch( + "neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_llm" +) +def test_simple_kg_pipeline_config_extractor(mock_llm: Mock, llm: LLMInterface) -> None: + mock_llm.return_value = llm + config = SimpleKGPipelineConfig( + on_error="IGNORE", # type: ignore + prompt_template=ERExtractionTemplate(template="my template {text}"), + ) + extractor = config._get_extractor() + assert isinstance(extractor, LLMEntityRelationExtractor) + assert extractor.llm == llm + assert extractor.on_error == OnError.IGNORE + assert extractor.prompt_template.template == "my template {text}" + + +@patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", + return_value=(5, 23, 0), +) +@patch( + "neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_neo4j_driver" +) +def test_simple_kg_pipeline_config_writer( + mock_driver: Mock, + _: Mock, + driver: neo4j.Driver, +) -> None: + mock_driver.return_value = driver + config = SimpleKGPipelineConfig( + neo4j_database="my_db", + ) + writer = config._get_writer() + assert isinstance(writer, Neo4jWriter) + assert writer.driver == driver + assert writer.neo4j_database == "my_db" + + +@patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", + return_value=(5, 23, 0), +) +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse") +def test_simple_kg_pipeline_config_writer_overwrite( + mock_component_parse: Mock, + _: Mock, + driver: neo4j.Driver, +) -> None: + my_writer_config = ComponentConfig( + class_="", + ) + my_writer = Neo4jWriter(driver, neo4j_database="my_db") + mock_component_parse.return_value = my_writer + config = SimpleKGPipelineConfig( + kg_writer=my_writer_config, # type: ignore + neo4j_database="my_other_db", + ) + writer: Neo4jWriter = config._get_writer() # type: ignore + assert writer == my_writer + # database not changed: + assert writer.neo4j_database == "my_db" + + +def test_simple_kg_pipeline_config_connections_from_pdf() -> None: + config = SimpleKGPipelineConfig( + from_pdf=True, + perform_entity_resolution=False, + ) + connections = config._get_connections() + assert len(connections) == 5 + expected_connections = [ + ("pdf_loader", "splitter"), + ("schema", "extractor"), + ("splitter", "chunk_embedder"), + ("chunk_embedder", "extractor"), + ("extractor", "writer"), + ] + for actual, expected in zip(connections, expected_connections): + assert (actual.start, actual.end) == expected + + +def test_simple_kg_pipeline_config_connections_from_text() -> None: + config = SimpleKGPipelineConfig( + from_pdf=False, + perform_entity_resolution=False, + ) + connections = config._get_connections() + assert len(connections) == 4 + expected_connections = [ + ("schema", "extractor"), + ("splitter", "chunk_embedder"), + ("chunk_embedder", "extractor"), + ("extractor", "writer"), + ] + for actual, expected in zip(connections, expected_connections): + assert (actual.start, actual.end) == expected + + +def test_simple_kg_pipeline_config_connections_with_er() -> None: + config = SimpleKGPipelineConfig( + from_pdf=True, + perform_entity_resolution=True, + ) + connections = config._get_connections() + assert len(connections) == 6 + expected_connections = [ + ("pdf_loader", "splitter"), + ("schema", "extractor"), + ("splitter", "chunk_embedder"), + ("chunk_embedder", "extractor"), + ("extractor", "writer"), + ("writer", "resolver"), + ] + for actual, expected in zip(connections, expected_connections): + assert (actual.start, actual.end) == expected + + +def test_simple_kg_pipeline_config_run_params_from_pdf_file_path() -> None: + config = SimpleKGPipelineConfig(from_pdf=True) + assert config.get_run_params({"file_path": "my_file"}) == { + "pdf_loader": {"filepath": "my_file"} + } + + +def test_simple_kg_pipeline_config_run_params_from_text_text() -> None: + config = SimpleKGPipelineConfig(from_pdf=False) + assert config.get_run_params({"text": "my text"}) == { + "splitter": {"text": "my text"} + } + + +def test_simple_kg_pipeline_config_run_params_from_pdf_text() -> None: + config = SimpleKGPipelineConfig(from_pdf=True) + with pytest.raises(PipelineDefinitionError) as excinfo: + config.get_run_params({"text": "my text"}) + assert "Expected 'file_path' argument when 'from_pdf' is True" in str(excinfo) + + +def test_simple_kg_pipeline_config_run_params_from_text_file_path() -> None: + config = SimpleKGPipelineConfig(from_pdf=False) + with pytest.raises(PipelineDefinitionError) as excinfo: + config.get_run_params({"file_path": "my file"}) + assert "Expected 'text' argument when 'from_pdf' is False" in str(excinfo) + + +def test_simple_kg_pipeline_config_run_params_no_file_no_text() -> None: + config = SimpleKGPipelineConfig(from_pdf=False) + with pytest.raises(PipelineDefinitionError) as excinfo: + config.get_run_params({}) + assert ( + "Use either 'text' (when from_pdf=False) or 'file_path' (when from_pdf=True) argument." + in str(excinfo) + ) + + +def test_simple_kg_pipeline_config_run_params_both_file_and_text() -> None: + config = SimpleKGPipelineConfig(from_pdf=False) + with pytest.raises(PipelineDefinitionError) as excinfo: + config.get_run_params({"text": "my text", "file_path": "my file"}) + assert ( + "Use either 'text' (when from_pdf=False) or 'file_path' (when from_pdf=True) argument." + in str(excinfo) + ) diff --git a/tests/unit/experimental/pipeline/config/test_base.py b/tests/unit/experimental/pipeline/config/test_base.py new file mode 100644 index 00000000..20332a79 --- /dev/null +++ b/tests/unit/experimental/pipeline/config/test_base.py @@ -0,0 +1,37 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import patch + +from neo4j_graphrag.experimental.pipeline.config.base import AbstractConfig +from neo4j_graphrag.experimental.pipeline.config.param_resolver import ( + ParamToResolveConfig, +) + + +def test_resolve_param_with_param_to_resolve_object() -> None: + c = AbstractConfig() + with patch( + "neo4j_graphrag.experimental.pipeline.config.param_resolver.ParamToResolveConfig", + spec=ParamToResolveConfig, + ) as mock_param_class: + mock_param = mock_param_class.return_value + mock_param.resolve.return_value = 1 + assert c.resolve_param(mock_param) == 1 + mock_param.resolve.assert_called_once_with({}) + + +def test_resolve_param_with_other_object() -> None: + c = AbstractConfig() + assert c.resolve_param("value") == "value" diff --git a/tests/unit/experimental/pipeline/config/test_object_config.py b/tests/unit/experimental/pipeline/config/test_object_config.py new file mode 100644 index 00000000..baa3c767 --- /dev/null +++ b/tests/unit/experimental/pipeline/config/test_object_config.py @@ -0,0 +1,164 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import patch + +import neo4j +import pytest +from neo4j_graphrag.embeddings import Embedder, OpenAIEmbeddings +from neo4j_graphrag.experimental.pipeline import Pipeline +from neo4j_graphrag.experimental.pipeline.config.object_config import ( + EmbedderConfig, + EmbedderType, + LLMConfig, + LLMType, + Neo4jDriverConfig, + Neo4jDriverType, + ObjectConfig, +) +from neo4j_graphrag.llm import LLMInterface, OpenAILLM + + +def test_get_class_no_optional_module() -> None: + c: ObjectConfig[object] = ObjectConfig() + klass = c._get_class("neo4j_graphrag.experimental.pipeline.Pipeline") + assert klass == Pipeline + + +def test_get_class_optional_module() -> None: + c: ObjectConfig[object] = ObjectConfig() + klass = c._get_class( + "Pipeline", optional_module="neo4j_graphrag.experimental.pipeline" + ) + assert klass == Pipeline + + +def test_get_class_path_and_optional_module() -> None: + c: ObjectConfig[object] = ObjectConfig() + klass = c._get_class( + "pipeline.Pipeline", optional_module="neo4j_graphrag.experimental" + ) + assert klass == Pipeline + + +def test_get_class_wrong_path() -> None: + c: ObjectConfig[object] = ObjectConfig() + with pytest.raises(ValueError): + c._get_class("MyClass") + + +def test_neo4j_driver_config() -> None: + config = Neo4jDriverConfig.model_validate( + { + "params_": { + "uri": "bolt://", + "user": "a user", + "password": "a password", + } + } + ) + assert config.class_ == "not used" + assert config.params_ == { + "uri": "bolt://", + "user": "a user", + "password": "a password", + } + with patch( + "neo4j_graphrag.experimental.pipeline.config.object_config.neo4j.GraphDatabase.driver" + ) as driver_mock: + driver_mock.return_value = "a driver" + d = config.parse() + driver_mock.assert_called_once_with("bolt://", auth=("a user", "a password")) + assert d == "a driver" # type: ignore + + +def test_neo4j_driver_type_with_driver(driver: neo4j.Driver) -> None: + driver_type = Neo4jDriverType(driver) + assert driver_type.parse() == driver + + +def test_neo4j_driver_type_with_config() -> None: + driver_type = Neo4jDriverType( + Neo4jDriverConfig( + params_={ + "uri": "bolt://", + "user": "", + "password": "", + } + ) + ) + driver = driver_type.parse() + assert isinstance(driver, neo4j.Driver) + + +def test_llm_config() -> None: + config = LLMConfig.model_validate( + { + "class_": "OpenAILLM", + "params_": {"model_name": "gpt-4o", "api_key": "my-api-key"}, + } + ) + assert config.class_ == "OpenAILLM" + assert config.get_module() == "neo4j_graphrag.llm" + assert config.get_interface() == LLMInterface + assert config.params_ == {"model_name": "gpt-4o", "api_key": "my-api-key"} + d = config.parse() + assert isinstance(d, OpenAILLM) + + +def test_llm_type_with_driver(llm: LLMInterface) -> None: + llm_type = LLMType(llm) + assert llm_type.parse() == llm + + +def test_llm_type_with_config() -> None: + llm_type = LLMType( + LLMConfig( + class_="OpenAILLM", + params_={"model_name": "gpt-4o", "api_key": "my-api-key"}, + ) + ) + llm = llm_type.parse() + assert isinstance(llm, OpenAILLM) + + +def test_embedder_config() -> None: + config = EmbedderConfig.model_validate( + { + "class_": "OpenAIEmbeddings", + "params_": {"api_key": "my-api-key"}, + } + ) + assert config.class_ == "OpenAIEmbeddings" + assert config.get_module() == "neo4j_graphrag.embeddings" + assert config.get_interface() == Embedder + assert config.params_ == {"api_key": "my-api-key"} + d = config.parse() + assert isinstance(d, OpenAIEmbeddings) + + +def test_embedder_type_with_embedder(embedder: Embedder) -> None: + embedder_type = EmbedderType(embedder) + assert embedder_type.parse() == embedder + + +def test_embedder_type_with_config() -> None: + embedder_type = EmbedderType( + EmbedderConfig( + class_="OpenAIEmbeddings", + params_={"api_key": "my-api-key"}, + ) + ) + embedder = embedder_type.parse() + assert isinstance(embedder, OpenAIEmbeddings) diff --git a/tests/unit/experimental/pipeline/config/test_param_resolver.py b/tests/unit/experimental/pipeline/config/test_param_resolver.py new file mode 100644 index 00000000..efd4d7e9 --- /dev/null +++ b/tests/unit/experimental/pipeline/config/test_param_resolver.py @@ -0,0 +1,56 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest.mock import patch + +import pytest +from neo4j_graphrag.experimental.pipeline.config.param_resolver import ( + ParamFromEnvConfig, + ParamFromKeyConfig, +) + + +@patch.dict(os.environ, {"MY_KEY": "my_value"}, clear=True) +def test_env_param_config_happy_path() -> None: + resolver = ParamFromEnvConfig(var_="MY_KEY") + assert resolver.resolve({}) == "my_value" + + +@patch.dict(os.environ, {}, clear=True) +def test_env_param_config_missing_env_var() -> None: + resolver = ParamFromEnvConfig(var_="MY_KEY") + assert resolver.resolve({}) is None + + +def test_config_key_param_simple_key() -> None: + resolver = ParamFromKeyConfig(key_="my_key") + assert resolver.resolve({"my_key": "my_value"}) == "my_value" + + +def test_config_key_param_missing_key() -> None: + resolver = ParamFromKeyConfig(key_="my_key") + with pytest.raises(KeyError): + resolver.resolve({}) + + +def test_config_complex_key_param() -> None: + resolver = ParamFromKeyConfig(key_="my_key.my_sub_key") + assert resolver.resolve({"my_key": {"my_sub_key": "value"}}) == "value" + + +def test_config_complex_key_param_missing_subkey() -> None: + resolver = ParamFromKeyConfig(key_="my_key.my_sub_key") + with pytest.raises(KeyError): + resolver.resolve({"my_key": {}}) diff --git a/tests/unit/experimental/pipeline/config/test_pipeline_config.py b/tests/unit/experimental/pipeline/config/test_pipeline_config.py new file mode 100644 index 00000000..4de5874b --- /dev/null +++ b/tests/unit/experimental/pipeline/config/test_pipeline_config.py @@ -0,0 +1,378 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock, patch + +import neo4j +from neo4j_graphrag.embeddings import Embedder +from neo4j_graphrag.experimental.pipeline import Component +from neo4j_graphrag.experimental.pipeline.config.object_config import ( + ComponentConfig, + ComponentType, + Neo4jDriverConfig, + Neo4jDriverType, +) +from neo4j_graphrag.experimental.pipeline.config.param_resolver import ( + ParamFromEnvConfig, + ParamFromKeyConfig, +) +from neo4j_graphrag.experimental.pipeline.config.pipeline_config import ( + AbstractPipelineConfig, +) +from neo4j_graphrag.experimental.pipeline.types import ComponentDefinition +from neo4j_graphrag.llm import LLMInterface + + +@patch( + "neo4j_graphrag.experimental.pipeline.config.object_config.Neo4jDriverConfig.parse" +) +def test_abstract_pipeline_config_neo4j_config_is_a_dict_with_params_( + mock_neo4j_config: Mock, +) -> None: + mock_neo4j_config.return_value = "text" + config = AbstractPipelineConfig.model_validate( + { + "neo4j_config": { + "params_": { + "uri": "bolt://", + "user": "", + "password": "", + } + } + } + ) + assert isinstance(config.neo4j_config, dict) + assert "default" in config.neo4j_config + config.parse() + mock_neo4j_config.assert_called_once() + assert config._global_data["neo4j_config"]["default"] == "text" + + +@patch( + "neo4j_graphrag.experimental.pipeline.config.object_config.Neo4jDriverConfig.parse" +) +def test_abstract_pipeline_config_neo4j_config_is_a_dict_with_names( + mock_neo4j_config: Mock, +) -> None: + mock_neo4j_config.return_value = "text" + config = AbstractPipelineConfig.model_validate( + { + "neo4j_config": { + "my_driver": { + "params_": { + "uri": "bolt://", + "user": "", + "password": "", + } + } + } + } + ) + assert isinstance(config.neo4j_config, dict) + assert "my_driver" in config.neo4j_config + config.parse() + mock_neo4j_config.assert_called_once() + assert config._global_data["neo4j_config"]["my_driver"] == "text" + + +@patch( + "neo4j_graphrag.experimental.pipeline.config.object_config.Neo4jDriverConfig.parse" +) +def test_abstract_pipeline_config_neo4j_config_is_a_dict_with_driver( + mock_neo4j_config: Mock, driver: neo4j.Driver +) -> None: + config = AbstractPipelineConfig.model_validate( + { + "neo4j_config": { + "my_driver": driver, + } + } + ) + assert isinstance(config.neo4j_config, dict) + assert "my_driver" in config.neo4j_config + config.parse() + assert not mock_neo4j_config.called + assert config._global_data["neo4j_config"]["my_driver"] == driver + + +@patch( + "neo4j_graphrag.experimental.pipeline.config.object_config.Neo4jDriverConfig.parse" +) +def test_abstract_pipeline_config_neo4j_config_is_a_driver( + mock_neo4j_config: Mock, driver: neo4j.Driver +) -> None: + config = AbstractPipelineConfig.model_validate( + { + "neo4j_config": driver, + } + ) + assert isinstance(config.neo4j_config, dict) + assert "default" in config.neo4j_config + config.parse() + assert not mock_neo4j_config.called + assert config._global_data["neo4j_config"]["default"] == driver + + +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.LLMConfig.parse") +def test_abstract_pipeline_config_llm_config_is_a_dict_with_params_( + mock_llm_config: Mock, +) -> None: + mock_llm_config.return_value = "text" + config = AbstractPipelineConfig.model_validate( + {"llm_config": {"class_": "OpenAILLM", "params_": {"model_name": "gpt-4o"}}} + ) + assert isinstance(config.llm_config, dict) + assert "default" in config.llm_config + config.parse() + mock_llm_config.assert_called_once() + assert config._global_data["llm_config"]["default"] == "text" + + +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.LLMConfig.parse") +def test_abstract_pipeline_config_llm_config_is_a_dict_with_names( + mock_llm_config: Mock, +) -> None: + mock_llm_config.return_value = "text" + config = AbstractPipelineConfig.model_validate( + { + "llm_config": { + "my_llm": {"class_": "OpenAILLM", "params_": {"model_name": "gpt-4o"}} + } + } + ) + assert isinstance(config.llm_config, dict) + assert "my_llm" in config.llm_config + config.parse() + mock_llm_config.assert_called_once() + assert config._global_data["llm_config"]["my_llm"] == "text" + + +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.LLMConfig.parse") +def test_abstract_pipeline_config_llm_config_is_a_dict_with_llm( + mock_llm_config: Mock, llm: LLMInterface +) -> None: + config = AbstractPipelineConfig.model_validate( + { + "llm_config": { + "my_llm": llm, + } + } + ) + assert isinstance(config.llm_config, dict) + assert "my_llm" in config.llm_config + config.parse() + assert not mock_llm_config.called + assert config._global_data["llm_config"]["my_llm"] == llm + + +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.LLMConfig.parse") +def test_abstract_pipeline_config_llm_config_is_a_llm( + mock_llm_config: Mock, llm: LLMInterface +) -> None: + config = AbstractPipelineConfig.model_validate( + { + "llm_config": llm, + } + ) + assert isinstance(config.llm_config, dict) + assert "default" in config.llm_config + config.parse() + assert not mock_llm_config.called + assert config._global_data["llm_config"]["default"] == llm + + +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.EmbedderConfig.parse") +def test_abstract_pipeline_config_embedder_config_is_a_dict_with_params_( + mock_embedder_config: Mock, +) -> None: + mock_embedder_config.return_value = "text" + config = AbstractPipelineConfig.model_validate( + {"embedder_config": {"class_": "OpenAIEmbeddings", "params_": {}}} + ) + assert isinstance(config.embedder_config, dict) + assert "default" in config.embedder_config + config.parse() + mock_embedder_config.assert_called_once() + assert config._global_data["embedder_config"]["default"] == "text" + + +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.EmbedderConfig.parse") +def test_abstract_pipeline_config_embedder_config_is_a_dict_with_names( + mock_embedder_config: Mock, +) -> None: + mock_embedder_config.return_value = "text" + config = AbstractPipelineConfig.model_validate( + { + "embedder_config": { + "my_embedder": {"class_": "OpenAIEmbeddings", "params_": {}} + } + } + ) + assert isinstance(config.embedder_config, dict) + assert "my_embedder" in config.embedder_config + config.parse() + mock_embedder_config.assert_called_once() + assert config._global_data["embedder_config"]["my_embedder"] == "text" + + +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.EmbedderConfig.parse") +def test_abstract_pipeline_config_embedder_config_is_a_dict_with_llm( + mock_embedder_config: Mock, embedder: Embedder +) -> None: + config = AbstractPipelineConfig.model_validate( + { + "embedder_config": { + "my_embedder": embedder, + } + } + ) + assert isinstance(config.embedder_config, dict) + assert "my_embedder" in config.embedder_config + config.parse() + assert not mock_embedder_config.called + assert config._global_data["embedder_config"]["my_embedder"] == embedder + + +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.EmbedderConfig.parse") +def test_abstract_pipeline_config_embedder_config_is_an_embedder( + mock_embedder_config: Mock, embedder: Embedder +) -> None: + config = AbstractPipelineConfig.model_validate( + { + "embedder_config": embedder, + } + ) + assert isinstance(config.embedder_config, dict) + assert "default" in config.embedder_config + config.parse() + assert not mock_embedder_config.called + assert config._global_data["embedder_config"]["default"] == embedder + + +def test_abstract_pipeline_config_parse_global_data_no_extras(driver: Mock) -> None: + config = AbstractPipelineConfig( + neo4j_config={"my_driver": Neo4jDriverType(driver)}, + ) + gd = config._parse_global_data() + assert gd == { + "extras": {}, + "neo4j_config": { + "my_driver": driver, + }, + "llm_config": {}, + "embedder_config": {}, + } + + +@patch( + "neo4j_graphrag.experimental.pipeline.config.param_resolver.ParamFromEnvConfig.resolve" +) +def test_abstract_pipeline_config_parse_global_data_extras( + mock_param_resolver: Mock, +) -> None: + mock_param_resolver.return_value = "my value" + config = AbstractPipelineConfig( + extras={"my_extra_var": ParamFromEnvConfig(var_="some key")}, + ) + gd = config._parse_global_data() + assert gd == { + "extras": {"my_extra_var": "my value"}, + "neo4j_config": {}, + "llm_config": {}, + "embedder_config": {}, + } + + +@patch( + "neo4j_graphrag.experimental.pipeline.config.param_resolver.ParamFromEnvConfig.resolve" +) +@patch( + "neo4j_graphrag.experimental.pipeline.config.object_config.Neo4jDriverType.parse" +) +def test_abstract_pipeline_config_parse_global_data_use_extras_in_other_config( + mock_neo4j_parser: Mock, + mock_param_resolver: Mock, +) -> None: + """Parser is able to read variables in the 'extras' section of config + to instantiate another object (neo4j.Driver in this test case) + """ + mock_param_resolver.side_effect = ["bolt://myhost", "myuser", "mypwd"] + mock_neo4j_parser.return_value = "my driver" + config = AbstractPipelineConfig( + extras={ + "my_extra_uri": ParamFromEnvConfig(var_="some key"), + "my_extra_user": ParamFromEnvConfig(var_="some key"), + "my_extra_pwd": ParamFromEnvConfig(var_="some key"), + }, + neo4j_config={ + "my_driver": Neo4jDriverType( + Neo4jDriverConfig( + params_=dict( + uri=ParamFromKeyConfig(key_="extras.my_extra_uri"), + user=ParamFromKeyConfig(key_="extras.my_extra_user"), + password=ParamFromKeyConfig(key_="extras.my_extra_pwd"), + ) + ) + ) + }, + ) + gd = config._parse_global_data() + expected_extras = { + "my_extra_uri": "bolt://myhost", + "my_extra_user": "myuser", + "my_extra_pwd": "mypwd", + } + assert gd["extras"] == expected_extras + assert gd["neo4j_config"] == {"my_driver": "my driver"} + mock_neo4j_parser.assert_called_once_with({"extras": expected_extras}) + + +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse") +def test_abstract_pipeline_config_resolve_component_definition_no_run_params( + mock_component_parse: Mock, + component: Component, +) -> None: + mock_component_parse.return_value = component + config = AbstractPipelineConfig() + component_type = ComponentType(component) + component_definition = config._resolve_component_definition("name", component_type) + assert isinstance(component_definition, ComponentDefinition) + mock_component_parse.assert_called_once_with({}) + assert component_definition.name == "name" + assert component_definition.component == component + assert component_definition.run_params == {} + + +@patch( + "neo4j_graphrag.experimental.pipeline.config.pipeline_config.AbstractPipelineConfig.resolve_params" +) +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse") +def test_abstract_pipeline_config_resolve_component_definition_with_run_params( + mock_component_parse: Mock, + mock_resolve_params: Mock, + component: Component, +) -> None: + mock_component_parse.return_value = component + mock_resolve_params.return_value = {"param": "resolver param result"} + config = AbstractPipelineConfig() + component_type = ComponentType( + ComponentConfig(class_="", params_={}, run_params_={"param1": "value1"}) + ) + component_definition = config._resolve_component_definition("name", component_type) + assert isinstance(component_definition, ComponentDefinition) + mock_component_parse.assert_called_once_with({}) + assert component_definition.name == "name" + assert component_definition.component == component + assert component_definition.run_params == {"param": "resolver param result"} + mock_resolve_params.assert_called_once_with({"param1": "value1"}) diff --git a/tests/unit/experimental/pipeline/config/test_runner.py b/tests/unit/experimental/pipeline/config/test_runner.py new file mode 100644 index 00000000..327b5221 --- /dev/null +++ b/tests/unit/experimental/pipeline/config/test_runner.py @@ -0,0 +1,56 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock, patch + +from neo4j_graphrag.experimental.pipeline import Pipeline +from neo4j_graphrag.experimental.pipeline.config.pipeline_config import PipelineConfig +from neo4j_graphrag.experimental.pipeline.config.runner import PipelineRunner +from neo4j_graphrag.experimental.pipeline.types import PipelineDefinition + + +@patch("neo4j_graphrag.experimental.pipeline.pipeline.Pipeline.from_definition") +def test_pipeline_runner_from_def_empty(mock_from_definition: Mock) -> None: + mock_from_definition.return_value = Pipeline() + runner = PipelineRunner( + pipeline_definition=PipelineDefinition(components=[], connections=[]) + ) + assert runner.config is None + assert runner.pipeline is not None + assert runner.pipeline._nodes == {} + assert runner.pipeline._edges == [] + assert runner.run_params == {} + mock_from_definition.assert_called_once() + + +def test_pipeline_runner_from_config() -> None: + config = PipelineConfig(component_config={}, connection_config=[]) + runner = PipelineRunner.from_config(config) + assert runner.config is not None + assert runner.pipeline is not None + assert runner.pipeline._nodes == {} + assert runner.pipeline._edges == [] + assert runner.run_params == {} + + +@patch("neo4j_graphrag.experimental.pipeline.config.runner.PipelineRunner.from_config") +@patch("neo4j_graphrag.experimental.pipeline.config.config_reader.ConfigReader.read") +def test_pipeline_runner_from_config_file( + mock_read: Mock, mock_from_config: Mock +) -> None: + mock_read.return_value = {"dict": "with data"} + PipelineRunner.from_config_file("file.yaml") + + mock_read.assert_called_once_with("file.yaml") + mock_from_config.assert_called_once_with({"dict": "with data"}, do_cleaning=True) diff --git a/tests/unit/experimental/pipeline/test_kg_builder.py b/tests/unit/experimental/pipeline/test_kg_builder.py index b1b29151..bddbac50 100644 --- a/tests/unit/experimental/pipeline/test_kg_builder.py +++ b/tests/unit/experimental/pipeline/test_kg_builder.py @@ -18,10 +18,8 @@ import neo4j import pytest from neo4j_graphrag.embeddings import Embedder -from neo4j_graphrag.experimental.components.entity_relation_extractor import OnError from neo4j_graphrag.experimental.components.schema import ( SchemaEntity, - SchemaProperty, SchemaRelation, ) from neo4j_graphrag.experimental.components.types import LexicalGraphConfig @@ -31,136 +29,6 @@ from neo4j_graphrag.llm.base import LLMInterface -@mock.patch( - "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", - return_value=(5, 23, 0), -) -@pytest.mark.asyncio -async def test_knowledge_graph_builder_init_with_text(_: Mock) -> None: - llm = MagicMock(spec=LLMInterface) - driver = MagicMock(spec=neo4j.Driver) - embedder = MagicMock(spec=Embedder) - - kg_builder = SimpleKGPipeline( - llm=llm, - driver=driver, - embedder=embedder, - from_pdf=False, - ) - - assert kg_builder.llm == llm - assert kg_builder.driver == driver - assert kg_builder.embedder == embedder - assert kg_builder.from_pdf is False - assert kg_builder.entities == [] - assert kg_builder.relations == [] - assert kg_builder.potential_schema == [] - assert "pdf_loader" not in kg_builder.pipeline - - text_input = "May thy knife chip and shatter." - - with patch.object( - kg_builder.pipeline, - "run", - return_value=PipelineResult(run_id="test_run", result=None), - ) as mock_run: - await kg_builder.run_async(text=text_input) - mock_run.assert_called_once() - pipe_inputs = mock_run.call_args[0][0] - assert pipe_inputs["splitter"]["text"] == text_input - - -@mock.patch( - "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", - return_value=(5, 23, 0), -) -@pytest.mark.asyncio -async def test_knowledge_graph_builder_init_with_file_path(_: Mock) -> None: - llm = MagicMock(spec=LLMInterface) - driver = MagicMock(spec=neo4j.Driver) - embedder = MagicMock(spec=Embedder) - - kg_builder = SimpleKGPipeline( - llm=llm, - driver=driver, - embedder=embedder, - from_pdf=True, - ) - - assert kg_builder.llm == llm - assert kg_builder.driver == driver - assert kg_builder.from_pdf is True - assert kg_builder.entities == [] - assert kg_builder.relations == [] - assert kg_builder.potential_schema == [] - assert "pdf_loader" in kg_builder.pipeline - - file_path = "path/to/test.pdf" - - with patch.object( - kg_builder.pipeline, - "run", - return_value=PipelineResult(run_id="test_run", result=None), - ) as mock_run: - await kg_builder.run_async(file_path=file_path) - mock_run.assert_called_once() - pipe_inputs = mock_run.call_args[0][0] - assert pipe_inputs["pdf_loader"]["filepath"] == file_path - - -@mock.patch( - "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", - return_value=(5, 23, 0), -) -@pytest.mark.asyncio -async def test_knowledge_graph_builder_run_with_both_inputs(_: Mock) -> None: - llm = MagicMock(spec=LLMInterface) - driver = MagicMock(spec=neo4j.Driver) - embedder = MagicMock(spec=Embedder) - - kg_builder = SimpleKGPipeline( - llm=llm, - driver=driver, - embedder=embedder, - from_pdf=True, - ) - - text_input = "May thy knife chip and shatter." - file_path = "path/to/test.pdf" - - with pytest.raises(PipelineDefinitionError) as exc_info: - await kg_builder.run_async(file_path=file_path, text=text_input) - - assert "Expected 'file_path' argument when 'from_pdf' is True." in str( - exc_info.value - ) or "Expected 'text' argument when 'from_pdf' is False." in str(exc_info.value) - - -@mock.patch( - "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", - return_value=(5, 23, 0), -) -@pytest.mark.asyncio -async def test_knowledge_graph_builder_run_with_no_inputs(_: Mock) -> None: - llm = MagicMock(spec=LLMInterface) - driver = MagicMock(spec=neo4j.Driver) - embedder = MagicMock(spec=Embedder) - - kg_builder = SimpleKGPipeline( - llm=llm, - driver=driver, - embedder=embedder, - from_pdf=True, - ) - - with pytest.raises(PipelineDefinitionError) as exc_info: - await kg_builder.run_async() - - assert "Expected 'file_path' argument when 'from_pdf' is True." in str( - exc_info.value - ) or "Expected 'text' argument when 'from_pdf' is False." in str(exc_info.value) - - @mock.patch( "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", return_value=(5, 23, 0), @@ -181,13 +49,13 @@ async def test_knowledge_graph_builder_document_info_with_file(_: Mock) -> None: file_path = "path/to/test.pdf" with patch.object( - kg_builder.pipeline, + kg_builder.runner.pipeline, "run", return_value=PipelineResult(run_id="test_run", result=None), ) as mock_run: await kg_builder.run_async(file_path=file_path) - pipe_inputs = mock_run.call_args[0][0] + pipe_inputs = mock_run.call_args[1]["data"] assert "pdf_loader" in pipe_inputs assert pipe_inputs["pdf_loader"] == {"filepath": file_path} assert "extractor" not in pipe_inputs @@ -213,13 +81,13 @@ async def test_knowledge_graph_builder_document_info_with_text(_: Mock) -> None: text_input = "May thy knife chip and shatter." with patch.object( - kg_builder.pipeline, + kg_builder.runner.pipeline, "run", return_value=PipelineResult(run_id="test_run", result=None), ) as mock_run: await kg_builder.run_async(text=text_input) - pipe_inputs = mock_run.call_args[0][0] + pipe_inputs = mock_run.call_args[1]["data"] assert "splitter" in pipe_inputs assert pipe_inputs["splitter"] == {"text": text_input} @@ -248,51 +116,33 @@ async def test_knowledge_graph_builder_with_entities_and_file(_: Mock) -> None: from_pdf=True, ) - internal_entities = [SchemaEntity(label=label) for label in entities] - internal_relations = [SchemaRelation(label=label) for label in relations] - assert kg_builder.entities == internal_entities - assert kg_builder.relations == internal_relations - assert kg_builder.potential_schema == potential_schema + # assert kg_builder.entities == entities + # assert kg_builder.relations == relations + # assert kg_builder.potential_schema == potential_schema file_path = "path/to/test.pdf" + internal_entities = [SchemaEntity(label=label) for label in entities] + internal_relations = [SchemaRelation(label=label) for label in relations] + with patch.object( - kg_builder.pipeline, + kg_builder.runner.pipeline, "run", return_value=PipelineResult(run_id="test_run", result=None), ) as mock_run: await kg_builder.run_async(file_path=file_path) - pipe_inputs = mock_run.call_args[0][0] + pipe_inputs = mock_run.call_args[1]["data"] assert pipe_inputs["schema"]["entities"] == internal_entities assert pipe_inputs["schema"]["relations"] == internal_relations assert pipe_inputs["schema"]["potential_schema"] == potential_schema -@mock.patch( - "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", - return_value=(5, 23, 0), -) -def test_simple_kg_pipeline_on_error_conversion(_: Mock) -> None: - llm = MagicMock(spec=LLMInterface) - driver = MagicMock(spec=neo4j.Driver) - embedder = MagicMock(spec=Embedder) - - kg_builder = SimpleKGPipeline( - llm=llm, - driver=driver, - embedder=embedder, - on_error="RAISE", - ) - - assert kg_builder.on_error == OnError.RAISE - - def test_simple_kg_pipeline_on_error_invalid_value() -> None: llm = MagicMock(spec=LLMInterface) driver = MagicMock(spec=neo4j.Driver) embedder = MagicMock(spec=Embedder) - with pytest.raises(PipelineDefinitionError) as exc_info: + with pytest.raises(PipelineDefinitionError): SimpleKGPipeline( llm=llm, driver=driver, @@ -300,50 +150,6 @@ def test_simple_kg_pipeline_on_error_invalid_value() -> None: on_error="INVALID_VALUE", ) - assert "Expected one of ['RAISE', 'IGNORE']" in str(exc_info.value) - - -@mock.patch( - "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", - return_value=(5, 23, 0), -) -def test_simple_kg_pipeline_no_entity_resolution(_: Mock) -> None: - llm = MagicMock(spec=LLMInterface) - driver = MagicMock(spec=neo4j.Driver) - embedder = MagicMock(spec=Embedder) - - kg_builder = SimpleKGPipeline( - llm=llm, - driver=driver, - embedder=embedder, - on_error="IGNORE", - perform_entity_resolution=False, - ) - - assert "resolver" not in kg_builder.pipeline - - -@mock.patch( - "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", - return_value=(5, 23, 0), -) -@pytest.mark.asyncio -def test_simple_kg_pipeline_lexical_graph_config_attribute(_: Mock) -> None: - llm = MagicMock(spec=LLMInterface) - driver = MagicMock(spec=neo4j.Driver) - embedder = MagicMock(spec=Embedder) - - lexical_graph_config = LexicalGraphConfig() - kg_builder = SimpleKGPipeline( - llm=llm, - driver=driver, - embedder=embedder, - on_error="IGNORE", - lexical_graph_config=lexical_graph_config, - ) - - assert kg_builder.lexical_graph_config == lexical_graph_config - @mock.patch( "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", @@ -372,58 +178,14 @@ async def test_knowledge_graph_builder_with_lexical_graph_config(_: Mock) -> Non text_input = "May thy knife chip and shatter." with patch.object( - kg_builder.pipeline, + kg_builder.runner.pipeline, "run", return_value=PipelineResult(run_id="test_run", result=None), ) as mock_run: await kg_builder.run_async(text=text_input) - pipe_inputs = mock_run.call_args[0][0] + pipe_inputs = mock_run.call_args[1]["data"] assert "extractor" in pipe_inputs assert pipe_inputs["extractor"] == { "lexical_graph_config": lexical_graph_config } - - -def test_knowledge_graph_builder_to_schema_entity_method() -> None: - assert SimpleKGPipeline.to_schema_entity("EntityType") == SchemaEntity( - label="EntityType" - ) - assert SimpleKGPipeline.to_schema_entity({"label": "EntityType"}) == SchemaEntity( - label="EntityType" - ) - assert SimpleKGPipeline.to_schema_entity( - {"label": "EntityType", "description": "A special entity"} - ) == SchemaEntity(label="EntityType", description="A special entity") - assert SimpleKGPipeline.to_schema_entity( - {"label": "EntityType", "properties": []} - ) == SchemaEntity(label="EntityType") - assert SimpleKGPipeline.to_schema_entity( - { - "label": "EntityType", - "properties": [{"name": "entityProperty", "type": "DATE"}], - } - ) == SchemaEntity( - label="EntityType", - properties=[SchemaProperty(name="entityProperty", type="DATE")], - ) - - -def test_knowledge_graph_builder_to_schema_relation_method() -> None: - assert SimpleKGPipeline.to_schema_relation("REL_TYPE") == SchemaRelation( - label="REL_TYPE" - ) - assert SimpleKGPipeline.to_schema_relation({"label": "REL_TYPE"}) == SchemaRelation( - label="REL_TYPE" - ) - assert SimpleKGPipeline.to_schema_relation( - {"label": "REL_TYPE", "description": "A rel type"} - ) == SchemaRelation(label="REL_TYPE", description="A rel type") - assert SimpleKGPipeline.to_schema_relation( - {"label": "REL_TYPE", "properties": []} - ) == SchemaRelation(label="REL_TYPE") - assert SimpleKGPipeline.to_schema_relation( - {"label": "REL_TYPE", "properties": [{"name": "relProperty", "type": "DATE"}]} - ) == SchemaRelation( - label="REL_TYPE", properties=[SchemaProperty(name="relProperty", type="DATE")] - )