diff --git a/src/neo4j_genai/client.py b/src/neo4j_genai/client.py index 4d65e01fe..82580a4d7 100644 --- a/src/neo4j_genai/client.py +++ b/src/neo4j_genai/client.py @@ -2,7 +2,7 @@ from pydantic import ValidationError from neo4j import Driver from .embedder import Embedder -from .types import CreateIndexModel, SimilaritySearchModel, Neo4jRecord +from .types import VectorIndexModel, SimilaritySearchModel, Neo4jRecord class GenAIClient: @@ -72,7 +72,7 @@ def create_index( ValueError: If validation of the input arguments fail. """ try: - CreateIndexModel( + VectorIndexModel( **{ "name": name, "label": label, diff --git a/src/neo4j_genai/indexes.py b/src/neo4j_genai/indexes.py index ba6b6014a..916040056 100644 --- a/src/neo4j_genai/indexes.py +++ b/src/neo4j_genai/indexes.py @@ -2,7 +2,7 @@ from neo4j import Driver from pydantic import ValidationError -from .types import CreateIndexModel +from .types import VectorIndexModel, FulltextIndexModel def create_vector_index( @@ -32,8 +32,9 @@ def create_vector_index( ValueError: If validation of the input arguments fail. """ try: - CreateIndexModel( + VectorIndexModel( **{ + "driver": driver, "name": name, "label": label, "property": property, @@ -42,7 +43,7 @@ def create_vector_index( } ) except ValidationError as e: - raise ValueError(f"Error for inputs to create_index {str(e)}") + raise ValueError(f"Error for inputs to create_vector_index {str(e)}") query = ( f"CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:{label}) ON n.{property} OPTIONS " @@ -54,13 +55,25 @@ def create_vector_index( def create_fulltext_index( - driver: Driver, name: str, label: str, text_node_properties: List[str] = [] + driver: Driver, name: str, label: str, node_properties: List[str] ) -> None: """ """ + try: + FulltextIndexModel( + **{ + "driver": driver, + "name": name, + "label": label, + "node_properties": node_properties, + } + ) + except ValidationError as e: + raise ValueError(f"Error for inputs to create_fulltext_index {str(e)}") + query = ( "CREATE FULLTEXT INDEX $name" f"FOR (n:`{label}`) ON EACH " - f"[{', '.join(['n.`' + property + '`' for property in text_node_properties])}]" + f"[{', '.join(['n.`' + prop + '`' for prop in node_properties])}]" ) driver.execute_query(query, {"name": name}) diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index f7ee8e621..91db6db74 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -1,5 +1,6 @@ from typing import List, Any, Literal, Optional -from pydantic import BaseModel, PositiveInt, model_validator +from pydantic import BaseModel, PositiveInt, model_validator, field_validator +from neo4j import Driver class Neo4jRecord(BaseModel): @@ -11,7 +12,17 @@ class EmbeddingVector(BaseModel): vector: List[float] -class CreateIndexModel(BaseModel): +class IndexModel(BaseModel): + driver: Any + + @field_validator("driver") + def check_driver_is_valid(cls, v): + if not isinstance(v, Driver): + raise ValueError("driver must be an instance of neo4j.Driver") + return v + + +class VectorIndexModel(IndexModel): name: str label: str property: str @@ -19,6 +30,18 @@ class CreateIndexModel(BaseModel): similarity_fn: Literal["euclidean", "cosine"] +class FulltextIndexModel(IndexModel): + name: str + label: str + node_properties: List[str] + + @field_validator("node_properties") + def check_node_properties_not_empty(cls, v): + if len(v) == 0: + raise ValueError("node_properties cannot be an empty list") + return v + + class SimilaritySearchModel(BaseModel): index_name: str top_k: PositiveInt = 5 diff --git a/tests/conftest.py b/tests/conftest.py index b3331dec0..f481cdb43 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,12 @@ import pytest from neo4j_genai import GenAIClient +from neo4j import Driver from unittest.mock import MagicMock, patch @pytest.fixture def driver(): - return MagicMock() + return MagicMock(spec=Driver) @pytest.fixture diff --git a/tests/test_indexes.py b/tests/test_indexes.py index f042a7921..6b0ae6223 100644 --- a/tests/test_indexes.py +++ b/tests/test_indexes.py @@ -8,7 +8,6 @@ def test_create_vector_index_happy_path(driver): - driver.execute_query.return_value = [None, None, None] create_query = ( "CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:People) ON n.name OPTIONS " "{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }" @@ -23,7 +22,6 @@ def test_create_vector_index_happy_path(driver): def test_create_vector_index_ensure_escaping(driver, client): - driver.execute_query.return_value = [None, None, None] create_query = ( "CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:People) ON n.name OPTIONS " "{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }" @@ -46,23 +44,22 @@ def test_create_vector_index_ensure_escaping(driver, client): def test_create_vector_index_negative_dimension(driver): with pytest.raises(ValueError) as excinfo: create_vector_index(driver, "my-index", "People", "name", -5, "cosine") - assert "Error for inputs to create_index" in str(excinfo) + assert "Error for inputs to create_vector_index" in str(excinfo) -def test_create_index_validation_error_dimensions(driver): +def test_create_vector_index_validation_error_dimensions(driver): with pytest.raises(ValueError) as excinfo: create_vector_index(driver, "my-index", "People", "name", "no-dim", "cosine") - assert "Error for inputs to create_index" in str(excinfo) + assert "Error for inputs to create_vector_index" in str(excinfo) -def test_create_index_validation_error_similarity_fn(driver): +def test_create_vector_index_validation_error_similarity_fn(driver): with pytest.raises(ValueError) as excinfo: create_vector_index(driver, "my-index", "People", "name", 1536, "algebra") - assert "Error for inputs to create_index" in str(excinfo) + assert "Error for inputs to create_vector_index" in str(excinfo) def test_drop_vector_index(driver): - driver.execute_query.return_value = [None, None, None] drop_query = "DROP INDEX $name" drop_vector_index(driver, "my-index") @@ -74,7 +71,6 @@ def test_drop_vector_index(driver): def test_create_fulltext_index_happy_path(driver): - driver.execute_query.return_value = [None, None, None] label = "node-label" text_node_properties = ["property-1", "property-2"] create_query = ( @@ -87,7 +83,32 @@ def test_create_fulltext_index_happy_path(driver): driver.execute_query.assert_called_once_with( create_query, - { - "name": "my-index" - }, + {"name": "my-index"}, + ) + + +def test_create_fulltext_index_empty_node_properties(driver): + label = "node-label" + node_properties = [] + + with pytest.raises(ValueError) as excinfo: + create_fulltext_index(driver, "my-index", label, node_properties) + + assert "Error for inputs to create_fulltext_index" in str(excinfo) + + +def test_create_fulltext_index_ensure_escaping(driver, client): + label = "node-label" + text_node_properties = ["property-1", "property-2"] + create_query = ( + "CREATE FULLTEXT INDEX $name" + f"FOR (n:`{label}`) ON EACH " + f"[{', '.join(['n.`' + property + '`' for property in text_node_properties])}]" + ) + + create_fulltext_index(driver, "my-complicated-`-index", label, text_node_properties) + + driver.execute_query.assert_called_once_with( + create_query, + {"name": "my-complicated-`-index"}, )