Skip to content

Commit

Permalink
Refactored create and drop index methods, add create fulltext index m…
Browse files Browse the repository at this point in the history
…ethod
  • Loading branch information
willtai committed Apr 4, 2024
1 parent 4b35592 commit 20fb738
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 24 deletions.
4 changes: 2 additions & 2 deletions src/neo4j_genai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pydantic import ValidationError
from neo4j import Driver
from .embedder import Embedder
from .types import CreateIndexModel, SimilaritySearchModel, Neo4jRecord
from .types import VectorIndexModel, SimilaritySearchModel, Neo4jRecord


class GenAIClient:
Expand Down Expand Up @@ -72,7 +72,7 @@ def create_index(
ValueError: If validation of the input arguments fail.
"""
try:
CreateIndexModel(
VectorIndexModel(
**{
"name": name,
"label": label,
Expand Down
41 changes: 34 additions & 7 deletions src/neo4j_genai/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from neo4j import Driver
from pydantic import ValidationError
from .types import CreateIndexModel
from .types import VectorIndexModel, FulltextIndexModel


def create_vector_index(
Expand All @@ -17,7 +17,7 @@ def create_vector_index(
This method constructs a Cypher query and executes it
to create a new vector index in Neo4j.
See Cypher manual on [Create node index](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_createNodeIndex)
See Cypher manual on [Create vector index](https://neo4j.com/docs/cypher-manual/current/indexes/semantic-indexes/vector-indexes/#indexes-vector-create)
Args:
driver (Driver): Neo4j Python driver instance.
Expand All @@ -32,8 +32,9 @@ def create_vector_index(
ValueError: If validation of the input arguments fail.
"""
try:
CreateIndexModel(
VectorIndexModel(
**{
"driver": driver,
"name": name,
"label": label,
"property": property,
Expand All @@ -42,7 +43,7 @@ def create_vector_index(
}
)
except ValidationError as e:
raise ValueError(f"Error for inputs to create_index {str(e)}")
raise ValueError(f"Error for inputs to create_vector_index {str(e)}")

query = (
f"CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:{label}) ON n.{property} OPTIONS "
Expand All @@ -54,13 +55,39 @@ def create_vector_index(


def create_fulltext_index(
driver: Driver, name: str, label: str, text_node_properties: List[str] = []
driver: Driver, name: str, label: str, node_properties: List[str]
) -> None:
""" """
"""
This method constructs a Cypher query and executes it
to create a new fulltext index in Neo4j.
See Cypher manual on [Create fulltext index](https://neo4j.com/docs/cypher-manual/current/indexes/semantic-indexes/full-text-indexes/#create-full-text-indexes)
Args:
driver (Driver): Neo4j Python driver instance.
name (str): The unique name of the index.
label (str): The node label to be indexed.
node_properties (List[str]): The node properties to create the fulltext index on.
Raises:
ValueError: If validation of the input arguments fail.
"""
try:
FulltextIndexModel(
**{
"driver": driver,
"name": name,
"label": label,
"node_properties": node_properties,
}
)
except ValidationError as e:
raise ValueError(f"Error for inputs to create_fulltext_index {str(e)}")

query = (
"CREATE FULLTEXT INDEX $name"
f"FOR (n:`{label}`) ON EACH "
f"[{', '.join(['n.`' + property + '`' for property in text_node_properties])}]"
f"[{', '.join(['n.`' + prop + '`' for prop in node_properties])}]"
)
driver.execute_query(query, {"name": name})

Expand Down
27 changes: 25 additions & 2 deletions src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Any, Literal, Optional
from pydantic import BaseModel, PositiveInt, model_validator
from pydantic import BaseModel, PositiveInt, model_validator, field_validator
from neo4j import Driver


class Neo4jRecord(BaseModel):
Expand All @@ -11,14 +12,36 @@ class EmbeddingVector(BaseModel):
vector: List[float]


class CreateIndexModel(BaseModel):
class IndexModel(BaseModel):
driver: Any

@field_validator("driver")
def check_driver_is_valid(cls, v):
if not isinstance(v, Driver):
raise ValueError("driver must be an instance of neo4j.Driver")
return v


class VectorIndexModel(IndexModel):
name: str
label: str
property: str
dimensions: PositiveInt
similarity_fn: Literal["euclidean", "cosine"]


class FulltextIndexModel(IndexModel):
name: str
label: str
node_properties: List[str]

@field_validator("node_properties")
def check_node_properties_not_empty(cls, v):
if len(v) == 0:
raise ValueError("node_properties cannot be an empty list")
return v


class SimilaritySearchModel(BaseModel):
index_name: str
top_k: PositiveInt = 5
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import pytest
from neo4j_genai import GenAIClient
from neo4j import Driver
from unittest.mock import MagicMock, patch


@pytest.fixture
def driver():
return MagicMock()
return MagicMock(spec=Driver)


@pytest.fixture
Expand Down
45 changes: 33 additions & 12 deletions tests/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@


def test_create_vector_index_happy_path(driver):
driver.execute_query.return_value = [None, None, None]
create_query = (
"CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:People) ON n.name OPTIONS "
"{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }"
Expand All @@ -23,7 +22,6 @@ def test_create_vector_index_happy_path(driver):


def test_create_vector_index_ensure_escaping(driver, client):
driver.execute_query.return_value = [None, None, None]
create_query = (
"CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:People) ON n.name OPTIONS "
"{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }"
Expand All @@ -46,23 +44,22 @@ def test_create_vector_index_ensure_escaping(driver, client):
def test_create_vector_index_negative_dimension(driver):
with pytest.raises(ValueError) as excinfo:
create_vector_index(driver, "my-index", "People", "name", -5, "cosine")
assert "Error for inputs to create_index" in str(excinfo)
assert "Error for inputs to create_vector_index" in str(excinfo)


def test_create_index_validation_error_dimensions(driver):
def test_create_vector_index_validation_error_dimensions(driver):
with pytest.raises(ValueError) as excinfo:
create_vector_index(driver, "my-index", "People", "name", "no-dim", "cosine")
assert "Error for inputs to create_index" in str(excinfo)
assert "Error for inputs to create_vector_index" in str(excinfo)


def test_create_index_validation_error_similarity_fn(driver):
def test_create_vector_index_validation_error_similarity_fn(driver):
with pytest.raises(ValueError) as excinfo:
create_vector_index(driver, "my-index", "People", "name", 1536, "algebra")
assert "Error for inputs to create_index" in str(excinfo)
assert "Error for inputs to create_vector_index" in str(excinfo)


def test_drop_vector_index(driver):
driver.execute_query.return_value = [None, None, None]
drop_query = "DROP INDEX $name"

drop_vector_index(driver, "my-index")
Expand All @@ -74,7 +71,6 @@ def test_drop_vector_index(driver):


def test_create_fulltext_index_happy_path(driver):
driver.execute_query.return_value = [None, None, None]
label = "node-label"
text_node_properties = ["property-1", "property-2"]
create_query = (
Expand All @@ -87,7 +83,32 @@ def test_create_fulltext_index_happy_path(driver):

driver.execute_query.assert_called_once_with(
create_query,
{
"name": "my-index"
},
{"name": "my-index"},
)


def test_create_fulltext_index_empty_node_properties(driver):
label = "node-label"
node_properties = []

with pytest.raises(ValueError) as excinfo:
create_fulltext_index(driver, "my-index", label, node_properties)

assert "Error for inputs to create_fulltext_index" in str(excinfo)


def test_create_fulltext_index_ensure_escaping(driver, client):
label = "node-label"
text_node_properties = ["property-1", "property-2"]
create_query = (
"CREATE FULLTEXT INDEX $name"
f"FOR (n:`{label}`) ON EACH "
f"[{', '.join(['n.`' + property + '`' for property in text_node_properties])}]"
)

create_fulltext_index(driver, "my-complicated-`-index", label, text_node_properties)

driver.execute_query.assert_called_once_with(
create_query,
{"name": "my-complicated-`-index"},
)

0 comments on commit 20fb738

Please sign in to comment.