diff --git a/CHANGELOG.md b/CHANGELOG.md
index 7254d570..2ac264f0 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,7 @@
## Added
- Integrated json-repair package to handle and repair invalid JSON generated by LLMs.
- Introduced InvalidJSONError exception for handling cases where JSON repair fails.
+- Ability to create a Pipeline or SimpleKGPipeline from a config file. See [the example](examples/build_graph/from_config_files/simple_kg_pipeline_from_config_file.py).
## Changed
- Updated LLM prompts to include stricter instructions for generating valid JSON.
diff --git a/docs/source/api.rst b/docs/source/api.rst
index d1b1066d..7d46b3f0 100644
--- a/docs/source/api.rst
+++ b/docs/source/api.rst
@@ -9,17 +9,18 @@ API Documentation
Components
**********
-KGWriter
-========
+DataLoader
+==========
-.. autoclass:: neo4j_graphrag.experimental.components.kg_writer.KGWriter
- :members: run
+.. autoclass:: neo4j_graphrag.experimental.components.pdf_loader.DataLoader
+ :members: run, get_document_metadata
-Neo4jWriter
-===========
-.. autoclass:: neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter
- :members: run
+PdfLoader
+=========
+
+.. autoclass:: neo4j_graphrag.experimental.components.pdf_loader.PdfLoader
+ :members: run, load_file
TextSplitter
============
@@ -85,6 +86,17 @@ LLMEntityRelationExtractor
.. autoclass:: neo4j_graphrag.experimental.components.entity_relation_extractor.LLMEntityRelationExtractor
:members: run
+KGWriter
+========
+
+.. autoclass:: neo4j_graphrag.experimental.components.kg_writer.KGWriter
+ :members: run
+
+Neo4jWriter
+===========
+
+.. autoclass:: neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter
+ :members: run
SinglePropertyExactMatchResolver
================================
@@ -112,6 +124,23 @@ SimpleKGPipeline
:members: run_async
+************
+Config files
+************
+
+
+SimpleKGPipelineConfig
+======================
+
+.. autoclass:: neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig
+
+
+PipelineRunner
+==============
+
+.. autoclass:: neo4j_graphrag.experimental.pipeline.config.runner.PipelineRunner
+
+
.. _retrievers-section:
**********
diff --git a/docs/source/images/kg_builder_pipeline.png b/docs/source/images/kg_builder_pipeline.png
index b8faf2ba..935c2a12 100644
Binary files a/docs/source/images/kg_builder_pipeline.png and b/docs/source/images/kg_builder_pipeline.png differ
diff --git a/docs/source/types.rst b/docs/source/types.rst
index 6322b29b..253994ad 100644
--- a/docs/source/types.rst
+++ b/docs/source/types.rst
@@ -82,3 +82,62 @@ SchemaConfig
============
.. autoclass:: neo4j_graphrag.experimental.components.schema.SchemaConfig
+
+LexicalGraphConfig
+===================
+
+.. autoclass:: neo4j_graphrag.experimental.components.types.LexicalGraphConfig
+
+
+Neo4jDriverType
+===============
+
+.. autoclass:: neo4j_graphrag.experimental.pipeline.config.object_config.Neo4jDriverType
+
+
+Neo4jDriverConfig
+=================
+
+.. autoclass:: neo4j_graphrag.experimental.pipeline.config.object_config.Neo4jDriverConfig
+
+
+LLMType
+=======
+
+.. autoclass:: neo4j_graphrag.experimental.pipeline.config.object_config.LLMType
+
+
+LLMConfig
+=========
+
+.. autoclass:: neo4j_graphrag.experimental.pipeline.config.object_config.LLMConfig
+
+
+EmbedderType
+============
+
+.. autoclass:: neo4j_graphrag.experimental.pipeline.config.object_config.EmbedderType
+
+
+EmbedderConfig
+==============
+
+.. autoclass:: neo4j_graphrag.experimental.pipeline.config.object_config.EmbedderConfig
+
+
+ComponentType
+=============
+
+.. autoclass:: neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType
+
+
+ComponentConfig
+===============
+
+.. autoclass:: neo4j_graphrag.experimental.pipeline.config.object_config.ComponentConfig
+
+
+ParamFromEnvConfig
+==================
+
+.. autoclass:: neo4j_graphrag.experimental.pipeline.config.param_resolver.ParamFromEnvConfig
diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst
index b758767a..87ce2f42 100644
--- a/docs/source/user_guide_kg_builder.rst
+++ b/docs/source/user_guide_kg_builder.rst
@@ -18,11 +18,11 @@ Pipeline Structure
A Knowledge Graph (KG) construction pipeline requires a few components (some of the below components are optional):
-- **Document parser**: extract text from files (PDFs, ...).
-- **Document chunker**: split the text into smaller pieces of text, manageable by the LLM context window (token limit).
+- **Data loader**: extract text from files (PDFs, ...).
+- **Text splitter**: split the text into smaller pieces of text (chunks), manageable by the LLM context window (token limit).
- **Chunk embedder** (optional): compute the chunk embeddings.
- **Schema builder**: provide a schema to ground the LLM extracted entities and relations and obtain an easily navigable KG.
-- **LexicalGraphBuilder**: build the lexical graph (Document, Chunk and their relationships) (optional).
+- **Lexical graph builder**: build the lexical graph (Document, Chunk and their relationships) (optional).
- **Entity and relation extractor**: extract relevant entities and relations from the text.
- **Knowledge Graph writer**: save the identified entities and relations.
- **Entity resolver**: merge similar entities into a single node.
@@ -34,7 +34,486 @@ A Knowledge Graph (KG) construction pipeline requires a few components (some of
This package contains the interface and implementations for each of these components, which are detailed in the following sections.
To see an end-to-end example of a Knowledge Graph construction pipeline,
-refer to the `example folder `_ in the project GitHub repository.
+visit the `example folder `_
+in the project's GitHub repository.
+
+
+******************
+Simple KG Pipeline
+******************
+
+The simplest way to begin building a KG from unstructured data using this package
+is utilizing the `SimpleKGPipeline` interface:
+
+.. code:: python
+
+ from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
+
+ kg_builder = SimpleKGPipeline(
+ llm=llm, # an LLMInterface for Entity and Relation extraction
+ driver=neo4j_driver, # a neo4j driver to write results to graph
+ embedder=embedder, # an Embedder for chunks
+ from_pdf=True, # set to False if parsing an already extracted text
+ )
+ await kg_builder.run_async(file_path=str(file_path))
+ # await kg_builder.run_async(text="my text") # if using from_pdf=False
+
+
+See:
+
+- :ref:`Using Another LLM Model` to learn how to instantiate the `llm`
+- :ref:`Embedders` to learn how to instantiate the `embedder`
+
+
+The following section outlines the configuration parameters for this class.
+
+Customizing the SimpleKGPipeline
+================================
+
+Graph Schema
+------------
+
+It is possible to guide the LLM by supplying a list of entities, relationships,
+and instructions on how to connect them. However, note that the extracted graph
+may not fully adhere to these guidelines. Entities and relationships can be
+represented as either simple strings (for their labels) or dictionaries. If using
+a dictionary, it must include a label key and can optionally include description
+and properties keys, as shown below:
+
+.. code:: python
+
+ ENTITIES = [
+ # entities can be defined with a simple label...
+ "Person",
+ # ... or with a dict if more details are needed,
+ # such as a description:
+ {"label": "House", "description": "Family the person belongs to"},
+ # or a list of properties the LLM will try to attach to the entity:
+ {"label": "Planet", "properties": [{"name": "weather", "type": "STRING"}]},
+ ]
+ # same thing for relationships:
+ RELATIONS = [
+ "PARENT_OF",
+ {
+ "label": "HEIR_OF",
+ "description": "Used for inheritor relationship between father and sons",
+ },
+ {"label": "RULES", "properties": [{"name": "fromYear", "type": "INTEGER"}]},
+ ]
+
+The `potential_schema` is defined by a list of triplet in the format:
+`(source_node_label, relationship_label, target_node_label)`. For instance:
+
+
+.. code:: python
+
+ POTENTIAL_SCHEMA = [
+ ("Person", "PARENT_OF", "Person"),
+ ("Person", "HEIR_OF", "House"),
+ ("House", "RULES", "Planet"),
+ ]
+
+This schema information can be provided to the `SimpleKGBuilder` as demonstrated below:
+
+.. code:: python
+
+ kg_builder = SimpleKGPipeline(
+ # ...
+ entities=ENTITIES,
+ relations=RELATIONS,
+ potential_schema=POTENTIAL_SCHEMA,
+ # ...
+ )
+
+Prompt Template, Lexical Graph Config and Error Behavior
+--------------------------------------------------------
+
+These parameters are part of the `EntityAndRelationExtractor` component.
+For detailed information, refer to the section on :ref:`Entity and Relation Extractor`.
+They are also accessible via the `SimpleKGPipeline` interface.
+
+.. code:: python
+
+ kg_builder = SimpleKGPipeline(
+ # ...
+ prompt_template="",
+ lexical_graph_config=my_config,
+ on_error="RAISE",
+ # ...
+ )
+
+Skip Entity Resolution
+----------------------
+
+By default, after each run, an Entity Resolution step is performed to merge nodes
+that share the same label and name property. To disable this behavior, adjust
+the following parameter:
+
+.. code:: python
+
+ kg_builder = SimpleKGPipeline(
+ # ...
+ perform_entity_resolution=False,
+ # ...
+ )
+
+Neo4j Database
+--------------
+
+To write to a non-default Neo4j database, specify the database name using this parameter:
+
+.. code:: python
+
+ kg_builder = SimpleKGPipeline(
+ # ...
+ neo4j_database="myDb",
+ # ...
+ )
+
+Using Custom Components
+-----------------------
+
+For advanced customization or when using a custom implementation, you can pass
+instances of specific components to the `SimpleKGPipeline`. The components that can
+customized at the moment are:
+
+- `text_splitter`: must be an instance of :ref:`TextSplitter`
+- `pdf_loader`: must be an instance of :ref:`PdfLoader`
+- `kg_writer`: must be an instance of :ref:`KGWriter`
+
+For instance, the following code can be used to customize the chunk size and
+chunk overlap in the text splitter component:
+
+.. code:: python
+
+ from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
+ FixedSizeSplitter,
+ )
+
+ text_splitter = FixedSizeSplitter(chunk_size=500, chunk_overlap=100)
+
+ kg_builder = SimpleKGPipeline(
+ # ...
+ text_splitter=text_splitter,
+ # ...
+ )
+
+
+Using a Config file
+===================
+
+.. code:: python
+
+ from neo4j_graphrag.experimental.pipeline.config.runner import PipelineRunner
+
+ file_path = "my_config.json"
+
+ pipeline = PipelineRunner.from_config_file(file_path)
+ await pipeline.run({"text": "my text"})
+
+
+The config file can be written in either JSON or YAML format.
+
+Here is an example of a base configuration file in JSON format:
+
+.. code:: json
+
+ {
+ "version_": 1,
+ "template_": "SimpleKGPipeline",
+ "neo4j_config": {},
+ "llm_config": {},
+ "embedder_config": {}
+ }
+
+And like this in YAML:
+
+.. code:: yaml
+
+ version_: 1
+ template_: SimpleKGPipeline
+ neo4j_config:
+ llm_config:
+ embedder_config:
+
+
+Defining a Neo4j Driver
+-----------------------
+
+Below is an example of configuring a Neo4j driver in a JSON configuration file:
+
+.. code:: json
+
+ {
+ "neo4j_config": {
+ "params_": {
+ "uri": "bolt://...",
+ "user": "neo4j",
+ "password": "password"
+ }
+ }
+ }
+
+Same for YAML:
+
+.. code:: yaml
+
+ neo4j_config:
+ params_:
+ uri: bolt://
+ user: neo4j
+ password: password
+
+In some cases, it may be necessary to avoid hard-coding sensitive values,
+such as passwords or API keys, to ensure security. To address this, the configuration
+parser supports parameter resolution methods.
+
+Parameter resolution
+--------------------
+
+To instruct the configuration parser to read a parameter from an environment variable,
+use the following syntax:
+
+.. code:: json
+
+ {
+ "neo4j_config": {
+ "params_": {
+ "uri": "bolt://...",
+ "user": "neo4j",
+ "password": {
+ "resolver_": "ENV",
+ "var_": "NEO4J_PASSWORD"
+ }
+ }
+ }
+ }
+
+And for YAML:
+
+.. code:: yaml
+
+ neo4j_config:
+ params_:
+ uri: bolt://
+ user: neo4j
+ password:
+ resolver_: ENV
+ var_: NEO4J_PASSWORD
+
+- The `resolver_=ENV` key is mandatory and its value cannot be altered.
+- The `var_` key specifies the name of the environment variable to be read.
+
+This syntax can be applied to all parameters.
+
+
+Defining an LLM
+----------------
+
+Below is an example of configuring an LLM in a JSON configuration file:
+
+.. code:: json
+
+ {
+ "llm_config": {
+ "class_": "OpenAILLM",
+ "params_": {
+ "mode_name": "gpt-4o",
+ "api_key": {
+ "resolver_": "ENV",
+ "var_": "OPENAI_API_KEY",
+ },
+ "model_params": {
+ "temperature": 0,
+ "max_tokens": 2000,
+ "response_format": {"type": "json_object"}
+ }
+ }
+ }
+ }
+
+And the equivalent YAML:
+
+.. code:: yaml
+
+ llm_config:
+ class_: OpenAILLM
+ params_:
+ model_name: gpt-4o
+ api_key:
+ resolver_: ENV
+ var_: OPENAI_API_KEY
+ model_params:
+ temperature: 0
+ max_tokens: 2000
+ response_format:
+ type: json_object
+
+- The `class_` key specifies the path to the class to be instantiated.
+- The `params_` key contains the parameters to be passed to the class constructor.
+
+When using an LLM implementation provided by this package, the full path in the `class_` key
+can be omitted (the parser will automatically import from `neo4j_graphrag.llm`).
+For custom implementations, the full path must be explicitly specified,
+for example: `my_package.my_llm.MyLLM`.
+
+Defining an Embedder
+--------------------
+
+The same principles apply to `embedder_config`:
+
+.. code:: json
+
+ {
+ "embedder_config": {
+ "class_": "OpenAIEmbeddings",
+ "params_": {
+ "mode": "text-embedding-ada-002",
+ "api_key": {
+ "resolver_": "ENV",
+ "var_": "OPENAI_API_KEY",
+ }
+ }
+ }
+ }
+
+Or the YAML version:
+
+.. code:: yaml
+
+ embedder_config:
+ class_: OpenAIEmbeddings
+ params_:
+ api_key:
+ resolver_: ENV
+ var_: OPENAI_API_KEY
+
+- For embedder implementations from this package, the full path can be omitted in the `class_` key (the parser will import from `neo4j_graphrag.embeddings`).
+- For custom implementations, the full path must be provided, for example: `my_package.my_embedding.MyEmbedding`.
+
+
+Other configuration
+-------------------
+
+The other parameters exposed in the :ref:`SimpleKGPipeline` can also be configured
+within the configuration file.
+
+.. code:: json
+
+ {
+ "from_pdf": false,
+ "perform_entity_resolution": true,
+ "neo4j_database": "myDb",
+ "on_error": "IGNORE",
+ "prompt_template": "...",
+ "entities": [
+ "Person",
+ {
+ "label": "House",
+ "description": "Family the person belongs to",
+ "properties": [
+ {"name": "name", "type": "STRING"}
+ ]
+ },
+ {
+ "label": "Planet",
+ "properties": [
+ {"name": "name", "type": "STRING"},
+ {"name": "weather", "type": "STRING"}
+ ]
+ }
+ ],
+ "relations": [
+ "PARENT_OF",
+ {
+ "label": "HEIR_OF",
+ "description": "Used for inheritor relationship between father and sons"
+ },
+ {
+ "label": "RULES",
+ "properties": [
+ {"name": "fromYear", "type": "INTEGER"}
+ ]
+ }
+ ],
+ "potential_schema": [
+ ["Person", "PARENT_OF", "Person"],
+ ["Person", "HEIR_OF", "House"],
+ ["House", "RULES", "Planet"]
+ ],
+ "lexical_graph_config": {
+ "chunk_node_label": "TextPart"
+ }
+ }
+
+
+or in YAML:
+
+.. code:: yaml
+
+ from_pdf: false
+ perform_entity_resolution: true
+ neo4j_database: myDb
+ on_error: IGNORE
+ prompt_template: ...
+ entities:
+ - label: Person
+ - label: House
+ description: Family the person belongs to
+ properties:
+ - name: name
+ type: STRING
+ - label: Planet
+ properties:
+ - name: name
+ type: STRING
+ - name: weather
+ type: STRING
+ relations:
+ - label: PARENT_OF
+ - label: HEIR_OF
+ description: Used for inheritor relationship between father and sons
+ - label: RULES
+ properties:
+ - name: fromYear
+ type: INTEGER
+ potential_schema:
+ - ["Person", "PARENT_OF", "Person"]
+ - ["Person", "HEIR_OF", "House"]
+ - ["House", "RULES", "Planet"]
+ lexical_graph_config:
+ chunk_node_label: TextPart
+
+
+It is also possible to further customize components, with a syntax similar to the one
+used for `llm_config` or `embedder_config`:
+
+.. code:: json
+
+ {
+ "text_splitter": {
+ "class_": "text_splitters.FixedSizeSplitter",
+ "params_": {
+ "chunk_size": 500,
+ "chunk_overlap": 100
+ }
+ }
+
+ }
+
+The YAML equivalent:
+
+.. code:: yaml
+
+ text_splitter:
+ class_: text_splitters.fixed_size_splitter.FixedSizeSplitter
+ params_:
+ chunk_size: 100
+ chunk_overlap: 10
+
+The `neo4j_graphrag.experimental.components` prefix will be appended automatically
+if needed.
+
**********************************
Knowledge Graph Builder Components
@@ -63,10 +542,10 @@ They can also be used within a pipeline:
pipeline.add_component(my_component, "component_name")
-Document Parser
-===============
+Data Loader
+============
-Document parsers start from a file path and return the text extracted from this file.
+Data loaders start from a file path and return the text extracted from this file.
This package currently supports text extraction from PDFs:
@@ -92,8 +571,8 @@ To implement your own loader, use the `DataLoader` interface:
-Document Splitter
-=================
+Text Splitter
+==============
Document splitters, as the name indicate, split documents into smaller chunks
that can be processed within the LLM token limits:
diff --git a/docs/source/user_guide_rag.rst b/docs/source/user_guide_rag.rst
index d939ed85..2b2b528b 100644
--- a/docs/source/user_guide_rag.rst
+++ b/docs/source/user_guide_rag.rst
@@ -397,7 +397,7 @@ However, in most cases, a text (from the user) will be provided instead of a vec
In this scenario, an `Embedder` is required.
Search Similar Text
------------------------------
+--------------------
When searching for a text, specifying how the retriever transforms (embeds) the text
into a vector is required. Therefore, the retriever requires knowledge of an embedder:
@@ -418,7 +418,7 @@ into a vector is required. Therefore, the retriever requires knowledge of an emb
Embedders
------------------------------
+---------
Currently, this package supports the following embedders:
@@ -485,7 +485,7 @@ using the `return_properties` parameter:
Pre-Filters
------------------------------
+-----------
When performing a similarity search, one may have constraints to apply.
For instance, filtering out movies released before 2000. This can be achieved
diff --git a/examples/README.md b/examples/README.md
index 95f9443a..2faed5f8 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -14,6 +14,7 @@ are listed in [the last section of this file](#customize).
- [End to end PDF to graph simple pipeline](build_graph/simple_kg_builder_from_pdf.py)
- [End to end text to graph simple pipeline](build_graph/simple_kg_builder_from_text.py)
+- [Build KG pipeline from config file](build_graph/from_config_files/simple_kg_pipeline_from_config_file.py)
## Retrieve
@@ -93,6 +94,7 @@ are listed in [the last section of this file](#customize).
- [End to end example with explicit components and PDF input](./customize/build_graph/pipeline/kg_builder_from_pdf.py)
- [Process multiple documents](./customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py)
- [Export lexical graph creation into another pipeline](./customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py)
+- [Build pipeline from config file](customize/build_graph/pipeline/from_config_files/pipeline_from_config_file.py)
#### Components
diff --git a/examples/build_graph/from_config_files/simple_kg_pipeline_config.json b/examples/build_graph/from_config_files/simple_kg_pipeline_config.json
new file mode 100644
index 00000000..ef251624
--- /dev/null
+++ b/examples/build_graph/from_config_files/simple_kg_pipeline_config.json
@@ -0,0 +1,112 @@
+{
+ "version_": "1",
+ "template_": "SimpleKGPipeline",
+ "neo4j_config": {
+ "params_": {
+ "uri": {
+ "resolver_": "ENV",
+ "var_": "NEO4J_URI"
+ },
+ "user": {
+ "resolver_": "ENV",
+ "var_": "NEO4J_USER"
+ },
+ "password": {
+ "resolver_": "ENV",
+ "var_": "NEO4J_PASSWORD"
+ }
+ }
+ },
+ "llm_config": {
+ "class_": "OpenAILLM",
+ "params_": {
+ "api_key": {
+ "resolver_": "ENV",
+ "var_": "OPENAI_API_KEY"
+ },
+ "model_name": "gpt-4o",
+ "model_params": {
+ "temperature": 0,
+ "max_tokens": 2000,
+ "response_format": {"type": "json_object"}
+ }
+ }
+ },
+ "embedder_config": {
+ "class_": "OpenAIEmbeddings",
+ "params_": {
+ "api_key": {
+ "resolver_": "ENV",
+ "var_": "OPENAI_API_KEY"
+ }
+ }
+ },
+ "from_pdf": false,
+ "entities": [
+ "Person",
+ {
+ "label": "House",
+ "description": "Family the person belongs to",
+ "properties": [
+ {
+ "name": "name",
+ "type": "STRING"
+ }
+ ]
+ },
+ {
+ "label": "Planet",
+ "properties": [
+ {
+ "name": "name",
+ "type": "STRING"
+ },
+ {
+ "name": "weather",
+ "type": "STRING"
+ }
+ ]
+ }
+ ],
+ "relations": [
+ "PARENT_OF",
+ {
+ "label": "HEIR_OF",
+ "description": "Used for inheritor relationship between father and sons"
+ },
+ {
+ "label": "RULES",
+ "properties": [
+ {
+ "name": "fromYear",
+ "type": "INTEGER"
+ }
+ ]
+ }
+ ],
+ "potential_schema": [
+ [
+ "Person",
+ "PARENT_OF",
+ "Person"
+ ],
+ [
+ "Person",
+ "HEIR_OF",
+ "House"
+ ],
+ [
+ "House",
+ "RULES",
+ "Planet"
+ ]
+ ],
+ "text_splitter": {
+ "class_": "text_splitters.fixed_size_splitter.FixedSizeSplitter",
+ "params_": {
+ "chunk_size": 100,
+ "chunk_overlap": 10
+ }
+ },
+ "perform_entity_resolution": true
+}
diff --git a/examples/build_graph/from_config_files/simple_kg_pipeline_config.yaml b/examples/build_graph/from_config_files/simple_kg_pipeline_config.yaml
new file mode 100644
index 00000000..8917e8ca
--- /dev/null
+++ b/examples/build_graph/from_config_files/simple_kg_pipeline_config.yaml
@@ -0,0 +1,63 @@
+version_: "1"
+template_: SimpleKGPipeline
+neo4j_config:
+ params_:
+ uri:
+ resolver_: ENV
+ var_: NEO4J_URI
+ user:
+ resolver_: ENV
+ var_: NEO4J_USER
+ password:
+ resolver_: ENV
+ var_: NEO4J_PASSWORD
+llm_config:
+ class_: OpenAILLM
+ params_:
+ api_key:
+ resolver_: ENV
+ var_: OPENAI_API_KEY
+ model_name: gpt-4o
+ model_params:
+ temperature: 0
+ max_tokens: 2000
+ response_format:
+ type: json_object
+embedder_config:
+ class_: OpenAIEmbeddings
+ params_:
+ api_key:
+ resolver_: ENV
+ var_: OPENAI_API_KEY
+from_pdf: false
+entities:
+ - label: Person
+ - label: House
+ description: Family the person belongs to
+ properties:
+ - name: name
+ type: STRING
+ - label: Planet
+ properties:
+ - name: name
+ type: STRING
+ - name: weather
+ type: STRING
+relations:
+ - label: PARENT_OF
+ - label: HEIR_OF
+ description: Used for inheritor relationship between father and sons
+ - label: RULES
+ properties:
+ - name: fromYear
+ type: INTEGER
+potential_schema:
+ - ["Person", "PARENT_OF", "Person"]
+ - ["Person", "HEIR_OF", "House"]
+ - ["House", "RULES", "Planet"]
+text_splitter:
+ class_: text_splitters.fixed_size_splitter.FixedSizeSplitter
+ params_:
+ chunk_size: 100
+ chunk_overlap: 10
+perform_entity_resolution: true
diff --git a/examples/build_graph/from_config_files/simple_kg_pipeline_from_config_file.py b/examples/build_graph/from_config_files/simple_kg_pipeline_from_config_file.py
new file mode 100644
index 00000000..62ba6c85
--- /dev/null
+++ b/examples/build_graph/from_config_files/simple_kg_pipeline_from_config_file.py
@@ -0,0 +1,47 @@
+"""In this example, the pipeline is defined in a JSON ('simple_kg_pipeline_config.json')
+or YAML ('simple_kg_pipeline_config.yaml') file.
+
+According to the configuration file, some parameters will be read from the env vars
+(Neo4j credentials and the OpenAI API key).
+"""
+
+import asyncio
+import logging
+
+## If env vars are in a .env file, uncomment:
+## (requires pip install python-dotenv)
+# from dotenv import load_dotenv
+# load_dotenv()
+# env vars manually set for testing:
+import os
+from pathlib import Path
+
+from neo4j_graphrag.experimental.pipeline.config.runner import PipelineRunner
+from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
+
+logging.basicConfig()
+logging.getLogger("neo4j_graphrag").setLevel(logging.DEBUG)
+
+os.environ["NEO4J_URI"] = "bolt://localhost:7687"
+os.environ["NEO4J_USER"] = "neo4j"
+os.environ["NEO4J_PASSWORD"] = "password"
+# os.environ["OPENAI_API_KEY"] = "sk-..."
+
+
+root_dir = Path(__file__).parent
+file_path = root_dir / "simple_kg_pipeline_config.yaml"
+# file_path = root_dir / "simple_kg_pipeline_config.json"
+
+
+# Text to process
+TEXT = """The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House Atreides,
+an aristocratic family that rules the planet Caladan, the rainy planet, since 10191."""
+
+
+async def main() -> PipelineResult:
+ pipeline = PipelineRunner.from_config_file(file_path)
+ return await pipeline.run({"text": TEXT})
+
+
+if __name__ == "__main__":
+ print(asyncio.run(main()))
diff --git a/examples/customize/build_graph/pipeline/from_config_files/pipeline_config.json b/examples/customize/build_graph/pipeline/from_config_files/pipeline_config.json
new file mode 100644
index 00000000..ce815412
--- /dev/null
+++ b/examples/customize/build_graph/pipeline/from_config_files/pipeline_config.json
@@ -0,0 +1,68 @@
+{
+ "version_": "1",
+ "template_": "none",
+ "name": "",
+ "neo4j_config": {
+ "params_": {
+ "uri": {
+ "resolver_": "ENV",
+ "var_": "NEO4J_URI"
+ },
+ "user": {
+ "resolver_": "ENV",
+ "var_": "NEO4J_USER"
+ },
+ "password": {
+ "resolver_": "ENV",
+ "var_": "NEO4J_PASSWORD"
+ }
+ }
+ },
+ "extras": {
+ "database": "neo4j"
+ },
+ "component_config": {
+ "splitter": {
+ "class_": "text_splitters.fixed_size_splitter.FixedSizeSplitter"
+ },
+ "builder": {
+ "class_": "lexical_graph.LexicalGraphBuilder",
+ "params_": {
+ "config": {
+ "chunk_node_label": "TextPart"
+ }
+ }
+ },
+ "writer": {
+ "name_": "writer",
+ "class_": "kg_writer.Neo4jWriter",
+ "params_": {
+ "driver": {
+ "resolver_": "CONFIG_KEY",
+ "key_": "neo4j_config.default"
+ },
+ "neo4j_database": {
+ "resolver_": "CONFIG_KEY",
+ "key_": "extras.database"
+ }
+ }
+ }
+ },
+ "connection_config": [
+ {
+ "start": "splitter",
+ "end": "builder",
+ "input_config": {
+ "text_chunks": "splitter"
+ }
+ },
+ {
+ "start": "builder",
+ "end": "writer",
+ "input_config": {
+ "graph": "builder.graph",
+ "lexical_graph_config": "builder.config"
+ }
+ }
+ ]
+}
diff --git a/examples/customize/build_graph/pipeline/from_config_files/pipeline_config.yaml b/examples/customize/build_graph/pipeline/from_config_files/pipeline_config.yaml
new file mode 100644
index 00000000..87ac905e
--- /dev/null
+++ b/examples/customize/build_graph/pipeline/from_config_files/pipeline_config.yaml
@@ -0,0 +1,45 @@
+version_: "1"
+template_: none
+neo4j_config:
+ params_:
+ uri:
+ resolver_: ENV
+ var_: NEO4J_URI
+ user:
+ resolver_: ENV
+ var_: NEO4J_USER
+ password:
+ resolver_: ENV
+ var_: NEO4J_PASSWORD
+extras:
+ database: neo4j
+component_config:
+ splitter:
+ class_: text_splitters.fixed_size_splitter.FixedSizeSplitter
+ params_:
+ chunk_size: 100
+ chunk_overlap: 10
+ builder:
+ class_: lexical_graph.LexicalGraphBuilder
+ params_:
+ config:
+ chunk_node_label: TextPart
+ writer:
+ class_: kg_writer.Neo4jWriter
+ params_:
+ driver:
+ resolver_: CONFIG_KEY
+ key_: neo4j_config.default
+ neo4j_database:
+ resolver_: CONFIG_KEY
+ key_: extras.database
+connection_config:
+ - start: splitter
+ end: builder
+ input_config:
+ text_chunks: splitter
+ - start: builder
+ end: writer
+ input_config:
+ graph: builder.graph
+ lexical_graph_config: builder.config
diff --git a/examples/customize/build_graph/pipeline/from_config_files/pipeline_from_config_file.py b/examples/customize/build_graph/pipeline/from_config_files/pipeline_from_config_file.py
new file mode 100644
index 00000000..9a2a1680
--- /dev/null
+++ b/examples/customize/build_graph/pipeline/from_config_files/pipeline_from_config_file.py
@@ -0,0 +1,40 @@
+"""In this example, the pipeline is defined in a JSON file 'pipeline_config.json'.
+According to the configuration file, some parameters will be read from the env vars
+(Neo4j credentials and the OpenAI API key).
+"""
+
+import asyncio
+
+## If env vars are in a .env file, uncomment:
+## (requires pip install python-dotenv)
+# from dotenv import load_dotenv
+# load_dotenv()
+# env vars manually set for testing:
+import os
+from pathlib import Path
+
+from neo4j_graphrag.experimental.pipeline.config.runner import PipelineRunner
+from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
+
+os.environ["NEO4J_URI"] = "bolt://localhost:7687"
+os.environ["NEO4J_USER"] = "neo4j"
+os.environ["NEO4J_PASSWORD"] = "password"
+# os.environ["OPENAI_API_KEY"] = "sk-..."
+
+
+root_dir = Path(__file__).parent
+# file_path = root_dir / "pipeline_config.json"
+file_path = root_dir / "pipeline_config.yaml"
+
+# Text to process
+TEXT = """The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House Atreides,
+an aristocratic family that rules the planet Caladan, the rainy planet, since 10191."""
+
+
+async def main() -> PipelineResult:
+ pipeline = PipelineRunner.from_config_file(file_path)
+ return await pipeline.run({"splitter": {"text": TEXT}})
+
+
+if __name__ == "__main__":
+ print(asyncio.run(main()))
diff --git a/poetry.lock b/poetry.lock
index e7bf3849..a662e717 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1905,18 +1905,18 @@ files = [
[[package]]
name = "langchain-core"
-version = "0.3.23"
+version = "0.3.24"
description = "Building applications with LLMs through composability"
optional = true
python-versions = "<4.0,>=3.9"
files = [
- {file = "langchain_core-0.3.23-py3-none-any.whl", hash = "sha256:550c0b996990830fa6515a71a1192a8a0343367999afc36d4ede14222941e420"},
- {file = "langchain_core-0.3.23.tar.gz", hash = "sha256:f9e175e3b82063cc3b160c2ca2b155832e1c6f915312e1204828f97d4aabf6e1"},
+ {file = "langchain_core-0.3.24-py3-none-any.whl", hash = "sha256:97192552ef882a3dd6ae3b870a180a743801d0137a1159173f51ac555eeb7eec"},
+ {file = "langchain_core-0.3.24.tar.gz", hash = "sha256:460851e8145327f70b70aad7dce2cdbd285e144d14af82b677256b941fc99656"},
]
[package.dependencies]
jsonpatch = ">=1.33,<2.0"
-langsmith = ">=0.1.125,<0.2.0"
+langsmith = ">=0.1.125,<0.3"
packaging = ">=23.2,<25"
pydantic = [
{version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""},
@@ -1976,13 +1976,13 @@ langchain-core = ">=0.3.15,<0.4.0"
[[package]]
name = "langsmith"
-version = "0.1.147"
+version = "0.2.2"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = true
-python-versions = "<4.0,>=3.8.1"
+python-versions = "<4.0,>=3.9"
files = [
- {file = "langsmith-0.1.147-py3-none-any.whl", hash = "sha256:7166fc23b965ccf839d64945a78e9f1157757add228b086141eb03a60d699a15"},
- {file = "langsmith-0.1.147.tar.gz", hash = "sha256:2e933220318a4e73034657103b3b1a3a6109cc5db3566a7e8e03be8d6d7def7a"},
+ {file = "langsmith-0.2.2-py3-none-any.whl", hash = "sha256:4786d7dcdbc25e43d4a1bf70bbe12938a9eb2364feec8f6fc4d967162519b367"},
+ {file = "langsmith-0.2.2.tar.gz", hash = "sha256:6f515ee41ae80968a7d552be1154414ccde57a0a534c960c8c3cd1835734095f"},
]
[package.dependencies]
@@ -2866,13 +2866,13 @@ files = [
[[package]]
name = "openai"
-version = "1.57.1"
+version = "1.57.2"
description = "The official Python library for the openai API"
optional = true
python-versions = ">=3.8"
files = [
- {file = "openai-1.57.1-py3-none-any.whl", hash = "sha256:3865686c927e93492d1145938d4a24b634951531c4b2769d43ca5dbd4b25d8fd"},
- {file = "openai-1.57.1.tar.gz", hash = "sha256:a95f22e04ab3df26e64a15d958342265e802314131275908b3b3e36f8c5d4377"},
+ {file = "openai-1.57.2-py3-none-any.whl", hash = "sha256:f7326283c156fdee875746e7e54d36959fb198eadc683952ee05e3302fbd638d"},
+ {file = "openai-1.57.2.tar.gz", hash = "sha256:5f49fd0f38e9f2131cda7deb45dafdd1aee4f52a637e190ce0ecf40147ce8cee"},
]
[package.dependencies]
@@ -4928,6 +4928,17 @@ build = ["cmake (>=3.20)", "lit"]
tests = ["autopep8", "flake8", "isort", "llnl-hatchet", "numpy", "pytest", "scipy (>=1.7.1)"]
tutorials = ["matplotlib", "pandas", "tabulate"]
+[[package]]
+name = "types-pyyaml"
+version = "6.0.12.20240917"
+description = "Typing stubs for PyYAML"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "types-PyYAML-6.0.12.20240917.tar.gz", hash = "sha256:d1405a86f9576682234ef83bcb4e6fff7c9305c8b1fbad5e0bcd4f7dbdc9c587"},
+ {file = "types_PyYAML-6.0.12.20240917-py3-none-any.whl", hash = "sha256:392b267f1c0fe6022952462bf5d6523f31e37f6cea49b14cee7ad634b6301570"},
+]
+
[[package]]
name = "types-requests"
version = "2.31.0.6"
@@ -5267,4 +5278,4 @@ weaviate = ["weaviate-client"]
[metadata]
lock-version = "2.0"
python-versions = "^3.9.0"
-content-hash = "eb706cb5d3d47a4d3ac32bc7f228bda27de77c87398999d1e610552ff19af93f"
+content-hash = "5f729b5f7f31021258d04fcf26e2310f685f1b97113e888ba346df5c7393d4e4"
diff --git a/pyproject.toml b/pyproject.toml
index 6007709d..604fde99 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -49,6 +49,7 @@ openai = {version = "^1.51.1", optional = true }
anthropic = { version = "^0.36.0", optional = true}
sentence-transformers = {version = "^3.0.0", optional = true }
json-repair = "^0.30.2"
+types-pyyaml = "^6.0.12.20240917"
[tool.poetry.group.dev.dependencies]
urllib3 = "<2"
diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py
index 225b4c0f..d4070aea 100644
--- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py
+++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py
@@ -23,7 +23,6 @@
from typing import Any, List, Optional, Union, cast
import json_repair
-
from pydantic import ValidationError, validate_call
from neo4j_graphrag.exceptions import LLMGenerationError
diff --git a/src/neo4j_graphrag/experimental/components/lexical_graph.py b/src/neo4j_graphrag/experimental/components/lexical_graph.py
index 92681a8b..ce96b9fd 100644
--- a/src/neo4j_graphrag/experimental/components/lexical_graph.py
+++ b/src/neo4j_graphrag/experimental/components/lexical_graph.py
@@ -42,6 +42,7 @@ class LexicalGraphBuilder(Component):
- A relationship between a chunk and the next one in the document
"""
+ @validate_call
def __init__(
self,
config: LexicalGraphConfig = LexicalGraphConfig(),
diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py
index 64e908ed..0f118545 100644
--- a/src/neo4j_graphrag/experimental/components/schema.py
+++ b/src/neo4j_graphrag/experimental/components/schema.py
@@ -17,9 +17,14 @@
from typing import Any, Dict, List, Literal, Optional, Tuple
from pydantic import BaseModel, ValidationError, model_validator, validate_call
+from typing_extensions import Self
from neo4j_graphrag.exceptions import SchemaValidationError
from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
+from neo4j_graphrag.experimental.pipeline.types import (
+ EntityInputType,
+ RelationInputType,
+)
class SchemaProperty(BaseModel):
@@ -55,6 +60,14 @@ class SchemaEntity(BaseModel):
description: str = ""
properties: List[SchemaProperty] = []
+ @classmethod
+ def from_text_or_dict(cls, input: EntityInputType) -> Self:
+ if isinstance(input, SchemaEntity):
+ return input
+ if isinstance(input, str):
+ return cls(label=input)
+ return cls.model_validate(input)
+
class SchemaRelation(BaseModel):
"""
@@ -65,6 +78,14 @@ class SchemaRelation(BaseModel):
description: str = ""
properties: List[SchemaProperty] = []
+ @classmethod
+ def from_text_or_dict(cls, input: RelationInputType) -> Self:
+ if isinstance(input, SchemaRelation):
+ return input
+ if isinstance(input, str):
+ return cls(label=input)
+ return cls.model_validate(input)
+
class SchemaConfig(DataModel):
"""
diff --git a/src/neo4j_graphrag/experimental/pipeline/config/__init__.py b/src/neo4j_graphrag/experimental/pipeline/config/__init__.py
new file mode 100644
index 00000000..c0199c14
--- /dev/null
+++ b/src/neo4j_graphrag/experimental/pipeline/config/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [https://neo4j.com]
+# #
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# #
+# https://www.apache.org/licenses/LICENSE-2.0
+# #
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/src/neo4j_graphrag/experimental/pipeline/config/base.py b/src/neo4j_graphrag/experimental/pipeline/config/base.py
new file mode 100644
index 00000000..665a56d0
--- /dev/null
+++ b/src/neo4j_graphrag/experimental/pipeline/config/base.py
@@ -0,0 +1,62 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [https://neo4j.com]
+# #
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# #
+# https://www.apache.org/licenses/LICENSE-2.0
+# #
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Abstract class for all pipeline configs."""
+
+from __future__ import annotations
+
+import logging
+from typing import Any
+
+from pydantic import BaseModel, PrivateAttr
+
+from neo4j_graphrag.experimental.pipeline.config.param_resolver import (
+ ParamConfig,
+ ParamToResolveConfig,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class AbstractConfig(BaseModel):
+ """Base class for all configs.
+ Provides methods to get a class from a string and resolve a parameter defined by
+ a dict with a 'resolver_' key.
+
+ Each subclass must implement a 'parse' method that returns the relevant object.
+ """
+
+ _global_data: dict[str, Any] = PrivateAttr({})
+ """Additional parameter ignored by all Pydantic model_* methods."""
+
+ def resolve_param(self, param: ParamConfig) -> Any:
+ """Finds the parameter value from its definition."""
+ if not isinstance(param, ParamToResolveConfig):
+ # some parameters do not have to be resolved, real
+ # values are already provided
+ return param
+ return param.resolve(self._global_data)
+
+ def resolve_params(self, params: dict[str, ParamConfig]) -> dict[str, Any]:
+ """Resolve all parameters
+
+ Returning dict[str, Any] because parameters can be anything (str, float, list, dict...)
+ """
+ return {
+ param_name: self.resolve_param(param)
+ for param_name, param in params.items()
+ }
+
+ def parse(self, resolved_data: dict[str, Any] | None = None) -> Any:
+ raise NotImplementedError()
diff --git a/src/neo4j_graphrag/experimental/pipeline/config/config_reader.py b/src/neo4j_graphrag/experimental/pipeline/config/config_reader.py
new file mode 100644
index 00000000..c3df28d9
--- /dev/null
+++ b/src/neo4j_graphrag/experimental/pipeline/config/config_reader.py
@@ -0,0 +1,85 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [https://neo4j.com]
+# #
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# #
+# https://www.apache.org/licenses/LICENSE-2.0
+# #
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Read JSON or YAML files and returns a dict.
+No data validation performed at this stage.
+"""
+
+import json
+import logging
+from pathlib import Path
+from typing import Any, Optional
+
+import fsspec
+import yaml
+from fsspec.implementations.local import LocalFileSystem
+
+logger = logging.getLogger(__name__)
+
+
+class ConfigReader:
+ """Reads config from a file (JSON or YAML format)
+ and returns a dict.
+
+ File format is guessed from the extension. Supported extensions are
+ (lower or upper case):
+
+ - .json
+ - .yaml, .yml
+
+ Example:
+
+ .. code-block:: python
+
+ from pathlib import Path
+ from neo4j_graphrag.experimental.pipeline.config.reader import ConfigReader
+ reader = ConfigReader()
+ reader.read(Path("my_file.json"))
+
+ If reading a file with a different extension but still in JSON or YAML format,
+ it is possible to call directly the `read_json` or `read_yaml` methods:
+
+ .. code-block:: python
+
+ reader.read_yaml(Path("my_file.txt"))
+
+ """
+
+ def __init__(self, fs: Optional[fsspec.AbstractFileSystem] = None) -> None:
+ self.fs = fs or LocalFileSystem()
+
+ def read_json(self, file_path: str) -> Any:
+ logger.debug(f"CONFIG_READER: read from json {file_path}")
+ with self.fs.open(file_path, "r") as f:
+ return json.load(f)
+
+ def read_yaml(self, file_path: str) -> Any:
+ logger.debug(f"CONFIG_READER: read from yaml {file_path}")
+ with self.fs.open(file_path, "r") as f:
+ return yaml.safe_load(f)
+
+ def _guess_format_and_read(self, file_path: str) -> dict[str, Any]:
+ p = Path(file_path)
+ extension = p.suffix.lower()
+ # Note: .suffix returns an empty string if Path has no extension
+ # if not returning a dict, parsing will fail later on
+ if extension in [".json"]:
+ return self.read_json(file_path) # type: ignore[no-any-return]
+ if extension in [".yaml", ".yml"]:
+ return self.read_yaml(file_path) # type: ignore[no-any-return]
+ raise ValueError(f"Unsupported extension: {extension}")
+
+ def read(self, file_path: str) -> dict[str, Any]:
+ data = self._guess_format_and_read(file_path)
+ return data
diff --git a/src/neo4j_graphrag/experimental/pipeline/config/object_config.py b/src/neo4j_graphrag/experimental/pipeline/config/object_config.py
new file mode 100644
index 00000000..cb47b380
--- /dev/null
+++ b/src/neo4j_graphrag/experimental/pipeline/config/object_config.py
@@ -0,0 +1,266 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [https://neo4j.com]
+# #
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# #
+# https://www.apache.org/licenses/LICENSE-2.0
+# #
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Config for all parameters that can be both provided as object instance or
+config dict with 'class_' and 'params_' keys.
+
+Nomenclature in this file:
+
+- `*Config` models are used to represent "things" as dict to be used in a config file.
+ e.g.:
+ - neo4j.Driver => {"uri": "", "user": "", "password": ""}
+ - LLMInterface => {"class_": "OpenAI", "params_": {"model_name": "gpt-4o"}}
+- `*Type` models are wrappers around an object and a 'Config' the object can be created
+ from. They are used to allow the instantiation of "PipelineConfig" either from
+ instantiated objects (when used in code) and from a config dict (when used to
+ load config from file).
+"""
+
+from __future__ import annotations
+
+import importlib
+import logging
+from typing import (
+ Any,
+ ClassVar,
+ Generic,
+ Optional,
+ TypeVar,
+ Union,
+ cast,
+)
+
+import neo4j
+from pydantic import (
+ ConfigDict,
+ Field,
+ RootModel,
+ field_validator,
+)
+
+from neo4j_graphrag.embeddings import Embedder
+from neo4j_graphrag.experimental.pipeline import Component
+from neo4j_graphrag.experimental.pipeline.config.base import AbstractConfig
+from neo4j_graphrag.experimental.pipeline.config.param_resolver import (
+ ParamConfig,
+)
+from neo4j_graphrag.llm import LLMInterface
+
+logger = logging.getLogger(__name__)
+
+
+T = TypeVar("T")
+"""Generic type to help mypy with the parse method when we know the exact
+expected return type (e.g. for the Neo4jDriverConfig below).
+"""
+
+
+class ObjectConfig(AbstractConfig, Generic[T]):
+ """A config class to represent an object from a class name
+ and its constructor parameters.
+ """
+
+ class_: str | None = Field(default=None, validate_default=True)
+ """Path to class to be instantiated."""
+ params_: dict[str, ParamConfig] = {}
+ """Initialization parameters."""
+
+ DEFAULT_MODULE: ClassVar[str] = "."
+ """Default module to import the class from."""
+ INTERFACE: ClassVar[type] = object
+ """Constraint on the class (must be a subclass of)."""
+ REQUIRED_PARAMS: ClassVar[list[str]] = []
+ """List of required parameters for this object constructor."""
+
+ @field_validator("params_")
+ @classmethod
+ def validate_params(cls, params_: dict[str, Any]) -> dict[str, Any]:
+ """Make sure all required parameters are provided."""
+ for p in cls.REQUIRED_PARAMS:
+ if p not in params_:
+ raise ValueError(f"Missing parameter {p}")
+ return params_
+
+ def get_module(self) -> str:
+ return self.DEFAULT_MODULE
+
+ def get_interface(self) -> type:
+ return self.INTERFACE
+
+ @classmethod
+ def _get_class(cls, class_path: str, optional_module: Optional[str] = None) -> type:
+ """Get class from string and an optional module
+
+ Will first try to import the class from `class_path` alone. If it results in an ImportError,
+ will try to import from `f'{optional_module}.{class_path}'`
+
+ Args:
+ class_path (str): Class path with format 'my_module.MyClass'.
+ optional_module (Optional[str]): Optional module path. Used to provide a default path for some known objects and simplify the notation.
+
+ Raises:
+ ValueError: if the class can't be imported, even using the optional module.
+ """
+ *modules, class_name = class_path.rsplit(".", 1)
+ module_name = modules[0] if modules else optional_module
+ if module_name is None:
+ raise ValueError("Must specify a module to import class from")
+ try:
+ module = importlib.import_module(module_name)
+ klass = getattr(module, class_name)
+ except (ImportError, AttributeError):
+ if optional_module and module_name != optional_module:
+ full_klass_path = optional_module + "." + class_path
+ return cls._get_class(full_klass_path)
+ raise ValueError(f"Could not find {class_name} in {module_name}")
+ return cast(type, klass)
+
+ def parse(self, resolved_data: dict[str, Any] | None = None) -> T:
+ """Import `class_`, resolve `params_` and instantiate object."""
+ self._global_data = resolved_data or {}
+ logger.debug(f"OBJECT_CONFIG: parsing {self} using {resolved_data}")
+ if self.class_ is None:
+ raise ValueError(f"`class_` is not required to parse object {self}")
+ klass = self._get_class(self.class_, self.get_module())
+ if not issubclass(klass, self.get_interface()):
+ raise ValueError(
+ f"Invalid class '{klass}'. Expected a subclass of '{self.get_interface()}'"
+ )
+ params = self.resolve_params(self.params_)
+ try:
+ obj = klass(**params)
+ except TypeError as e:
+ logger.error(
+ "OBJECT_CONFIG: failed to instantiate object due to improperly configured parameters"
+ )
+ raise e
+ return cast(T, obj)
+
+
+class Neo4jDriverConfig(ObjectConfig[neo4j.Driver]):
+ REQUIRED_PARAMS = ["uri", "user", "password"]
+
+ @field_validator("class_", mode="before")
+ @classmethod
+ def validate_class(cls, class_: Any) -> str:
+ """`class_` parameter is not used because we're always using the sync driver."""
+ if class_:
+ logger.info("Parameter class_ is not used for Neo4jDriverConfig")
+ # not used
+ return "not used"
+
+ def parse(self, resolved_data: dict[str, Any] | None = None) -> neo4j.Driver:
+ params = self.resolve_params(self.params_)
+ # we know these params are there because of the required params validator
+ uri = params.pop("uri")
+ user = params.pop("user")
+ password = params.pop("password")
+ driver = neo4j.GraphDatabase.driver(uri, auth=(user, password), **params)
+ return driver
+
+
+# note: using the notation with RootModel + root: field
+# instead of RootModel[] for clarity
+# but this requires the type: ignore comment below
+class Neo4jDriverType(RootModel): # type: ignore[type-arg]
+ """A model to wrap neo4j.Driver and Neo4jDriverConfig objects.
+
+ The `parse` method always returns a neo4j.Driver.
+ """
+
+ root: Union[neo4j.Driver, Neo4jDriverConfig]
+
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+ def parse(self, resolved_data: dict[str, Any] | None = None) -> neo4j.Driver:
+ if isinstance(self.root, neo4j.Driver):
+ return self.root
+ # self.root is a Neo4jDriverConfig object
+ return self.root.parse(resolved_data)
+
+
+class LLMConfig(ObjectConfig[LLMInterface]):
+ """Configuration for any LLMInterface object.
+
+ By default, will try to import from `neo4j_graphrag.llm`.
+ """
+
+ DEFAULT_MODULE = "neo4j_graphrag.llm"
+ INTERFACE = LLMInterface
+
+
+class LLMType(RootModel): # type: ignore[type-arg]
+ """A model to wrap LLMInterface and LLMConfig objects.
+
+ The `parse` method always returns an object inheriting from LLMInterface.
+ """
+
+ root: Union[LLMInterface, LLMConfig]
+
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+ def parse(self, resolved_data: dict[str, Any] | None = None) -> LLMInterface:
+ if isinstance(self.root, LLMInterface):
+ return self.root
+ return self.root.parse(resolved_data)
+
+
+class EmbedderConfig(ObjectConfig[Embedder]):
+ """Configuration for any Embedder object.
+
+ By default, will try to import from `neo4j_graphrag.embeddings`.
+ """
+
+ DEFAULT_MODULE = "neo4j_graphrag.embeddings"
+ INTERFACE = Embedder
+
+
+class EmbedderType(RootModel): # type: ignore[type-arg]
+ """A model to wrap Embedder and EmbedderConfig objects.
+
+ The `parse` method always returns an object inheriting from Embedder.
+ """
+
+ root: Union[Embedder, EmbedderConfig]
+
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+ def parse(self, resolved_data: dict[str, Any] | None = None) -> Embedder:
+ if isinstance(self.root, Embedder):
+ return self.root
+ return self.root.parse(resolved_data)
+
+
+class ComponentConfig(ObjectConfig[Component]):
+ """A config model for all components.
+
+ In addition to the object config, components can have pre-defined parameters
+ that will be passed to the `run` method, ie `run_params_`.
+ """
+
+ run_params_: dict[str, ParamConfig] = {}
+
+ DEFAULT_MODULE = "neo4j_graphrag.experimental.components"
+ INTERFACE = Component
+
+
+class ComponentType(RootModel): # type: ignore[type-arg]
+ root: Union[Component, ComponentConfig]
+
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+ def parse(self, resolved_data: dict[str, Any] | None = None) -> Component:
+ if isinstance(self.root, Component):
+ return self.root
+ return self.root.parse(resolved_data)
diff --git a/src/neo4j_graphrag/experimental/pipeline/config/param_resolver.py b/src/neo4j_graphrag/experimental/pipeline/config/param_resolver.py
new file mode 100644
index 00000000..24190add
--- /dev/null
+++ b/src/neo4j_graphrag/experimental/pipeline/config/param_resolver.py
@@ -0,0 +1,60 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [https://neo4j.com]
+# #
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# #
+# https://www.apache.org/licenses/LICENSE-2.0
+# #
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import enum
+import os
+from typing import Any, ClassVar, Literal, Union
+
+from pydantic import BaseModel
+
+
+class ParamResolverEnum(str, enum.Enum):
+ ENV = "ENV"
+ CONFIG_KEY = "CONFIG_KEY"
+
+
+class ParamToResolveConfig(BaseModel):
+ def resolve(self, data: dict[str, Any]) -> Any:
+ raise NotImplementedError
+
+
+class ParamFromEnvConfig(ParamToResolveConfig):
+ resolver_: Literal[ParamResolverEnum.ENV] = ParamResolverEnum.ENV
+ var_: str
+
+ def resolve(self, data: dict[str, Any]) -> Any:
+ return os.environ.get(self.var_)
+
+
+class ParamFromKeyConfig(ParamToResolveConfig):
+ resolver_: Literal[ParamResolverEnum.CONFIG_KEY] = ParamResolverEnum.CONFIG_KEY
+ key_: str
+
+ KEY_SEP: ClassVar[str] = "."
+
+ def resolve(self, data: dict[str, Any]) -> Any:
+ d = data
+ for k in self.key_.split(self.KEY_SEP):
+ d = d[k]
+ return d
+
+
+ParamConfig = Union[
+ float,
+ str,
+ ParamFromEnvConfig,
+ ParamFromKeyConfig,
+ dict[str, Any],
+]
diff --git a/src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py b/src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py
new file mode 100644
index 00000000..c3871179
--- /dev/null
+++ b/src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py
@@ -0,0 +1,199 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [https://neo4j.com]
+# #
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# #
+# https://www.apache.org/licenses/LICENSE-2.0
+# #
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Any, ClassVar, Literal, Optional, Union
+
+import neo4j
+from pydantic import field_validator
+
+from neo4j_graphrag.embeddings import Embedder
+from neo4j_graphrag.experimental.pipeline.config.base import AbstractConfig
+from neo4j_graphrag.experimental.pipeline.config.object_config import (
+ ComponentType,
+ EmbedderType,
+ LLMType,
+ Neo4jDriverType,
+)
+from neo4j_graphrag.experimental.pipeline.config.types import PipelineType
+from neo4j_graphrag.experimental.pipeline.types import (
+ ComponentDefinition,
+ ConnectionDefinition,
+ PipelineDefinition,
+)
+from neo4j_graphrag.llm import LLMInterface
+
+logger = logging.getLogger(__name__)
+
+
+class AbstractPipelineConfig(AbstractConfig):
+ """This class defines the fields possibly used by all pipelines: neo4j drivers, LLMs...
+ neo4j_config, llm_config can be provided by user as a single item or a dict of items.
+ Validators deal with type conversion so that the field in all instances is a dict of items.
+ """
+
+ neo4j_config: dict[str, Neo4jDriverType] = {}
+ llm_config: dict[str, LLMType] = {}
+ embedder_config: dict[str, EmbedderType] = {}
+ # extra parameters values that can be used in different places of the config file
+ extras: dict[str, Any] = {}
+
+ DEFAULT_NAME: ClassVar[str] = "default"
+ """Name of the default item in dict
+ """
+
+ @field_validator("neo4j_config", mode="before")
+ @classmethod
+ def validate_drivers(
+ cls, drivers: Union[Neo4jDriverType, dict[str, Any]]
+ ) -> dict[str, Any]:
+ if not isinstance(drivers, dict) or "params_" in drivers:
+ return {cls.DEFAULT_NAME: drivers}
+ return drivers
+
+ @field_validator("llm_config", mode="before")
+ @classmethod
+ def validate_llms(cls, llms: Union[LLMType, dict[str, Any]]) -> dict[str, Any]:
+ if not isinstance(llms, dict) or "class_" in llms:
+ return {cls.DEFAULT_NAME: llms}
+ return llms
+
+ @field_validator("embedder_config", mode="before")
+ @classmethod
+ def validate_embedders(
+ cls, embedders: Union[EmbedderType, dict[str, Any]]
+ ) -> dict[str, Any]:
+ if not isinstance(embedders, dict) or "class_" in embedders:
+ return {cls.DEFAULT_NAME: embedders}
+ return embedders
+
+ def _resolve_component_definition(
+ self, name: str, config: ComponentType
+ ) -> ComponentDefinition:
+ component = config.parse(self._global_data)
+ if hasattr(config.root, "run_params_"):
+ component_run_params = self.resolve_params(config.root.run_params_)
+ else:
+ component_run_params = {}
+ component_def = ComponentDefinition(
+ name=name,
+ component=component,
+ run_params=component_run_params,
+ )
+ logger.debug(f"PIPELINE_CONFIG: resolved component {component_def}")
+ return component_def
+
+ def _parse_global_data(self) -> dict[str, Any]:
+ """Global data contains data that can be referenced in other parts of the
+ config.
+
+ Typically, neo4j drivers, LLMs and embedders can be referenced in component
+ input parameters.
+ """
+ # 'extras' parameters can be referenced in other configs,
+ # that's why they are parsed before the others
+ # e.g., an API key used for both LLM and Embedder can be stored only
+ # once in extras.
+ extra_data = {
+ "extras": self.resolve_params(self.extras),
+ }
+ logger.debug(f"PIPELINE_CONFIG: resolved 'extras': {extra_data}")
+ drivers: dict[str, neo4j.Driver] = {
+ driver_name: driver_config.parse(extra_data)
+ for driver_name, driver_config in self.neo4j_config.items()
+ }
+ llms: dict[str, LLMInterface] = {
+ llm_name: llm_config.parse(extra_data)
+ for llm_name, llm_config in self.llm_config.items()
+ }
+ embedders: dict[str, Embedder] = {
+ embedder_name: embedder_config.parse(extra_data)
+ for embedder_name, embedder_config in self.embedder_config.items()
+ }
+ global_data = {
+ **extra_data,
+ "neo4j_config": drivers,
+ "llm_config": llms,
+ "embedder_config": embedders,
+ }
+ logger.debug(f"PIPELINE_CONFIG: resolved globals: {global_data}")
+ return global_data
+
+ def _get_components(self) -> list[ComponentDefinition]:
+ return []
+
+ def _get_connections(self) -> list[ConnectionDefinition]:
+ return []
+
+ def parse(
+ self, resolved_data: Optional[dict[str, Any]] = None
+ ) -> PipelineDefinition:
+ """Parse the full config and returns a PipelineDefinition object, containing instantiated
+ components and a list of connections.
+ """
+ self._global_data = self._parse_global_data()
+ return PipelineDefinition(
+ components=self._get_components(),
+ connections=self._get_connections(),
+ )
+
+ def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]:
+ return user_input
+
+ async def close(self) -> None:
+ drivers = self._global_data.get("neo4j_config", {})
+ for driver_name in drivers:
+ driver = drivers[driver_name]
+ logger.debug(f"PIPELINE_CONFIG: closing driver {driver_name}: {driver}")
+ driver.close()
+
+ def get_neo4j_driver_by_name(self, name: str) -> neo4j.Driver:
+ drivers: dict[str, neo4j.Driver] = self._global_data.get("neo4j_config", {})
+ return drivers[name]
+
+ def get_default_neo4j_driver(self) -> neo4j.Driver:
+ return self.get_neo4j_driver_by_name(self.DEFAULT_NAME)
+
+ def get_llm_by_name(self, name: str) -> LLMInterface:
+ llms: dict[str, LLMInterface] = self._global_data.get("llm_config", {})
+ return llms[name]
+
+ def get_default_llm(self) -> LLMInterface:
+ return self.get_llm_by_name(self.DEFAULT_NAME)
+
+ def get_embedder_by_name(self, name: str) -> Embedder:
+ embedders: dict[str, Embedder] = self._global_data.get("embedder_config", {})
+ return embedders[name]
+
+ def get_default_embedder(self) -> Embedder:
+ return self.get_embedder_by_name(self.DEFAULT_NAME)
+
+
+class PipelineConfig(AbstractPipelineConfig):
+ """Configuration class for raw pipelines.
+ Config must contain all components and connections."""
+
+ component_config: dict[str, ComponentType]
+ connection_config: list[ConnectionDefinition]
+ template_: Literal[PipelineType.NONE] = PipelineType.NONE
+
+ def _get_connections(self) -> list[ConnectionDefinition]:
+ return self.connection_config
+
+ def _get_components(self) -> list[ComponentDefinition]:
+ return [
+ self._resolve_component_definition(name, component_config)
+ for name, component_config in self.component_config.items()
+ ]
diff --git a/src/neo4j_graphrag/experimental/pipeline/config/runner.py b/src/neo4j_graphrag/experimental/pipeline/config/runner.py
new file mode 100644
index 00000000..a1a22585
--- /dev/null
+++ b/src/neo4j_graphrag/experimental/pipeline/config/runner.py
@@ -0,0 +1,132 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [https://neo4j.com]
+# #
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# #
+# https://www.apache.org/licenses/LICENSE-2.0
+# #
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Pipeline config wrapper (router based on 'template_' key)
+and pipeline runner.
+"""
+
+from __future__ import annotations
+
+import logging
+from pathlib import Path
+from typing import (
+ Annotated,
+ Any,
+ Union,
+)
+
+from pydantic import (
+ BaseModel,
+ Discriminator,
+ Field,
+ Tag,
+)
+from pydantic.v1.utils import deep_update
+from typing_extensions import Self
+
+from neo4j_graphrag.experimental.pipeline import Pipeline
+from neo4j_graphrag.experimental.pipeline.config.config_reader import ConfigReader
+from neo4j_graphrag.experimental.pipeline.config.pipeline_config import (
+ AbstractPipelineConfig,
+ PipelineConfig,
+)
+from neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder import (
+ SimpleKGPipelineConfig,
+)
+from neo4j_graphrag.experimental.pipeline.config.types import PipelineType
+from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
+from neo4j_graphrag.experimental.pipeline.types import PipelineDefinition
+
+logger = logging.getLogger(__name__)
+
+
+def _get_discriminator_value(model: Any) -> PipelineType:
+ template_ = None
+ if "template_" in model:
+ template_ = model["template_"]
+ if hasattr(model, "template_"):
+ template_ = model.template_
+ return PipelineType(template_) or PipelineType.NONE
+
+
+class PipelineConfigWrapper(BaseModel):
+ """The pipeline config wrapper will parse the right pipeline config based on the `template_` field."""
+
+ config: Union[
+ Annotated[PipelineConfig, Tag(PipelineType.NONE)],
+ Annotated[SimpleKGPipelineConfig, Tag(PipelineType.SIMPLE_KG_PIPELINE)],
+ ] = Field(discriminator=Discriminator(_get_discriminator_value))
+
+ def parse(self, resolved_data: dict[str, Any] | None = None) -> PipelineDefinition:
+ return self.config.parse(resolved_data)
+
+ def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]:
+ return self.config.get_run_params(user_input)
+
+
+class PipelineRunner:
+ """Pipeline runner builds a pipeline from different objects and exposes a run method to run pipeline
+
+ Pipeline can be built from:
+ - A PipelineDefinition (`__init__` method)
+ - A PipelineConfig (`from_config` method)
+ - A config file (`from_config_file` method)
+ """
+
+ def __init__(
+ self,
+ pipeline_definition: PipelineDefinition,
+ config: AbstractPipelineConfig | None = None,
+ do_cleaning: bool = False,
+ ) -> None:
+ self.config = config
+ self.pipeline = Pipeline.from_definition(pipeline_definition)
+ self.run_params = pipeline_definition.get_run_params()
+ self.do_cleaning = do_cleaning
+
+ @classmethod
+ def from_config(
+ cls, config: AbstractPipelineConfig | dict[str, Any], do_cleaning: bool = False
+ ) -> Self:
+ wrapper = PipelineConfigWrapper.model_validate({"config": config})
+ return cls(wrapper.parse(), config=wrapper.config, do_cleaning=do_cleaning)
+
+ @classmethod
+ def from_config_file(cls, file_path: Union[str, Path]) -> Self:
+ if not isinstance(file_path, str):
+ file_path = str(file_path)
+ data = ConfigReader().read(file_path)
+ return cls.from_config(data, do_cleaning=True)
+
+ async def run(self, user_input: dict[str, Any]) -> PipelineResult:
+ # pipeline_conditional_run_params = self.
+ if self.config:
+ run_param = deep_update(
+ self.run_params, self.config.get_run_params(user_input)
+ )
+ else:
+ run_param = deep_update(self.run_params, user_input)
+ logger.info(
+ f"PIPELINE_RUNNER: starting pipeline {self.pipeline} with run_params={run_param}"
+ )
+ result = await self.pipeline.run(data=run_param)
+ if self.do_cleaning:
+ await self.close()
+ return result
+
+ async def close(self) -> None:
+ logger.debug("PIPELINE_RUNNER: cleaning up (closing instantiated drivers...)")
+ if self.config:
+ await self.config.close()
diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/__init__.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/__init__.py
new file mode 100644
index 00000000..125a1c87
--- /dev/null
+++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/__init__.py
@@ -0,0 +1,20 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [https://neo4j.com]
+# #
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# #
+# https://www.apache.org/licenses/LICENSE-2.0
+# #
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .simple_kg_builder import SimpleKGPipelineConfig
+
+__all__ = [
+ "SimpleKGPipelineConfig",
+]
diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/base.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/base.py
new file mode 100644
index 00000000..69fbc751
--- /dev/null
+++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/base.py
@@ -0,0 +1,63 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [https://neo4j.com]
+# #
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# #
+# https://www.apache.org/licenses/LICENSE-2.0
+# #
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import Any, ClassVar, Optional
+
+from neo4j_graphrag.experimental.pipeline.config.pipeline_config import (
+ AbstractPipelineConfig,
+)
+from neo4j_graphrag.experimental.pipeline.types import ComponentDefinition
+
+logger = logging.getLogger(__name__)
+
+
+class TemplatePipelineConfig(AbstractPipelineConfig):
+ """This class represent a 'template' pipeline, ie pipeline with pre-defined default
+ components and fixed connections.
+
+ Component names are defined in the COMPONENTS class var. For each of them,
+ a `_get_` method must be implemented that returns the proper
+ component. Optionally, `_get_run_params_for_` can be implemented
+ to deal with parameters required by the component's run method and predefined on
+ template initialization.
+ """
+
+ COMPONENTS: ClassVar[list[str]] = []
+
+ def _get_component(self, component_name: str) -> Optional[ComponentDefinition]:
+ method = getattr(self, f"_get_{component_name}")
+ component = method()
+ if component is None:
+ return None
+ method = getattr(self, f"_get_run_params_for_{component_name}", None)
+ run_params = method() if method else {}
+ component_definition = ComponentDefinition(
+ name=component_name,
+ component=component,
+ run_params=run_params,
+ )
+ logger.debug(f"TEMPLATE_PIPELINE: resolved component {component_definition}")
+ return component_definition
+
+ def _get_components(self) -> list[ComponentDefinition]:
+ components = []
+ for component_name in self.COMPONENTS:
+ comp = self._get_component(component_name)
+ if comp is not None:
+ components.append(comp)
+ return components
+
+ def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]:
+ return {}
diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py
new file mode 100644
index 00000000..73edfd9a
--- /dev/null
+++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py
@@ -0,0 +1,228 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [https://neo4j.com]
+# #
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# #
+# https://www.apache.org/licenses/LICENSE-2.0
+# #
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, ClassVar, Literal, Optional, Sequence, Union
+
+from pydantic import ConfigDict
+
+from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
+from neo4j_graphrag.experimental.components.entity_relation_extractor import (
+ EntityRelationExtractor,
+ LLMEntityRelationExtractor,
+ OnError,
+)
+from neo4j_graphrag.experimental.components.kg_writer import KGWriter, Neo4jWriter
+from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader
+from neo4j_graphrag.experimental.components.resolver import (
+ EntityResolver,
+ SinglePropertyExactMatchResolver,
+)
+from neo4j_graphrag.experimental.components.schema import (
+ SchemaBuilder,
+ SchemaEntity,
+ SchemaRelation,
+)
+from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter
+from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
+ FixedSizeSplitter,
+)
+from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
+from neo4j_graphrag.experimental.pipeline.config.object_config import ComponentType
+from neo4j_graphrag.experimental.pipeline.config.template_pipeline.base import (
+ TemplatePipelineConfig,
+)
+from neo4j_graphrag.experimental.pipeline.config.types import PipelineType
+from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
+from neo4j_graphrag.experimental.pipeline.types import (
+ ConnectionDefinition,
+ EntityInputType,
+ RelationInputType,
+)
+from neo4j_graphrag.generation.prompts import ERExtractionTemplate
+
+
+class SimpleKGPipelineConfig(TemplatePipelineConfig):
+ COMPONENTS: ClassVar[list[str]] = [
+ "pdf_loader",
+ "splitter",
+ "chunk_embedder",
+ "schema",
+ "extractor",
+ "writer",
+ "resolver",
+ ]
+
+ template_: Literal[PipelineType.SIMPLE_KG_PIPELINE] = (
+ PipelineType.SIMPLE_KG_PIPELINE
+ )
+
+ from_pdf: bool = False
+ entities: Sequence[EntityInputType] = []
+ relations: Sequence[RelationInputType] = []
+ potential_schema: Optional[list[tuple[str, str, str]]] = None
+ on_error: OnError = OnError.IGNORE
+ prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate()
+ perform_entity_resolution: bool = True
+ lexical_graph_config: Optional[LexicalGraphConfig] = None
+ neo4j_database: Optional[str] = None
+
+ pdf_loader: Optional[ComponentType] = None
+ kg_writer: Optional[ComponentType] = None
+ text_splitter: Optional[ComponentType] = None
+
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+ def _get_pdf_loader(self) -> Optional[PdfLoader]:
+ if not self.from_pdf:
+ return None
+ if self.pdf_loader:
+ return self.pdf_loader.parse(self._global_data) # type: ignore
+ return PdfLoader()
+
+ def _get_splitter(self) -> TextSplitter:
+ if self.text_splitter:
+ return self.text_splitter.parse(self._global_data) # type: ignore
+ return FixedSizeSplitter()
+
+ def _get_chunk_embedder(self) -> TextChunkEmbedder:
+ return TextChunkEmbedder(embedder=self.get_default_embedder())
+
+ def _get_schema(self) -> SchemaBuilder:
+ return SchemaBuilder()
+
+ def _get_run_params_for_schema(self) -> dict[str, Any]:
+ return {
+ "entities": [SchemaEntity.from_text_or_dict(e) for e in self.entities],
+ "relations": [SchemaRelation.from_text_or_dict(r) for r in self.relations],
+ "potential_schema": self.potential_schema,
+ }
+
+ def _get_extractor(self) -> EntityRelationExtractor:
+ return LLMEntityRelationExtractor(
+ llm=self.get_default_llm(),
+ prompt_template=self.prompt_template,
+ on_error=self.on_error,
+ )
+
+ def _get_writer(self) -> KGWriter:
+ if self.kg_writer:
+ return self.kg_writer.parse(self._global_data) # type: ignore
+ return Neo4jWriter(
+ driver=self.get_default_neo4j_driver(),
+ neo4j_database=self.neo4j_database,
+ )
+
+ def _get_resolver(self) -> Optional[EntityResolver]:
+ if not self.perform_entity_resolution:
+ return None
+ return SinglePropertyExactMatchResolver(
+ driver=self.get_default_neo4j_driver(),
+ neo4j_database=self.neo4j_database,
+ )
+
+ def _get_connections(self) -> list[ConnectionDefinition]:
+ connections = []
+ if self.from_pdf:
+ connections.append(
+ ConnectionDefinition(
+ start="pdf_loader",
+ end="splitter",
+ input_config={"text": "pdf_loader.text"},
+ )
+ )
+ connections.append(
+ ConnectionDefinition(
+ start="schema",
+ end="extractor",
+ input_config={
+ "schema": "schema",
+ "document_info": "pdf_loader.document_info",
+ },
+ )
+ )
+ else:
+ connections.append(
+ ConnectionDefinition(
+ start="schema",
+ end="extractor",
+ input_config={
+ "schema": "schema",
+ },
+ )
+ )
+ connections.append(
+ ConnectionDefinition(
+ start="splitter",
+ end="chunk_embedder",
+ input_config={
+ "text_chunks": "splitter",
+ },
+ )
+ )
+ connections.append(
+ ConnectionDefinition(
+ start="chunk_embedder",
+ end="extractor",
+ input_config={
+ "chunks": "chunk_embedder",
+ },
+ )
+ )
+ connections.append(
+ ConnectionDefinition(
+ start="extractor",
+ end="writer",
+ input_config={
+ "graph": "extractor",
+ },
+ )
+ )
+
+ if self.perform_entity_resolution:
+ connections.append(
+ ConnectionDefinition(
+ start="writer",
+ end="resolver",
+ input_config={},
+ )
+ )
+
+ return connections
+
+ def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]:
+ run_params = {}
+ if self.lexical_graph_config:
+ run_params["extractor"] = {
+ "lexical_graph_config": self.lexical_graph_config
+ }
+ text = user_input.get("text")
+ file_path = user_input.get("file_path")
+ if not ((text is None) ^ (file_path is None)):
+ # exactly one of text or user_input must be set
+ raise PipelineDefinitionError(
+ "Use either 'text' (when from_pdf=False) or 'file_path' (when from_pdf=True) argument."
+ )
+ if self.from_pdf:
+ if not file_path:
+ raise PipelineDefinitionError(
+ "Expected 'file_path' argument when 'from_pdf' is True."
+ )
+ run_params["pdf_loader"] = {"filepath": file_path}
+ else:
+ if not text:
+ raise PipelineDefinitionError(
+ "Expected 'text' argument when 'from_pdf' is False."
+ )
+ run_params["splitter"] = {"text": text}
+ return run_params
diff --git a/src/neo4j_graphrag/experimental/pipeline/config/types.py b/src/neo4j_graphrag/experimental/pipeline/config/types.py
new file mode 100644
index 00000000..48f91f48
--- /dev/null
+++ b/src/neo4j_graphrag/experimental/pipeline/config/types.py
@@ -0,0 +1,26 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [https://neo4j.com]
+# #
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# #
+# https://www.apache.org/licenses/LICENSE-2.0
+# #
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import enum
+
+
+class PipelineType(str, enum.Enum):
+ """Pipeline type:
+
+ NONE => Pipeline
+ SIMPLE_KG_PIPELINE ~> SimpleKGPipeline
+ """
+
+ NONE = "none"
+ SIMPLE_KG_PIPELINE = "SimpleKGPipeline"
diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py
index 58868cb3..3fca0215 100644
--- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py
+++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py
@@ -18,30 +18,16 @@
from typing import Any, List, Optional, Sequence, Union
import neo4j
-from pydantic import BaseModel, ConfigDict, Field
+from pydantic import ValidationError
from neo4j_graphrag.embeddings import Embedder
-from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
-from neo4j_graphrag.experimental.components.entity_relation_extractor import (
- LLMEntityRelationExtractor,
- OnError,
-)
-from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter
-from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader
-from neo4j_graphrag.experimental.components.resolver import (
- SinglePropertyExactMatchResolver,
-)
-from neo4j_graphrag.experimental.components.schema import (
- SchemaBuilder,
- SchemaEntity,
- SchemaRelation,
-)
-from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
- FixedSizeSplitter,
-)
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
+from neo4j_graphrag.experimental.pipeline.config.runner import PipelineRunner
+from neo4j_graphrag.experimental.pipeline.config.template_pipeline import (
+ SimpleKGPipelineConfig,
+)
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
-from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline, PipelineResult
+from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
from neo4j_graphrag.experimental.pipeline.types import (
EntityInputType,
RelationInputType,
@@ -50,26 +36,6 @@
from neo4j_graphrag.llm.base import LLMInterface
-class SimpleKGPipelineConfig(BaseModel):
- llm: LLMInterface
- driver: neo4j.Driver
- from_pdf: bool
- embedder: Embedder
- entities: list[SchemaEntity] = Field(default_factory=list)
- relations: list[SchemaRelation] = Field(default_factory=list)
- potential_schema: list[tuple[str, str, str]] = Field(default_factory=list)
- pdf_loader: Any = None
- kg_writer: Any = None
- text_splitter: Any = None
- on_error: OnError = OnError.RAISE
- prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate()
- perform_entity_resolution: bool = True
- lexical_graph_config: Optional[LexicalGraphConfig] = None
- neo4j_database: Optional[str] = None
-
- model_config = ConfigDict(arbitrary_types_allowed=True)
-
-
class SimpleKGPipeline:
"""
A class to simplify the process of building a knowledge graph from text documents.
@@ -120,133 +86,29 @@ def __init__(
lexical_graph_config: Optional[LexicalGraphConfig] = None,
neo4j_database: Optional[str] = None,
):
- self.potential_schema = potential_schema or []
- self.entities = [self.to_schema_entity(e) for e in entities or []]
- self.relations = [self.to_schema_relation(r) for r in relations or []]
-
try:
- on_error_enum = OnError(on_error)
- except ValueError:
- raise PipelineDefinitionError(
- f"Invalid value for on_error: {on_error}. Expected one of {OnError.possible_values()}."
- )
-
- config = SimpleKGPipelineConfig(
- llm=llm,
- driver=driver,
- entities=self.entities,
- relations=self.relations,
- potential_schema=self.potential_schema,
- from_pdf=from_pdf,
- pdf_loader=pdf_loader,
- kg_writer=kg_writer,
- text_splitter=text_splitter,
- on_error=on_error_enum,
- prompt_template=prompt_template,
- embedder=embedder,
- perform_entity_resolution=perform_entity_resolution,
- lexical_graph_config=lexical_graph_config,
- neo4j_database=neo4j_database,
- )
-
- self.from_pdf = config.from_pdf
- self.llm = config.llm
- self.driver = config.driver
- self.embedder = config.embedder
- self.text_splitter = config.text_splitter or FixedSizeSplitter()
- self.on_error = config.on_error
- self.pdf_loader = config.pdf_loader if pdf_loader is not None else PdfLoader()
- self.kg_writer = (
- config.kg_writer
- if kg_writer is not None
- else Neo4jWriter(driver, neo4j_database=config.neo4j_database)
- )
- self.prompt_template = config.prompt_template
- self.perform_entity_resolution = config.perform_entity_resolution
- self.lexical_graph_config = config.lexical_graph_config
- self.neo4j_database = config.neo4j_database
-
- self.pipeline = self._build_pipeline()
-
- @staticmethod
- def to_schema_entity(entity: EntityInputType) -> SchemaEntity:
- if isinstance(entity, dict):
- return SchemaEntity.model_validate(entity)
- return SchemaEntity(label=entity)
-
- @staticmethod
- def to_schema_relation(relation: RelationInputType) -> SchemaRelation:
- if isinstance(relation, dict):
- return SchemaRelation.model_validate(relation)
- return SchemaRelation(label=relation)
-
- def _build_pipeline(self) -> Pipeline:
- pipe = Pipeline()
-
- pipe.add_component(self.text_splitter, "splitter")
- pipe.add_component(SchemaBuilder(), "schema")
- pipe.add_component(
- LLMEntityRelationExtractor(
- llm=self.llm,
- on_error=self.on_error,
- prompt_template=self.prompt_template,
- ),
- "extractor",
- )
- pipe.add_component(TextChunkEmbedder(embedder=self.embedder), "chunk_embedder")
- pipe.add_component(self.kg_writer, "writer")
-
- if self.from_pdf:
- pipe.add_component(self.pdf_loader, "pdf_loader")
-
- pipe.connect(
- "pdf_loader",
- "splitter",
- input_config={"text": "pdf_loader.text"},
- )
-
- pipe.connect(
- "schema",
- "extractor",
- input_config={
- "schema": "schema",
- "document_info": "pdf_loader.document_info",
- },
- )
- else:
- pipe.connect(
- "schema",
- "extractor",
- input_config={
- "schema": "schema",
- },
- )
-
- pipe.connect(
- "splitter", "chunk_embedder", input_config={"text_chunks": "splitter"}
- )
-
- pipe.connect(
- "chunk_embedder", "extractor", input_config={"chunks": "chunk_embedder"}
- )
-
- # Connect extractor to writer
- pipe.connect(
- "extractor",
- "writer",
- input_config={"graph": "extractor"},
- )
-
- if self.perform_entity_resolution:
- pipe.add_component(
- SinglePropertyExactMatchResolver(
- self.driver, neo4j_database=self.neo4j_database
- ),
- "resolver",
+ config = SimpleKGPipelineConfig(
+ # argument type are fixed in the Config object
+ llm_config=llm, # type: ignore[arg-type]
+ neo4j_config=driver, # type: ignore[arg-type]
+ embedder_config=embedder, # type: ignore[arg-type]
+ entities=entities or [],
+ relations=relations or [],
+ potential_schema=potential_schema,
+ from_pdf=from_pdf,
+ pdf_loader=pdf_loader,
+ kg_writer=kg_writer,
+ text_splitter=text_splitter,
+ on_error=on_error, # type: ignore[arg-type]
+ prompt_template=prompt_template,
+ perform_entity_resolution=perform_entity_resolution,
+ lexical_graph_config=lexical_graph_config,
+ neo4j_database=neo4j_database,
)
- pipe.connect("writer", "resolver", {})
+ except ValidationError as e:
+ raise PipelineDefinitionError() from e
- return pipe
+ self.runner = PipelineRunner.from_config(config)
async def run_async(
self, file_path: Optional[str] = None, text: Optional[str] = None
@@ -261,39 +123,4 @@ async def run_async(
Returns:
PipelineResult: The result of the pipeline execution.
"""
- pipe_inputs = self._prepare_inputs(file_path=file_path, text=text)
- return await self.pipeline.run(pipe_inputs)
-
- def _prepare_inputs(
- self, file_path: Optional[str], text: Optional[str]
- ) -> dict[str, Any]:
- if self.from_pdf:
- if file_path is None or text is not None:
- raise PipelineDefinitionError(
- "Expected 'file_path' argument when 'from_pdf' is True."
- )
- else:
- if text is None or file_path is not None:
- raise PipelineDefinitionError(
- "Expected 'text' argument when 'from_pdf' is False."
- )
-
- pipe_inputs: dict[str, Any] = {
- "schema": {
- "entities": self.entities,
- "relations": self.relations,
- "potential_schema": self.potential_schema,
- },
- }
-
- if self.from_pdf:
- pipe_inputs["pdf_loader"] = {"filepath": file_path}
- else:
- pipe_inputs["splitter"] = {"text": text}
-
- if self.lexical_graph_config:
- pipe_inputs["extractor"] = {
- "lexical_graph_config": self.lexical_graph_config
- }
-
- return pipe_inputs
+ return await self.runner.run({"file_path": file_path, "text": text})
diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py
index 5edc2783..e3ded494 100644
--- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py
+++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py
@@ -44,9 +44,9 @@
)
from neo4j_graphrag.experimental.pipeline.stores import InMemoryStore, ResultStore
from neo4j_graphrag.experimental.pipeline.types import (
- ComponentConfig,
- ConnectionConfig,
- PipelineConfig,
+ ComponentDefinition,
+ ConnectionDefinition,
+ PipelineDefinition,
)
logger = logging.getLogger(__name__)
@@ -349,16 +349,34 @@ def __init__(self, store: Optional[ResultStore] = None) -> None:
@classmethod
def from_template(
- cls, pipeline_template: PipelineConfig, store: Optional[ResultStore] = None
+ cls, pipeline_template: PipelineDefinition, store: Optional[ResultStore] = None
) -> Pipeline:
- """Create a Pipeline from a pydantic model defining the components and their connections"""
+ warnings.warn(
+ "from_template is deprecated, use from_definition instead",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return cls.from_definition(pipeline_template, store)
+
+ @classmethod
+ def from_definition(
+ cls,
+ pipeline_definition: PipelineDefinition,
+ store: Optional[ResultStore] = None,
+ ) -> Pipeline:
+ """Create a Pipeline from a pydantic model defining the components and their connections
+
+ Args:
+ pipeline_definition (PipelineDefinition): An object defining components and how they are connected to each other.
+ store (Optional[ResultStore]): Where the results are stored. By default, uses the InMemoryStore.
+ """
pipeline = Pipeline(store=store)
- for component in pipeline_template.components:
+ for component in pipeline_definition.components:
pipeline.add_component(
component.component,
component.name,
)
- for edge in pipeline_template.connections:
+ for edge in pipeline_definition.connections:
pipeline_edge = PipelineEdge(
edge.start, edge.end, data={"input_config": edge.input_config}
)
@@ -369,18 +387,18 @@ def show_as_dict(self) -> dict[str, Any]:
component_config = []
for name, task in self._nodes.items():
component_config.append(
- ComponentConfig(name=name, component=task.component)
+ ComponentDefinition(name=name, component=task.component)
)
connection_config = []
for edge in self._edges:
connection_config.append(
- ConnectionConfig(
+ ConnectionDefinition(
start=edge.start,
end=edge.end,
input_config=edge.data["input_config"] if edge.data else {},
)
)
- pipeline_config = PipelineConfig(
+ pipeline_config = PipelineDefinition(
components=component_config, connections=connection_config
)
return pipeline_config.model_dump()
diff --git a/src/neo4j_graphrag/experimental/pipeline/types.py b/src/neo4j_graphrag/experimental/pipeline/types.py
index ebdf141d..47aafd8b 100644
--- a/src/neo4j_graphrag/experimental/pipeline/types.py
+++ b/src/neo4j_graphrag/experimental/pipeline/types.py
@@ -14,29 +14,36 @@
# limitations under the License.
from __future__ import annotations
-from typing import Union
+from collections import defaultdict
+from typing import Any, Union
from pydantic import BaseModel, ConfigDict
from neo4j_graphrag.experimental.pipeline.component import Component
-class ComponentConfig(BaseModel):
+class ComponentDefinition(BaseModel):
name: str
component: Component
+ run_params: dict[str, Any] = {}
model_config = ConfigDict(arbitrary_types_allowed=True)
-class ConnectionConfig(BaseModel):
+class ConnectionDefinition(BaseModel):
start: str
end: str
input_config: dict[str, str]
-class PipelineConfig(BaseModel):
- components: list[ComponentConfig]
- connections: list[ConnectionConfig]
+class PipelineDefinition(BaseModel):
+ components: list[ComponentDefinition]
+ connections: list[ConnectionDefinition]
+
+ def get_run_params(self) -> defaultdict[str, dict[str, Any]]:
+ return defaultdict(
+ dict, {c.name: c.run_params for c in self.components if c.run_params}
+ )
EntityInputType = Union[str, dict[str, Union[str, list[dict[str, str]]]]]
diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py
index 70ca0193..42c21cf8 100644
--- a/tests/e2e/conftest.py
+++ b/tests/e2e/conftest.py
@@ -14,6 +14,7 @@
# limitations under the License.
from __future__ import annotations
+import os
import random
import string
import uuid
@@ -33,6 +34,8 @@
from ..e2e.utils import EMBEDDING_BIOLOGY
+BASE_DIR = os.path.dirname(os.path.abspath(__file__))
+
@pytest.fixture(scope="module")
def driver() -> Generator[Any, Any, Any]:
@@ -48,6 +51,12 @@ def llm() -> MagicMock:
return MagicMock(spec=LLMInterface)
+@pytest.fixture
+def embedder() -> Embedder:
+ embedder = MagicMock(spec=Embedder)
+ return embedder
+
+
class RandomEmbedder(Embedder):
def embed_query(self, text: str) -> list[float]:
return [random.random() for _ in range(1536)]
@@ -75,6 +84,31 @@ def retriever_mock() -> MagicMock:
return MagicMock(spec=VectorRetriever)
+@pytest.fixture
+def harry_potter_text() -> str:
+ with open(os.path.join(BASE_DIR, "data/documents/harry_potter.txt"), "r") as f:
+ text = f.read()
+ return text
+
+
+@pytest.fixture
+def harry_potter_text_part1() -> str:
+ with open(
+ os.path.join(BASE_DIR, "data/documents/harry_potter_part1.txt"), "r"
+ ) as f:
+ text = f.read()
+ return text
+
+
+@pytest.fixture
+def harry_potter_text_part2() -> str:
+ with open(
+ os.path.join(BASE_DIR, "data/documents/harry_potter_part2.txt"), "r"
+ ) as f:
+ text = f.read()
+ return text
+
+
@pytest.fixture(scope="module")
def setup_neo4j_for_retrieval(driver: Driver) -> None:
vector_index_name = "vector-index-name"
diff --git a/tests/e2e/data/config_files/pipeline_config.json b/tests/e2e/data/config_files/pipeline_config.json
new file mode 100644
index 00000000..fe36624d
--- /dev/null
+++ b/tests/e2e/data/config_files/pipeline_config.json
@@ -0,0 +1,72 @@
+{
+ "version_": "1",
+ "template_": "none",
+ "name": "",
+ "neo4j_config": {
+ "params_": {
+ "uri": {
+ "resolver_": "ENV",
+ "var_": "NEO4J_URI"
+ },
+ "user": {
+ "resolver_": "ENV",
+ "var_": "NEO4J_USER"
+ },
+ "password": {
+ "resolver_": "ENV",
+ "var_": "NEO4J_PASSWORD"
+ }
+ }
+ },
+ "extras": {
+ "database": "neo4j"
+ },
+ "component_config": {
+ "splitter": {
+ "class_": "text_splitters.fixed_size_splitter.FixedSizeSplitter",
+ "params_": {
+ "chunk_size": 100,
+ "chunk_overlap": 10
+ }
+ },
+ "builder": {
+ "class_": "lexical_graph.LexicalGraphBuilder",
+ "params_": {
+ "config": {
+ "chunk_node_label": "TextPart"
+ }
+ }
+ },
+ "writer": {
+ "name_": "writer",
+ "class_": "kg_writer.Neo4jWriter",
+ "params_": {
+ "driver": {
+ "resolver_": "CONFIG_KEY",
+ "key_": "neo4j_config.default"
+ },
+ "neo4j_database": {
+ "resolver_": "CONFIG_KEY",
+ "key_": "extras.database"
+ }
+ }
+ }
+ },
+ "connection_config": [
+ {
+ "start": "splitter",
+ "end": "builder",
+ "input_config": {
+ "text_chunks": "splitter"
+ }
+ },
+ {
+ "start": "builder",
+ "end": "writer",
+ "input_config": {
+ "graph": "builder.graph",
+ "lexical_graph_config": "builder.config"
+ }
+ }
+ ]
+}
diff --git a/tests/e2e/data/config_files/pipeline_config.yaml b/tests/e2e/data/config_files/pipeline_config.yaml
new file mode 100644
index 00000000..87ac905e
--- /dev/null
+++ b/tests/e2e/data/config_files/pipeline_config.yaml
@@ -0,0 +1,45 @@
+version_: "1"
+template_: none
+neo4j_config:
+ params_:
+ uri:
+ resolver_: ENV
+ var_: NEO4J_URI
+ user:
+ resolver_: ENV
+ var_: NEO4J_USER
+ password:
+ resolver_: ENV
+ var_: NEO4J_PASSWORD
+extras:
+ database: neo4j
+component_config:
+ splitter:
+ class_: text_splitters.fixed_size_splitter.FixedSizeSplitter
+ params_:
+ chunk_size: 100
+ chunk_overlap: 10
+ builder:
+ class_: lexical_graph.LexicalGraphBuilder
+ params_:
+ config:
+ chunk_node_label: TextPart
+ writer:
+ class_: kg_writer.Neo4jWriter
+ params_:
+ driver:
+ resolver_: CONFIG_KEY
+ key_: neo4j_config.default
+ neo4j_database:
+ resolver_: CONFIG_KEY
+ key_: extras.database
+connection_config:
+ - start: splitter
+ end: builder
+ input_config:
+ text_chunks: splitter
+ - start: builder
+ end: writer
+ input_config:
+ graph: builder.graph
+ lexical_graph_config: builder.config
diff --git a/tests/e2e/data/config_files/simple_kg_pipeline_config.json b/tests/e2e/data/config_files/simple_kg_pipeline_config.json
new file mode 100644
index 00000000..c2d629fb
--- /dev/null
+++ b/tests/e2e/data/config_files/simple_kg_pipeline_config.json
@@ -0,0 +1,64 @@
+{
+ "version_": "1",
+ "template_": "SimpleKGPipeline",
+ "neo4j_config": {
+ "params_": {
+ "uri": {
+ "resolver_": "ENV",
+ "var_": "NEO4J_URI"
+ },
+ "user": {
+ "resolver_": "ENV",
+ "var_": "NEO4J_USER"
+ },
+ "password": {
+ "resolver_": "ENV",
+ "var_": "NEO4J_PASSWORD"
+ }
+ }
+ },
+ "llm_config": {
+ "class_": "OpenAILLM",
+ "params_": {
+ "api_key": {
+ "resolver_": "ENV",
+ "var_": "OPENAI_API_KEY"
+ },
+ "model_name": "gpt-4o",
+ "model_params": {
+ "temperature": 0,
+ "max_tokens": 2000,
+ "response_format": {"type": "json_object"}
+ }
+ }
+ },
+ "embedder_config": {
+ "class_": "OpenAIEmbeddings",
+ "params_": {
+ "api_key": {
+ "resolver_": "ENV",
+ "var_": "OPENAI_API_KEY"
+ }
+ }
+ },
+ "from_pdf": true,
+ "entities": [
+ "Person",
+ "Organization",
+ "Horcrux",
+ "Location"
+ ],
+ "relations": [
+ "SITUATED_AT",
+ "INTERACTS",
+ "OWNS",
+ "LED_BY"
+ ],
+ "potential_schema": [
+ ["Person", "SITUATED_AT", "Location"],
+ ["Person", "INTERACTS", "Person"],
+ ["Person", "OWNS", "Horcrux"],
+ ["Organization", "LED_BY", "Person"]
+ ],
+ "perform_entity_resolution": true
+}
diff --git a/tests/e2e/data/config_files/simple_kg_pipeline_config.yaml b/tests/e2e/data/config_files/simple_kg_pipeline_config.yaml
new file mode 100644
index 00000000..abcf9632
--- /dev/null
+++ b/tests/e2e/data/config_files/simple_kg_pipeline_config.yaml
@@ -0,0 +1,50 @@
+version_: "1"
+template_: SimpleKGPipeline
+neo4j_config:
+ params_:
+ uri:
+ resolver_: ENV
+ var_: NEO4J_URI
+ user:
+ resolver_: ENV
+ var_: NEO4J_USER
+ password:
+ resolver_: ENV
+ var_: NEO4J_PASSWORD
+llm_config:
+ class_: OpenAILLM
+ params_:
+ api_key:
+ resolver_: ENV
+ var_: OPENAI_API_KEY
+ model_name: gpt-4o
+ model_params:
+ temperature: 0
+ max_tokens: 2000
+ response_format:
+ type: json_object
+embedder_config:
+ class_: OpenAIEmbeddings
+ params_:
+ api_key:
+ resolver_: ENV
+ var_: OPENAI_API_KEY
+from_pdf: true
+entities:
+ - Person
+ - Organization
+ - Location
+ - Horcrux
+relations:
+ - SITUATED_AT
+ - INTERACTS
+ - OWNS
+ - LED_BY
+potential_schema:
+ - ["Person", "SITUATED_AT", "Location"]
+ - ["Person", "INTERACTS", "Person"]
+ - ["Person", "OWNS", "Horcrux"]
+ - ["Organization", "LED_BY", "Person"]
+text_splitter:
+ class_: text_splitters.fixed_size_splitter.FixedSizeSplitter
+perform_entity_resolution: true
diff --git a/tests/e2e/data/documents/harry_potter.pdf b/tests/e2e/data/documents/harry_potter.pdf
new file mode 100644
index 00000000..49f0f6bd
Binary files /dev/null and b/tests/e2e/data/documents/harry_potter.pdf differ
diff --git a/tests/e2e/data/harry_potter.txt b/tests/e2e/data/documents/harry_potter.txt
similarity index 100%
rename from tests/e2e/data/harry_potter.txt
rename to tests/e2e/data/documents/harry_potter.txt
diff --git a/tests/e2e/experimental/__init__.py b/tests/e2e/experimental/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/e2e/experimental/pipeline/__init__.py b/tests/e2e/experimental/pipeline/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/e2e/experimental/pipeline/config/__init__.py b/tests/e2e/experimental/pipeline/config/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/e2e/experimental/pipeline/config/test_pipeline_runner_e2e.py b/tests/e2e/experimental/pipeline/config/test_pipeline_runner_e2e.py
new file mode 100644
index 00000000..9122a935
--- /dev/null
+++ b/tests/e2e/experimental/pipeline/config/test_pipeline_runner_e2e.py
@@ -0,0 +1,208 @@
+import os
+from typing import Any
+from unittest.mock import AsyncMock, Mock, patch
+
+import neo4j
+import pytest
+from neo4j_graphrag.experimental.pipeline.config.runner import PipelineRunner
+from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
+from neo4j_graphrag.llm import LLMResponse
+
+
+@pytest.fixture(scope="function", autouse=True)
+def clear_db(driver: neo4j.Driver) -> Any:
+ driver.execute_query("MATCH (n) DETACH DELETE n")
+ yield
+
+
+@pytest.mark.asyncio
+async def test_pipeline_from_json_config(harry_potter_text: str, driver: Mock) -> None:
+ os.environ["NEO4J_URI"] = "neo4j://localhost:7687"
+ os.environ["NEO4J_USER"] = "neo4j"
+ os.environ["NEO4J_PASSWORD"] = "password"
+
+ runner = PipelineRunner.from_config_file(
+ "tests/e2e/data/config_files/pipeline_config.json"
+ )
+ res = await runner.run({"splitter": {"text": harry_potter_text}})
+ assert isinstance(res, PipelineResult)
+ assert res.result["writer"]["metadata"] == {
+ "node_count": 11,
+ "relationship_count": 10,
+ }
+ nodes = driver.execute_query("MATCH (n) RETURN n")
+ assert len(nodes.records) == 11
+
+
+@pytest.mark.asyncio
+async def test_pipeline_from_yaml_config(harry_potter_text: str, driver: Mock) -> None:
+ os.environ["NEO4J_URI"] = "neo4j://localhost:7687"
+ os.environ["NEO4J_USER"] = "neo4j"
+ os.environ["NEO4J_PASSWORD"] = "password"
+
+ runner = PipelineRunner.from_config_file(
+ "tests/e2e/data/config_files/pipeline_config.yaml"
+ )
+ res = await runner.run({"splitter": {"text": harry_potter_text}})
+ assert isinstance(res, PipelineResult)
+ assert res.result["writer"]["metadata"] == {
+ "node_count": 11,
+ "relationship_count": 10,
+ }
+
+ nodes = driver.execute_query("MATCH (n) RETURN n")
+ assert len(nodes.records) == 11
+
+
+@patch(
+ "neo4j_graphrag.experimental.pipeline.config.runner.SimpleKGPipelineConfig.get_default_embedder"
+)
+@patch(
+ "neo4j_graphrag.experimental.pipeline.config.runner.SimpleKGPipelineConfig.get_default_llm"
+)
+@pytest.mark.asyncio
+async def test_simple_kg_pipeline_from_json_config(
+ mock_llm: Mock, mock_embedder: Mock, harry_potter_text: str, driver: Mock
+) -> None:
+ mock_llm.return_value.ainvoke = AsyncMock(
+ side_effect=[
+ LLMResponse(
+ content="""{
+ "nodes": [
+ {
+ "id": "0",
+ "label": "Person",
+ "properties": {
+ "name": "Harry Potter"
+ }
+ },
+ {
+ "id": "1",
+ "label": "Person",
+ "properties": {
+ "name": "Alastor Mad-Eye Moody"
+ }
+ },
+ {
+ "id": "2",
+ "label": "Organization",
+ "properties": {
+ "name": "The Order of the Phoenix"
+ }
+ }
+ ],
+ "relationships": [
+ {
+ "type": "KNOWS",
+ "start_node_id": "0",
+ "end_node_id": "1"
+ },
+ {
+ "type": "LED_BY",
+ "start_node_id": "2",
+ "end_node_id": "1"
+ }
+ ]
+ }"""
+ ),
+ ]
+ )
+ mock_embedder.return_value.embed_query.side_effect = [
+ [1.0, 2.0],
+ ]
+
+ os.environ["NEO4J_URI"] = "neo4j://localhost:7687"
+ os.environ["NEO4J_USER"] = "neo4j"
+ os.environ["NEO4J_PASSWORD"] = "password"
+ os.environ["OPENAI_API_KEY"] = "sk-my-secret-key"
+
+ runner = PipelineRunner.from_config_file(
+ "tests/e2e/data/config_files/simple_kg_pipeline_config.json"
+ )
+ res = await runner.run({"file_path": "tests/e2e/data/documents/harry_potter.pdf"})
+ assert isinstance(res, PipelineResult)
+ # print(await runner.pipeline.store.get_result_for_component(res.run_id, "splitter"))
+ assert res.result["resolver"] == {
+ "number_of_nodes_to_resolve": 3,
+ "number_of_created_nodes": 3,
+ }
+ nodes = driver.execute_query("MATCH (n) RETURN n")
+ # 1 chunk + 1 document + 3 nodes
+ assert len(nodes.records) == 5
+
+
+@patch(
+ "neo4j_graphrag.experimental.pipeline.config.runner.SimpleKGPipelineConfig.get_default_embedder"
+)
+@patch(
+ "neo4j_graphrag.experimental.pipeline.config.runner.SimpleKGPipelineConfig.get_default_llm"
+)
+@pytest.mark.asyncio
+async def test_simple_kg_pipeline_from_yaml_config(
+ mock_llm: Mock, mock_embedder: Mock, harry_potter_text: str, driver: Mock
+) -> None:
+ mock_llm.return_value.ainvoke = AsyncMock(
+ side_effect=[
+ LLMResponse(
+ content="""{
+ "nodes": [
+ {
+ "id": "0",
+ "label": "Person",
+ "properties": {
+ "name": "Harry Potter"
+ }
+ },
+ {
+ "id": "1",
+ "label": "Person",
+ "properties": {
+ "name": "Alastor Mad-Eye Moody"
+ }
+ },
+ {
+ "id": "2",
+ "label": "Organization",
+ "properties": {
+ "name": "The Order of the Phoenix"
+ }
+ }
+ ],
+ "relationships": [
+ {
+ "type": "KNOWS",
+ "start_node_id": "0",
+ "end_node_id": "1"
+ },
+ {
+ "type": "LED_BY",
+ "start_node_id": "2",
+ "end_node_id": "1"
+ }
+ ]
+ }"""
+ ),
+ ]
+ )
+ mock_embedder.return_value.embed_query.side_effect = [
+ [1.0, 2.0],
+ ]
+
+ os.environ["NEO4J_URI"] = "neo4j://localhost:7687"
+ os.environ["NEO4J_USER"] = "neo4j"
+ os.environ["NEO4J_PASSWORD"] = "password"
+ os.environ["OPENAI_API_KEY"] = "sk-my-secret-key"
+
+ runner = PipelineRunner.from_config_file(
+ "tests/e2e/data/config_files/simple_kg_pipeline_config.yaml"
+ )
+ res = await runner.run({"file_path": "tests/e2e/data/documents/harry_potter.pdf"})
+ assert isinstance(res, PipelineResult)
+ # print(await runner.pipeline.store.get_result_for_component(res.run_id, "splitter"))
+ assert res.result["resolver"] == {
+ "number_of_nodes_to_resolve": 3,
+ "number_of_created_nodes": 3,
+ }
+ nodes = driver.execute_query("MATCH (n) RETURN n")
+ # 1 chunk + 1 document + 3 nodes
+ assert len(nodes.records) == 5
diff --git a/tests/e2e/test_kg_builder_pipeline_e2e.py b/tests/e2e/test_kg_builder_pipeline_e2e.py
index 1713098b..bf74470c 100644
--- a/tests/e2e/test_kg_builder_pipeline_e2e.py
+++ b/tests/e2e/test_kg_builder_pipeline_e2e.py
@@ -124,31 +124,6 @@ def kg_builder_pipeline(
return pipe
-@pytest.fixture
-def harry_potter_text() -> str:
- with open(os.path.join(BASE_DIR, "data/harry_potter.txt"), "r") as f:
- text = f.read()
- return text
-
-
-@pytest.fixture
-def harry_potter_text_part1() -> str:
- with open(
- os.path.join(BASE_DIR, "data/documents/harry_potter_part1.txt"), "r"
- ) as f:
- text = f.read()
- return text
-
-
-@pytest.fixture
-def harry_potter_text_part2() -> str:
- with open(
- os.path.join(BASE_DIR, "data/documents/harry_potter_part2.txt"), "r"
- ) as f:
- text = f.read()
- return text
-
-
@pytest.mark.asyncio
@pytest.mark.usefixtures("setup_neo4j_for_kg_construction")
async def test_pipeline_builder_happy_path(
diff --git a/tests/e2e/test_simplekgpipeline_e2e.py b/tests/e2e/test_simplekgpipeline_e2e.py
index def867f5..4d3059c7 100644
--- a/tests/e2e/test_simplekgpipeline_e2e.py
+++ b/tests/e2e/test_simplekgpipeline_e2e.py
@@ -14,36 +14,13 @@
# limitations under the License.
from __future__ import annotations
-import os
from unittest.mock import MagicMock
import neo4j
import pytest
-from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
-from neo4j_graphrag.llm import LLMInterface, LLMResponse
-
-BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-
-
-@pytest.fixture
-def llm() -> LLMInterface:
- llm = MagicMock(spec=LLMInterface)
- return llm
-
-
-@pytest.fixture
-def embedder() -> Embedder:
- embedder = MagicMock(spec=Embedder)
- return embedder
-
-
-@pytest.fixture
-def harry_potter_text() -> str:
- with open(os.path.join(BASE_DIR, "data/harry_potter.txt"), "r") as f:
- text = f.read()
- return text
+from neo4j_graphrag.llm import LLMResponse
@pytest.mark.asyncio
diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py
index 829cad23..5069ab6e 100644
--- a/tests/unit/conftest.py
+++ b/tests/unit/conftest.py
@@ -19,6 +19,7 @@
import neo4j
import pytest
from neo4j_graphrag.embeddings.base import Embedder
+from neo4j_graphrag.experimental.pipeline import Component
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.retrievers import (
HybridRetriever,
@@ -98,3 +99,8 @@ def format_function(record: neo4j.Record) -> RetrieverResultItem:
)
return format_function
+
+
+@pytest.fixture(scope="function")
+def component() -> MagicMock:
+ return MagicMock(spec=Component)
diff --git a/tests/unit/experimental/pipeline/config/__init__.py b/tests/unit/experimental/pipeline/config/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/__init__.py b/tests/unit/experimental/pipeline/config/template_pipeline/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/test_base.py b/tests/unit/experimental/pipeline/config/template_pipeline/test_base.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py
new file mode 100644
index 00000000..e8d12095
--- /dev/null
+++ b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py
@@ -0,0 +1,301 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [https://neo4j.com]
+# #
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# #
+# https://www.apache.org/licenses/LICENSE-2.0
+# #
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import Mock, patch
+
+import neo4j
+import pytest
+from neo4j_graphrag.embeddings import Embedder
+from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
+from neo4j_graphrag.experimental.components.entity_relation_extractor import (
+ LLMEntityRelationExtractor,
+ OnError,
+)
+from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter
+from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader
+from neo4j_graphrag.experimental.components.schema import (
+ SchemaBuilder,
+ SchemaEntity,
+ SchemaRelation,
+)
+from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
+ FixedSizeSplitter,
+)
+from neo4j_graphrag.experimental.pipeline.config.object_config import ComponentConfig
+from neo4j_graphrag.experimental.pipeline.config.template_pipeline import (
+ SimpleKGPipelineConfig,
+)
+from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
+from neo4j_graphrag.generation.prompts import ERExtractionTemplate
+from neo4j_graphrag.llm import LLMInterface
+
+
+def test_simple_kg_pipeline_config_pdf_loader_from_pdf_is_false() -> None:
+ config = SimpleKGPipelineConfig(from_pdf=False)
+ assert config._get_pdf_loader() is None
+
+
+def test_simple_kg_pipeline_config_pdf_loader_from_pdf_is_true() -> None:
+ config = SimpleKGPipelineConfig(from_pdf=True)
+ assert isinstance(config._get_pdf_loader(), PdfLoader)
+
+
+def test_simple_kg_pipeline_config_pdf_loader_from_pdf_is_true_class_overwrite() -> (
+ None
+):
+ my_pdf_loader = PdfLoader()
+ config = SimpleKGPipelineConfig(from_pdf=True, pdf_loader=my_pdf_loader) # type: ignore
+ assert config._get_pdf_loader() == my_pdf_loader
+
+
+def test_simple_kg_pipeline_config_pdf_loader_class_overwrite_but_from_pdf_is_false() -> (
+ None
+):
+ my_pdf_loader = PdfLoader()
+ config = SimpleKGPipelineConfig(from_pdf=False, pdf_loader=my_pdf_loader) # type: ignore
+ assert config._get_pdf_loader() is None
+
+
+@patch("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse")
+def test_simple_kg_pipeline_config_pdf_loader_from_pdf_is_true_class_overwrite_from_config(
+ mock_component_parse: Mock,
+) -> None:
+ my_pdf_loader_config = ComponentConfig(
+ class_="",
+ )
+ my_pdf_loader = PdfLoader()
+ mock_component_parse.return_value = my_pdf_loader
+ config = SimpleKGPipelineConfig(
+ from_pdf=True,
+ pdf_loader=my_pdf_loader_config, # type: ignore
+ )
+ assert config._get_pdf_loader() == my_pdf_loader
+
+
+def test_simple_kg_pipeline_config_text_splitter() -> None:
+ config = SimpleKGPipelineConfig()
+ assert isinstance(config._get_splitter(), FixedSizeSplitter)
+
+
+@patch("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse")
+def test_simple_kg_pipeline_config_text_splitter_overwrite(
+ mock_component_parse: Mock,
+) -> None:
+ my_text_splitter_config = ComponentConfig(
+ class_="",
+ )
+ my_text_splitter = FixedSizeSplitter()
+ mock_component_parse.return_value = my_text_splitter
+ config = SimpleKGPipelineConfig(
+ text_splitter=my_text_splitter_config, # type: ignore
+ )
+ assert config._get_splitter() == my_text_splitter
+
+
+@patch(
+ "neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_embedder"
+)
+def test_simple_kg_pipeline_config_chunk_embedder(
+ mock_embedder: Mock, embedder: Embedder
+) -> None:
+ mock_embedder.return_value = embedder
+ config = SimpleKGPipelineConfig()
+ chunk_embedder = config._get_chunk_embedder()
+ assert isinstance(chunk_embedder, TextChunkEmbedder)
+ assert chunk_embedder._embedder == embedder
+
+
+def test_simple_kg_pipeline_config_schema() -> None:
+ config = SimpleKGPipelineConfig()
+ assert isinstance(config._get_schema(), SchemaBuilder)
+
+
+def test_simple_kg_pipeline_config_schema_run_params() -> None:
+ config = SimpleKGPipelineConfig(
+ entities=["Person"],
+ relations=["KNOWS"],
+ potential_schema=[("Person", "KNOWS", "Person")],
+ )
+ assert config._get_run_params_for_schema() == {
+ "entities": [SchemaEntity(label="Person")],
+ "relations": [SchemaRelation(label="KNOWS")],
+ "potential_schema": [
+ ("Person", "KNOWS", "Person"),
+ ],
+ }
+
+
+@patch(
+ "neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_llm"
+)
+def test_simple_kg_pipeline_config_extractor(mock_llm: Mock, llm: LLMInterface) -> None:
+ mock_llm.return_value = llm
+ config = SimpleKGPipelineConfig(
+ on_error="IGNORE", # type: ignore
+ prompt_template=ERExtractionTemplate(template="my template {text}"),
+ )
+ extractor = config._get_extractor()
+ assert isinstance(extractor, LLMEntityRelationExtractor)
+ assert extractor.llm == llm
+ assert extractor.on_error == OnError.IGNORE
+ assert extractor.prompt_template.template == "my template {text}"
+
+
+@patch(
+ "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version",
+ return_value=(5, 23, 0),
+)
+@patch(
+ "neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_neo4j_driver"
+)
+def test_simple_kg_pipeline_config_writer(
+ mock_driver: Mock,
+ _: Mock,
+ driver: neo4j.Driver,
+) -> None:
+ mock_driver.return_value = driver
+ config = SimpleKGPipelineConfig(
+ neo4j_database="my_db",
+ )
+ writer = config._get_writer()
+ assert isinstance(writer, Neo4jWriter)
+ assert writer.driver == driver
+ assert writer.neo4j_database == "my_db"
+
+
+@patch(
+ "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version",
+ return_value=(5, 23, 0),
+)
+@patch("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse")
+def test_simple_kg_pipeline_config_writer_overwrite(
+ mock_component_parse: Mock,
+ _: Mock,
+ driver: neo4j.Driver,
+) -> None:
+ my_writer_config = ComponentConfig(
+ class_="",
+ )
+ my_writer = Neo4jWriter(driver, neo4j_database="my_db")
+ mock_component_parse.return_value = my_writer
+ config = SimpleKGPipelineConfig(
+ kg_writer=my_writer_config, # type: ignore
+ neo4j_database="my_other_db",
+ )
+ writer: Neo4jWriter = config._get_writer() # type: ignore
+ assert writer == my_writer
+ # database not changed:
+ assert writer.neo4j_database == "my_db"
+
+
+def test_simple_kg_pipeline_config_connections_from_pdf() -> None:
+ config = SimpleKGPipelineConfig(
+ from_pdf=True,
+ perform_entity_resolution=False,
+ )
+ connections = config._get_connections()
+ assert len(connections) == 5
+ expected_connections = [
+ ("pdf_loader", "splitter"),
+ ("schema", "extractor"),
+ ("splitter", "chunk_embedder"),
+ ("chunk_embedder", "extractor"),
+ ("extractor", "writer"),
+ ]
+ for actual, expected in zip(connections, expected_connections):
+ assert (actual.start, actual.end) == expected
+
+
+def test_simple_kg_pipeline_config_connections_from_text() -> None:
+ config = SimpleKGPipelineConfig(
+ from_pdf=False,
+ perform_entity_resolution=False,
+ )
+ connections = config._get_connections()
+ assert len(connections) == 4
+ expected_connections = [
+ ("schema", "extractor"),
+ ("splitter", "chunk_embedder"),
+ ("chunk_embedder", "extractor"),
+ ("extractor", "writer"),
+ ]
+ for actual, expected in zip(connections, expected_connections):
+ assert (actual.start, actual.end) == expected
+
+
+def test_simple_kg_pipeline_config_connections_with_er() -> None:
+ config = SimpleKGPipelineConfig(
+ from_pdf=True,
+ perform_entity_resolution=True,
+ )
+ connections = config._get_connections()
+ assert len(connections) == 6
+ expected_connections = [
+ ("pdf_loader", "splitter"),
+ ("schema", "extractor"),
+ ("splitter", "chunk_embedder"),
+ ("chunk_embedder", "extractor"),
+ ("extractor", "writer"),
+ ("writer", "resolver"),
+ ]
+ for actual, expected in zip(connections, expected_connections):
+ assert (actual.start, actual.end) == expected
+
+
+def test_simple_kg_pipeline_config_run_params_from_pdf_file_path() -> None:
+ config = SimpleKGPipelineConfig(from_pdf=True)
+ assert config.get_run_params({"file_path": "my_file"}) == {
+ "pdf_loader": {"filepath": "my_file"}
+ }
+
+
+def test_simple_kg_pipeline_config_run_params_from_text_text() -> None:
+ config = SimpleKGPipelineConfig(from_pdf=False)
+ assert config.get_run_params({"text": "my text"}) == {
+ "splitter": {"text": "my text"}
+ }
+
+
+def test_simple_kg_pipeline_config_run_params_from_pdf_text() -> None:
+ config = SimpleKGPipelineConfig(from_pdf=True)
+ with pytest.raises(PipelineDefinitionError) as excinfo:
+ config.get_run_params({"text": "my text"})
+ assert "Expected 'file_path' argument when 'from_pdf' is True" in str(excinfo)
+
+
+def test_simple_kg_pipeline_config_run_params_from_text_file_path() -> None:
+ config = SimpleKGPipelineConfig(from_pdf=False)
+ with pytest.raises(PipelineDefinitionError) as excinfo:
+ config.get_run_params({"file_path": "my file"})
+ assert "Expected 'text' argument when 'from_pdf' is False" in str(excinfo)
+
+
+def test_simple_kg_pipeline_config_run_params_no_file_no_text() -> None:
+ config = SimpleKGPipelineConfig(from_pdf=False)
+ with pytest.raises(PipelineDefinitionError) as excinfo:
+ config.get_run_params({})
+ assert (
+ "Use either 'text' (when from_pdf=False) or 'file_path' (when from_pdf=True) argument."
+ in str(excinfo)
+ )
+
+
+def test_simple_kg_pipeline_config_run_params_both_file_and_text() -> None:
+ config = SimpleKGPipelineConfig(from_pdf=False)
+ with pytest.raises(PipelineDefinitionError) as excinfo:
+ config.get_run_params({"text": "my text", "file_path": "my file"})
+ assert (
+ "Use either 'text' (when from_pdf=False) or 'file_path' (when from_pdf=True) argument."
+ in str(excinfo)
+ )
diff --git a/tests/unit/experimental/pipeline/config/test_base.py b/tests/unit/experimental/pipeline/config/test_base.py
new file mode 100644
index 00000000..20332a79
--- /dev/null
+++ b/tests/unit/experimental/pipeline/config/test_base.py
@@ -0,0 +1,37 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [https://neo4j.com]
+# #
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# #
+# https://www.apache.org/licenses/LICENSE-2.0
+# #
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import patch
+
+from neo4j_graphrag.experimental.pipeline.config.base import AbstractConfig
+from neo4j_graphrag.experimental.pipeline.config.param_resolver import (
+ ParamToResolveConfig,
+)
+
+
+def test_resolve_param_with_param_to_resolve_object() -> None:
+ c = AbstractConfig()
+ with patch(
+ "neo4j_graphrag.experimental.pipeline.config.param_resolver.ParamToResolveConfig",
+ spec=ParamToResolveConfig,
+ ) as mock_param_class:
+ mock_param = mock_param_class.return_value
+ mock_param.resolve.return_value = 1
+ assert c.resolve_param(mock_param) == 1
+ mock_param.resolve.assert_called_once_with({})
+
+
+def test_resolve_param_with_other_object() -> None:
+ c = AbstractConfig()
+ assert c.resolve_param("value") == "value"
diff --git a/tests/unit/experimental/pipeline/config/test_object_config.py b/tests/unit/experimental/pipeline/config/test_object_config.py
new file mode 100644
index 00000000..baa3c767
--- /dev/null
+++ b/tests/unit/experimental/pipeline/config/test_object_config.py
@@ -0,0 +1,164 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [https://neo4j.com]
+# #
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# #
+# https://www.apache.org/licenses/LICENSE-2.0
+# #
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import patch
+
+import neo4j
+import pytest
+from neo4j_graphrag.embeddings import Embedder, OpenAIEmbeddings
+from neo4j_graphrag.experimental.pipeline import Pipeline
+from neo4j_graphrag.experimental.pipeline.config.object_config import (
+ EmbedderConfig,
+ EmbedderType,
+ LLMConfig,
+ LLMType,
+ Neo4jDriverConfig,
+ Neo4jDriverType,
+ ObjectConfig,
+)
+from neo4j_graphrag.llm import LLMInterface, OpenAILLM
+
+
+def test_get_class_no_optional_module() -> None:
+ c: ObjectConfig[object] = ObjectConfig()
+ klass = c._get_class("neo4j_graphrag.experimental.pipeline.Pipeline")
+ assert klass == Pipeline
+
+
+def test_get_class_optional_module() -> None:
+ c: ObjectConfig[object] = ObjectConfig()
+ klass = c._get_class(
+ "Pipeline", optional_module="neo4j_graphrag.experimental.pipeline"
+ )
+ assert klass == Pipeline
+
+
+def test_get_class_path_and_optional_module() -> None:
+ c: ObjectConfig[object] = ObjectConfig()
+ klass = c._get_class(
+ "pipeline.Pipeline", optional_module="neo4j_graphrag.experimental"
+ )
+ assert klass == Pipeline
+
+
+def test_get_class_wrong_path() -> None:
+ c: ObjectConfig[object] = ObjectConfig()
+ with pytest.raises(ValueError):
+ c._get_class("MyClass")
+
+
+def test_neo4j_driver_config() -> None:
+ config = Neo4jDriverConfig.model_validate(
+ {
+ "params_": {
+ "uri": "bolt://",
+ "user": "a user",
+ "password": "a password",
+ }
+ }
+ )
+ assert config.class_ == "not used"
+ assert config.params_ == {
+ "uri": "bolt://",
+ "user": "a user",
+ "password": "a password",
+ }
+ with patch(
+ "neo4j_graphrag.experimental.pipeline.config.object_config.neo4j.GraphDatabase.driver"
+ ) as driver_mock:
+ driver_mock.return_value = "a driver"
+ d = config.parse()
+ driver_mock.assert_called_once_with("bolt://", auth=("a user", "a password"))
+ assert d == "a driver" # type: ignore
+
+
+def test_neo4j_driver_type_with_driver(driver: neo4j.Driver) -> None:
+ driver_type = Neo4jDriverType(driver)
+ assert driver_type.parse() == driver
+
+
+def test_neo4j_driver_type_with_config() -> None:
+ driver_type = Neo4jDriverType(
+ Neo4jDriverConfig(
+ params_={
+ "uri": "bolt://",
+ "user": "",
+ "password": "",
+ }
+ )
+ )
+ driver = driver_type.parse()
+ assert isinstance(driver, neo4j.Driver)
+
+
+def test_llm_config() -> None:
+ config = LLMConfig.model_validate(
+ {
+ "class_": "OpenAILLM",
+ "params_": {"model_name": "gpt-4o", "api_key": "my-api-key"},
+ }
+ )
+ assert config.class_ == "OpenAILLM"
+ assert config.get_module() == "neo4j_graphrag.llm"
+ assert config.get_interface() == LLMInterface
+ assert config.params_ == {"model_name": "gpt-4o", "api_key": "my-api-key"}
+ d = config.parse()
+ assert isinstance(d, OpenAILLM)
+
+
+def test_llm_type_with_driver(llm: LLMInterface) -> None:
+ llm_type = LLMType(llm)
+ assert llm_type.parse() == llm
+
+
+def test_llm_type_with_config() -> None:
+ llm_type = LLMType(
+ LLMConfig(
+ class_="OpenAILLM",
+ params_={"model_name": "gpt-4o", "api_key": "my-api-key"},
+ )
+ )
+ llm = llm_type.parse()
+ assert isinstance(llm, OpenAILLM)
+
+
+def test_embedder_config() -> None:
+ config = EmbedderConfig.model_validate(
+ {
+ "class_": "OpenAIEmbeddings",
+ "params_": {"api_key": "my-api-key"},
+ }
+ )
+ assert config.class_ == "OpenAIEmbeddings"
+ assert config.get_module() == "neo4j_graphrag.embeddings"
+ assert config.get_interface() == Embedder
+ assert config.params_ == {"api_key": "my-api-key"}
+ d = config.parse()
+ assert isinstance(d, OpenAIEmbeddings)
+
+
+def test_embedder_type_with_embedder(embedder: Embedder) -> None:
+ embedder_type = EmbedderType(embedder)
+ assert embedder_type.parse() == embedder
+
+
+def test_embedder_type_with_config() -> None:
+ embedder_type = EmbedderType(
+ EmbedderConfig(
+ class_="OpenAIEmbeddings",
+ params_={"api_key": "my-api-key"},
+ )
+ )
+ embedder = embedder_type.parse()
+ assert isinstance(embedder, OpenAIEmbeddings)
diff --git a/tests/unit/experimental/pipeline/config/test_param_resolver.py b/tests/unit/experimental/pipeline/config/test_param_resolver.py
new file mode 100644
index 00000000..efd4d7e9
--- /dev/null
+++ b/tests/unit/experimental/pipeline/config/test_param_resolver.py
@@ -0,0 +1,56 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [https://neo4j.com]
+# #
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# #
+# https://www.apache.org/licenses/LICENSE-2.0
+# #
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from unittest.mock import patch
+
+import pytest
+from neo4j_graphrag.experimental.pipeline.config.param_resolver import (
+ ParamFromEnvConfig,
+ ParamFromKeyConfig,
+)
+
+
+@patch.dict(os.environ, {"MY_KEY": "my_value"}, clear=True)
+def test_env_param_config_happy_path() -> None:
+ resolver = ParamFromEnvConfig(var_="MY_KEY")
+ assert resolver.resolve({}) == "my_value"
+
+
+@patch.dict(os.environ, {}, clear=True)
+def test_env_param_config_missing_env_var() -> None:
+ resolver = ParamFromEnvConfig(var_="MY_KEY")
+ assert resolver.resolve({}) is None
+
+
+def test_config_key_param_simple_key() -> None:
+ resolver = ParamFromKeyConfig(key_="my_key")
+ assert resolver.resolve({"my_key": "my_value"}) == "my_value"
+
+
+def test_config_key_param_missing_key() -> None:
+ resolver = ParamFromKeyConfig(key_="my_key")
+ with pytest.raises(KeyError):
+ resolver.resolve({})
+
+
+def test_config_complex_key_param() -> None:
+ resolver = ParamFromKeyConfig(key_="my_key.my_sub_key")
+ assert resolver.resolve({"my_key": {"my_sub_key": "value"}}) == "value"
+
+
+def test_config_complex_key_param_missing_subkey() -> None:
+ resolver = ParamFromKeyConfig(key_="my_key.my_sub_key")
+ with pytest.raises(KeyError):
+ resolver.resolve({"my_key": {}})
diff --git a/tests/unit/experimental/pipeline/config/test_pipeline_config.py b/tests/unit/experimental/pipeline/config/test_pipeline_config.py
new file mode 100644
index 00000000..4de5874b
--- /dev/null
+++ b/tests/unit/experimental/pipeline/config/test_pipeline_config.py
@@ -0,0 +1,378 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [https://neo4j.com]
+# #
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# #
+# https://www.apache.org/licenses/LICENSE-2.0
+# #
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import Mock, patch
+
+import neo4j
+from neo4j_graphrag.embeddings import Embedder
+from neo4j_graphrag.experimental.pipeline import Component
+from neo4j_graphrag.experimental.pipeline.config.object_config import (
+ ComponentConfig,
+ ComponentType,
+ Neo4jDriverConfig,
+ Neo4jDriverType,
+)
+from neo4j_graphrag.experimental.pipeline.config.param_resolver import (
+ ParamFromEnvConfig,
+ ParamFromKeyConfig,
+)
+from neo4j_graphrag.experimental.pipeline.config.pipeline_config import (
+ AbstractPipelineConfig,
+)
+from neo4j_graphrag.experimental.pipeline.types import ComponentDefinition
+from neo4j_graphrag.llm import LLMInterface
+
+
+@patch(
+ "neo4j_graphrag.experimental.pipeline.config.object_config.Neo4jDriverConfig.parse"
+)
+def test_abstract_pipeline_config_neo4j_config_is_a_dict_with_params_(
+ mock_neo4j_config: Mock,
+) -> None:
+ mock_neo4j_config.return_value = "text"
+ config = AbstractPipelineConfig.model_validate(
+ {
+ "neo4j_config": {
+ "params_": {
+ "uri": "bolt://",
+ "user": "",
+ "password": "",
+ }
+ }
+ }
+ )
+ assert isinstance(config.neo4j_config, dict)
+ assert "default" in config.neo4j_config
+ config.parse()
+ mock_neo4j_config.assert_called_once()
+ assert config._global_data["neo4j_config"]["default"] == "text"
+
+
+@patch(
+ "neo4j_graphrag.experimental.pipeline.config.object_config.Neo4jDriverConfig.parse"
+)
+def test_abstract_pipeline_config_neo4j_config_is_a_dict_with_names(
+ mock_neo4j_config: Mock,
+) -> None:
+ mock_neo4j_config.return_value = "text"
+ config = AbstractPipelineConfig.model_validate(
+ {
+ "neo4j_config": {
+ "my_driver": {
+ "params_": {
+ "uri": "bolt://",
+ "user": "",
+ "password": "",
+ }
+ }
+ }
+ }
+ )
+ assert isinstance(config.neo4j_config, dict)
+ assert "my_driver" in config.neo4j_config
+ config.parse()
+ mock_neo4j_config.assert_called_once()
+ assert config._global_data["neo4j_config"]["my_driver"] == "text"
+
+
+@patch(
+ "neo4j_graphrag.experimental.pipeline.config.object_config.Neo4jDriverConfig.parse"
+)
+def test_abstract_pipeline_config_neo4j_config_is_a_dict_with_driver(
+ mock_neo4j_config: Mock, driver: neo4j.Driver
+) -> None:
+ config = AbstractPipelineConfig.model_validate(
+ {
+ "neo4j_config": {
+ "my_driver": driver,
+ }
+ }
+ )
+ assert isinstance(config.neo4j_config, dict)
+ assert "my_driver" in config.neo4j_config
+ config.parse()
+ assert not mock_neo4j_config.called
+ assert config._global_data["neo4j_config"]["my_driver"] == driver
+
+
+@patch(
+ "neo4j_graphrag.experimental.pipeline.config.object_config.Neo4jDriverConfig.parse"
+)
+def test_abstract_pipeline_config_neo4j_config_is_a_driver(
+ mock_neo4j_config: Mock, driver: neo4j.Driver
+) -> None:
+ config = AbstractPipelineConfig.model_validate(
+ {
+ "neo4j_config": driver,
+ }
+ )
+ assert isinstance(config.neo4j_config, dict)
+ assert "default" in config.neo4j_config
+ config.parse()
+ assert not mock_neo4j_config.called
+ assert config._global_data["neo4j_config"]["default"] == driver
+
+
+@patch("neo4j_graphrag.experimental.pipeline.config.object_config.LLMConfig.parse")
+def test_abstract_pipeline_config_llm_config_is_a_dict_with_params_(
+ mock_llm_config: Mock,
+) -> None:
+ mock_llm_config.return_value = "text"
+ config = AbstractPipelineConfig.model_validate(
+ {"llm_config": {"class_": "OpenAILLM", "params_": {"model_name": "gpt-4o"}}}
+ )
+ assert isinstance(config.llm_config, dict)
+ assert "default" in config.llm_config
+ config.parse()
+ mock_llm_config.assert_called_once()
+ assert config._global_data["llm_config"]["default"] == "text"
+
+
+@patch("neo4j_graphrag.experimental.pipeline.config.object_config.LLMConfig.parse")
+def test_abstract_pipeline_config_llm_config_is_a_dict_with_names(
+ mock_llm_config: Mock,
+) -> None:
+ mock_llm_config.return_value = "text"
+ config = AbstractPipelineConfig.model_validate(
+ {
+ "llm_config": {
+ "my_llm": {"class_": "OpenAILLM", "params_": {"model_name": "gpt-4o"}}
+ }
+ }
+ )
+ assert isinstance(config.llm_config, dict)
+ assert "my_llm" in config.llm_config
+ config.parse()
+ mock_llm_config.assert_called_once()
+ assert config._global_data["llm_config"]["my_llm"] == "text"
+
+
+@patch("neo4j_graphrag.experimental.pipeline.config.object_config.LLMConfig.parse")
+def test_abstract_pipeline_config_llm_config_is_a_dict_with_llm(
+ mock_llm_config: Mock, llm: LLMInterface
+) -> None:
+ config = AbstractPipelineConfig.model_validate(
+ {
+ "llm_config": {
+ "my_llm": llm,
+ }
+ }
+ )
+ assert isinstance(config.llm_config, dict)
+ assert "my_llm" in config.llm_config
+ config.parse()
+ assert not mock_llm_config.called
+ assert config._global_data["llm_config"]["my_llm"] == llm
+
+
+@patch("neo4j_graphrag.experimental.pipeline.config.object_config.LLMConfig.parse")
+def test_abstract_pipeline_config_llm_config_is_a_llm(
+ mock_llm_config: Mock, llm: LLMInterface
+) -> None:
+ config = AbstractPipelineConfig.model_validate(
+ {
+ "llm_config": llm,
+ }
+ )
+ assert isinstance(config.llm_config, dict)
+ assert "default" in config.llm_config
+ config.parse()
+ assert not mock_llm_config.called
+ assert config._global_data["llm_config"]["default"] == llm
+
+
+@patch("neo4j_graphrag.experimental.pipeline.config.object_config.EmbedderConfig.parse")
+def test_abstract_pipeline_config_embedder_config_is_a_dict_with_params_(
+ mock_embedder_config: Mock,
+) -> None:
+ mock_embedder_config.return_value = "text"
+ config = AbstractPipelineConfig.model_validate(
+ {"embedder_config": {"class_": "OpenAIEmbeddings", "params_": {}}}
+ )
+ assert isinstance(config.embedder_config, dict)
+ assert "default" in config.embedder_config
+ config.parse()
+ mock_embedder_config.assert_called_once()
+ assert config._global_data["embedder_config"]["default"] == "text"
+
+
+@patch("neo4j_graphrag.experimental.pipeline.config.object_config.EmbedderConfig.parse")
+def test_abstract_pipeline_config_embedder_config_is_a_dict_with_names(
+ mock_embedder_config: Mock,
+) -> None:
+ mock_embedder_config.return_value = "text"
+ config = AbstractPipelineConfig.model_validate(
+ {
+ "embedder_config": {
+ "my_embedder": {"class_": "OpenAIEmbeddings", "params_": {}}
+ }
+ }
+ )
+ assert isinstance(config.embedder_config, dict)
+ assert "my_embedder" in config.embedder_config
+ config.parse()
+ mock_embedder_config.assert_called_once()
+ assert config._global_data["embedder_config"]["my_embedder"] == "text"
+
+
+@patch("neo4j_graphrag.experimental.pipeline.config.object_config.EmbedderConfig.parse")
+def test_abstract_pipeline_config_embedder_config_is_a_dict_with_llm(
+ mock_embedder_config: Mock, embedder: Embedder
+) -> None:
+ config = AbstractPipelineConfig.model_validate(
+ {
+ "embedder_config": {
+ "my_embedder": embedder,
+ }
+ }
+ )
+ assert isinstance(config.embedder_config, dict)
+ assert "my_embedder" in config.embedder_config
+ config.parse()
+ assert not mock_embedder_config.called
+ assert config._global_data["embedder_config"]["my_embedder"] == embedder
+
+
+@patch("neo4j_graphrag.experimental.pipeline.config.object_config.EmbedderConfig.parse")
+def test_abstract_pipeline_config_embedder_config_is_an_embedder(
+ mock_embedder_config: Mock, embedder: Embedder
+) -> None:
+ config = AbstractPipelineConfig.model_validate(
+ {
+ "embedder_config": embedder,
+ }
+ )
+ assert isinstance(config.embedder_config, dict)
+ assert "default" in config.embedder_config
+ config.parse()
+ assert not mock_embedder_config.called
+ assert config._global_data["embedder_config"]["default"] == embedder
+
+
+def test_abstract_pipeline_config_parse_global_data_no_extras(driver: Mock) -> None:
+ config = AbstractPipelineConfig(
+ neo4j_config={"my_driver": Neo4jDriverType(driver)},
+ )
+ gd = config._parse_global_data()
+ assert gd == {
+ "extras": {},
+ "neo4j_config": {
+ "my_driver": driver,
+ },
+ "llm_config": {},
+ "embedder_config": {},
+ }
+
+
+@patch(
+ "neo4j_graphrag.experimental.pipeline.config.param_resolver.ParamFromEnvConfig.resolve"
+)
+def test_abstract_pipeline_config_parse_global_data_extras(
+ mock_param_resolver: Mock,
+) -> None:
+ mock_param_resolver.return_value = "my value"
+ config = AbstractPipelineConfig(
+ extras={"my_extra_var": ParamFromEnvConfig(var_="some key")},
+ )
+ gd = config._parse_global_data()
+ assert gd == {
+ "extras": {"my_extra_var": "my value"},
+ "neo4j_config": {},
+ "llm_config": {},
+ "embedder_config": {},
+ }
+
+
+@patch(
+ "neo4j_graphrag.experimental.pipeline.config.param_resolver.ParamFromEnvConfig.resolve"
+)
+@patch(
+ "neo4j_graphrag.experimental.pipeline.config.object_config.Neo4jDriverType.parse"
+)
+def test_abstract_pipeline_config_parse_global_data_use_extras_in_other_config(
+ mock_neo4j_parser: Mock,
+ mock_param_resolver: Mock,
+) -> None:
+ """Parser is able to read variables in the 'extras' section of config
+ to instantiate another object (neo4j.Driver in this test case)
+ """
+ mock_param_resolver.side_effect = ["bolt://myhost", "myuser", "mypwd"]
+ mock_neo4j_parser.return_value = "my driver"
+ config = AbstractPipelineConfig(
+ extras={
+ "my_extra_uri": ParamFromEnvConfig(var_="some key"),
+ "my_extra_user": ParamFromEnvConfig(var_="some key"),
+ "my_extra_pwd": ParamFromEnvConfig(var_="some key"),
+ },
+ neo4j_config={
+ "my_driver": Neo4jDriverType(
+ Neo4jDriverConfig(
+ params_=dict(
+ uri=ParamFromKeyConfig(key_="extras.my_extra_uri"),
+ user=ParamFromKeyConfig(key_="extras.my_extra_user"),
+ password=ParamFromKeyConfig(key_="extras.my_extra_pwd"),
+ )
+ )
+ )
+ },
+ )
+ gd = config._parse_global_data()
+ expected_extras = {
+ "my_extra_uri": "bolt://myhost",
+ "my_extra_user": "myuser",
+ "my_extra_pwd": "mypwd",
+ }
+ assert gd["extras"] == expected_extras
+ assert gd["neo4j_config"] == {"my_driver": "my driver"}
+ mock_neo4j_parser.assert_called_once_with({"extras": expected_extras})
+
+
+@patch("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse")
+def test_abstract_pipeline_config_resolve_component_definition_no_run_params(
+ mock_component_parse: Mock,
+ component: Component,
+) -> None:
+ mock_component_parse.return_value = component
+ config = AbstractPipelineConfig()
+ component_type = ComponentType(component)
+ component_definition = config._resolve_component_definition("name", component_type)
+ assert isinstance(component_definition, ComponentDefinition)
+ mock_component_parse.assert_called_once_with({})
+ assert component_definition.name == "name"
+ assert component_definition.component == component
+ assert component_definition.run_params == {}
+
+
+@patch(
+ "neo4j_graphrag.experimental.pipeline.config.pipeline_config.AbstractPipelineConfig.resolve_params"
+)
+@patch("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse")
+def test_abstract_pipeline_config_resolve_component_definition_with_run_params(
+ mock_component_parse: Mock,
+ mock_resolve_params: Mock,
+ component: Component,
+) -> None:
+ mock_component_parse.return_value = component
+ mock_resolve_params.return_value = {"param": "resolver param result"}
+ config = AbstractPipelineConfig()
+ component_type = ComponentType(
+ ComponentConfig(class_="", params_={}, run_params_={"param1": "value1"})
+ )
+ component_definition = config._resolve_component_definition("name", component_type)
+ assert isinstance(component_definition, ComponentDefinition)
+ mock_component_parse.assert_called_once_with({})
+ assert component_definition.name == "name"
+ assert component_definition.component == component
+ assert component_definition.run_params == {"param": "resolver param result"}
+ mock_resolve_params.assert_called_once_with({"param1": "value1"})
diff --git a/tests/unit/experimental/pipeline/config/test_runner.py b/tests/unit/experimental/pipeline/config/test_runner.py
new file mode 100644
index 00000000..327b5221
--- /dev/null
+++ b/tests/unit/experimental/pipeline/config/test_runner.py
@@ -0,0 +1,56 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [https://neo4j.com]
+# #
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# #
+# https://www.apache.org/licenses/LICENSE-2.0
+# #
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import Mock, patch
+
+from neo4j_graphrag.experimental.pipeline import Pipeline
+from neo4j_graphrag.experimental.pipeline.config.pipeline_config import PipelineConfig
+from neo4j_graphrag.experimental.pipeline.config.runner import PipelineRunner
+from neo4j_graphrag.experimental.pipeline.types import PipelineDefinition
+
+
+@patch("neo4j_graphrag.experimental.pipeline.pipeline.Pipeline.from_definition")
+def test_pipeline_runner_from_def_empty(mock_from_definition: Mock) -> None:
+ mock_from_definition.return_value = Pipeline()
+ runner = PipelineRunner(
+ pipeline_definition=PipelineDefinition(components=[], connections=[])
+ )
+ assert runner.config is None
+ assert runner.pipeline is not None
+ assert runner.pipeline._nodes == {}
+ assert runner.pipeline._edges == []
+ assert runner.run_params == {}
+ mock_from_definition.assert_called_once()
+
+
+def test_pipeline_runner_from_config() -> None:
+ config = PipelineConfig(component_config={}, connection_config=[])
+ runner = PipelineRunner.from_config(config)
+ assert runner.config is not None
+ assert runner.pipeline is not None
+ assert runner.pipeline._nodes == {}
+ assert runner.pipeline._edges == []
+ assert runner.run_params == {}
+
+
+@patch("neo4j_graphrag.experimental.pipeline.config.runner.PipelineRunner.from_config")
+@patch("neo4j_graphrag.experimental.pipeline.config.config_reader.ConfigReader.read")
+def test_pipeline_runner_from_config_file(
+ mock_read: Mock, mock_from_config: Mock
+) -> None:
+ mock_read.return_value = {"dict": "with data"}
+ PipelineRunner.from_config_file("file.yaml")
+
+ mock_read.assert_called_once_with("file.yaml")
+ mock_from_config.assert_called_once_with({"dict": "with data"}, do_cleaning=True)
diff --git a/tests/unit/experimental/pipeline/test_kg_builder.py b/tests/unit/experimental/pipeline/test_kg_builder.py
index b1b29151..bddbac50 100644
--- a/tests/unit/experimental/pipeline/test_kg_builder.py
+++ b/tests/unit/experimental/pipeline/test_kg_builder.py
@@ -18,10 +18,8 @@
import neo4j
import pytest
from neo4j_graphrag.embeddings import Embedder
-from neo4j_graphrag.experimental.components.entity_relation_extractor import OnError
from neo4j_graphrag.experimental.components.schema import (
SchemaEntity,
- SchemaProperty,
SchemaRelation,
)
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
@@ -31,136 +29,6 @@
from neo4j_graphrag.llm.base import LLMInterface
-@mock.patch(
- "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version",
- return_value=(5, 23, 0),
-)
-@pytest.mark.asyncio
-async def test_knowledge_graph_builder_init_with_text(_: Mock) -> None:
- llm = MagicMock(spec=LLMInterface)
- driver = MagicMock(spec=neo4j.Driver)
- embedder = MagicMock(spec=Embedder)
-
- kg_builder = SimpleKGPipeline(
- llm=llm,
- driver=driver,
- embedder=embedder,
- from_pdf=False,
- )
-
- assert kg_builder.llm == llm
- assert kg_builder.driver == driver
- assert kg_builder.embedder == embedder
- assert kg_builder.from_pdf is False
- assert kg_builder.entities == []
- assert kg_builder.relations == []
- assert kg_builder.potential_schema == []
- assert "pdf_loader" not in kg_builder.pipeline
-
- text_input = "May thy knife chip and shatter."
-
- with patch.object(
- kg_builder.pipeline,
- "run",
- return_value=PipelineResult(run_id="test_run", result=None),
- ) as mock_run:
- await kg_builder.run_async(text=text_input)
- mock_run.assert_called_once()
- pipe_inputs = mock_run.call_args[0][0]
- assert pipe_inputs["splitter"]["text"] == text_input
-
-
-@mock.patch(
- "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version",
- return_value=(5, 23, 0),
-)
-@pytest.mark.asyncio
-async def test_knowledge_graph_builder_init_with_file_path(_: Mock) -> None:
- llm = MagicMock(spec=LLMInterface)
- driver = MagicMock(spec=neo4j.Driver)
- embedder = MagicMock(spec=Embedder)
-
- kg_builder = SimpleKGPipeline(
- llm=llm,
- driver=driver,
- embedder=embedder,
- from_pdf=True,
- )
-
- assert kg_builder.llm == llm
- assert kg_builder.driver == driver
- assert kg_builder.from_pdf is True
- assert kg_builder.entities == []
- assert kg_builder.relations == []
- assert kg_builder.potential_schema == []
- assert "pdf_loader" in kg_builder.pipeline
-
- file_path = "path/to/test.pdf"
-
- with patch.object(
- kg_builder.pipeline,
- "run",
- return_value=PipelineResult(run_id="test_run", result=None),
- ) as mock_run:
- await kg_builder.run_async(file_path=file_path)
- mock_run.assert_called_once()
- pipe_inputs = mock_run.call_args[0][0]
- assert pipe_inputs["pdf_loader"]["filepath"] == file_path
-
-
-@mock.patch(
- "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version",
- return_value=(5, 23, 0),
-)
-@pytest.mark.asyncio
-async def test_knowledge_graph_builder_run_with_both_inputs(_: Mock) -> None:
- llm = MagicMock(spec=LLMInterface)
- driver = MagicMock(spec=neo4j.Driver)
- embedder = MagicMock(spec=Embedder)
-
- kg_builder = SimpleKGPipeline(
- llm=llm,
- driver=driver,
- embedder=embedder,
- from_pdf=True,
- )
-
- text_input = "May thy knife chip and shatter."
- file_path = "path/to/test.pdf"
-
- with pytest.raises(PipelineDefinitionError) as exc_info:
- await kg_builder.run_async(file_path=file_path, text=text_input)
-
- assert "Expected 'file_path' argument when 'from_pdf' is True." in str(
- exc_info.value
- ) or "Expected 'text' argument when 'from_pdf' is False." in str(exc_info.value)
-
-
-@mock.patch(
- "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version",
- return_value=(5, 23, 0),
-)
-@pytest.mark.asyncio
-async def test_knowledge_graph_builder_run_with_no_inputs(_: Mock) -> None:
- llm = MagicMock(spec=LLMInterface)
- driver = MagicMock(spec=neo4j.Driver)
- embedder = MagicMock(spec=Embedder)
-
- kg_builder = SimpleKGPipeline(
- llm=llm,
- driver=driver,
- embedder=embedder,
- from_pdf=True,
- )
-
- with pytest.raises(PipelineDefinitionError) as exc_info:
- await kg_builder.run_async()
-
- assert "Expected 'file_path' argument when 'from_pdf' is True." in str(
- exc_info.value
- ) or "Expected 'text' argument when 'from_pdf' is False." in str(exc_info.value)
-
-
@mock.patch(
"neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version",
return_value=(5, 23, 0),
@@ -181,13 +49,13 @@ async def test_knowledge_graph_builder_document_info_with_file(_: Mock) -> None:
file_path = "path/to/test.pdf"
with patch.object(
- kg_builder.pipeline,
+ kg_builder.runner.pipeline,
"run",
return_value=PipelineResult(run_id="test_run", result=None),
) as mock_run:
await kg_builder.run_async(file_path=file_path)
- pipe_inputs = mock_run.call_args[0][0]
+ pipe_inputs = mock_run.call_args[1]["data"]
assert "pdf_loader" in pipe_inputs
assert pipe_inputs["pdf_loader"] == {"filepath": file_path}
assert "extractor" not in pipe_inputs
@@ -213,13 +81,13 @@ async def test_knowledge_graph_builder_document_info_with_text(_: Mock) -> None:
text_input = "May thy knife chip and shatter."
with patch.object(
- kg_builder.pipeline,
+ kg_builder.runner.pipeline,
"run",
return_value=PipelineResult(run_id="test_run", result=None),
) as mock_run:
await kg_builder.run_async(text=text_input)
- pipe_inputs = mock_run.call_args[0][0]
+ pipe_inputs = mock_run.call_args[1]["data"]
assert "splitter" in pipe_inputs
assert pipe_inputs["splitter"] == {"text": text_input}
@@ -248,51 +116,33 @@ async def test_knowledge_graph_builder_with_entities_and_file(_: Mock) -> None:
from_pdf=True,
)
- internal_entities = [SchemaEntity(label=label) for label in entities]
- internal_relations = [SchemaRelation(label=label) for label in relations]
- assert kg_builder.entities == internal_entities
- assert kg_builder.relations == internal_relations
- assert kg_builder.potential_schema == potential_schema
+ # assert kg_builder.entities == entities
+ # assert kg_builder.relations == relations
+ # assert kg_builder.potential_schema == potential_schema
file_path = "path/to/test.pdf"
+ internal_entities = [SchemaEntity(label=label) for label in entities]
+ internal_relations = [SchemaRelation(label=label) for label in relations]
+
with patch.object(
- kg_builder.pipeline,
+ kg_builder.runner.pipeline,
"run",
return_value=PipelineResult(run_id="test_run", result=None),
) as mock_run:
await kg_builder.run_async(file_path=file_path)
- pipe_inputs = mock_run.call_args[0][0]
+ pipe_inputs = mock_run.call_args[1]["data"]
assert pipe_inputs["schema"]["entities"] == internal_entities
assert pipe_inputs["schema"]["relations"] == internal_relations
assert pipe_inputs["schema"]["potential_schema"] == potential_schema
-@mock.patch(
- "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version",
- return_value=(5, 23, 0),
-)
-def test_simple_kg_pipeline_on_error_conversion(_: Mock) -> None:
- llm = MagicMock(spec=LLMInterface)
- driver = MagicMock(spec=neo4j.Driver)
- embedder = MagicMock(spec=Embedder)
-
- kg_builder = SimpleKGPipeline(
- llm=llm,
- driver=driver,
- embedder=embedder,
- on_error="RAISE",
- )
-
- assert kg_builder.on_error == OnError.RAISE
-
-
def test_simple_kg_pipeline_on_error_invalid_value() -> None:
llm = MagicMock(spec=LLMInterface)
driver = MagicMock(spec=neo4j.Driver)
embedder = MagicMock(spec=Embedder)
- with pytest.raises(PipelineDefinitionError) as exc_info:
+ with pytest.raises(PipelineDefinitionError):
SimpleKGPipeline(
llm=llm,
driver=driver,
@@ -300,50 +150,6 @@ def test_simple_kg_pipeline_on_error_invalid_value() -> None:
on_error="INVALID_VALUE",
)
- assert "Expected one of ['RAISE', 'IGNORE']" in str(exc_info.value)
-
-
-@mock.patch(
- "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version",
- return_value=(5, 23, 0),
-)
-def test_simple_kg_pipeline_no_entity_resolution(_: Mock) -> None:
- llm = MagicMock(spec=LLMInterface)
- driver = MagicMock(spec=neo4j.Driver)
- embedder = MagicMock(spec=Embedder)
-
- kg_builder = SimpleKGPipeline(
- llm=llm,
- driver=driver,
- embedder=embedder,
- on_error="IGNORE",
- perform_entity_resolution=False,
- )
-
- assert "resolver" not in kg_builder.pipeline
-
-
-@mock.patch(
- "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version",
- return_value=(5, 23, 0),
-)
-@pytest.mark.asyncio
-def test_simple_kg_pipeline_lexical_graph_config_attribute(_: Mock) -> None:
- llm = MagicMock(spec=LLMInterface)
- driver = MagicMock(spec=neo4j.Driver)
- embedder = MagicMock(spec=Embedder)
-
- lexical_graph_config = LexicalGraphConfig()
- kg_builder = SimpleKGPipeline(
- llm=llm,
- driver=driver,
- embedder=embedder,
- on_error="IGNORE",
- lexical_graph_config=lexical_graph_config,
- )
-
- assert kg_builder.lexical_graph_config == lexical_graph_config
-
@mock.patch(
"neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version",
@@ -372,58 +178,14 @@ async def test_knowledge_graph_builder_with_lexical_graph_config(_: Mock) -> Non
text_input = "May thy knife chip and shatter."
with patch.object(
- kg_builder.pipeline,
+ kg_builder.runner.pipeline,
"run",
return_value=PipelineResult(run_id="test_run", result=None),
) as mock_run:
await kg_builder.run_async(text=text_input)
- pipe_inputs = mock_run.call_args[0][0]
+ pipe_inputs = mock_run.call_args[1]["data"]
assert "extractor" in pipe_inputs
assert pipe_inputs["extractor"] == {
"lexical_graph_config": lexical_graph_config
}
-
-
-def test_knowledge_graph_builder_to_schema_entity_method() -> None:
- assert SimpleKGPipeline.to_schema_entity("EntityType") == SchemaEntity(
- label="EntityType"
- )
- assert SimpleKGPipeline.to_schema_entity({"label": "EntityType"}) == SchemaEntity(
- label="EntityType"
- )
- assert SimpleKGPipeline.to_schema_entity(
- {"label": "EntityType", "description": "A special entity"}
- ) == SchemaEntity(label="EntityType", description="A special entity")
- assert SimpleKGPipeline.to_schema_entity(
- {"label": "EntityType", "properties": []}
- ) == SchemaEntity(label="EntityType")
- assert SimpleKGPipeline.to_schema_entity(
- {
- "label": "EntityType",
- "properties": [{"name": "entityProperty", "type": "DATE"}],
- }
- ) == SchemaEntity(
- label="EntityType",
- properties=[SchemaProperty(name="entityProperty", type="DATE")],
- )
-
-
-def test_knowledge_graph_builder_to_schema_relation_method() -> None:
- assert SimpleKGPipeline.to_schema_relation("REL_TYPE") == SchemaRelation(
- label="REL_TYPE"
- )
- assert SimpleKGPipeline.to_schema_relation({"label": "REL_TYPE"}) == SchemaRelation(
- label="REL_TYPE"
- )
- assert SimpleKGPipeline.to_schema_relation(
- {"label": "REL_TYPE", "description": "A rel type"}
- ) == SchemaRelation(label="REL_TYPE", description="A rel type")
- assert SimpleKGPipeline.to_schema_relation(
- {"label": "REL_TYPE", "properties": []}
- ) == SchemaRelation(label="REL_TYPE")
- assert SimpleKGPipeline.to_schema_relation(
- {"label": "REL_TYPE", "properties": [{"name": "relProperty", "type": "DATE"}]}
- ) == SchemaRelation(
- label="REL_TYPE", properties=[SchemaProperty(name="relProperty", type="DATE")]
- )