Skip to content

Commit

Permalink
Added VectorRetriever class and remove GenAIClient
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed Apr 4, 2024
1 parent 20fb738 commit 32b967d
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 117 deletions.
12 changes: 7 additions & 5 deletions examples/similarity_search_for_text.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import List
from neo4j import GraphDatabase
from neo4j_genai.client import GenAIClient
from neo4j_genai import VectorRetriever

from random import random
from neo4j_genai.embedder import Embedder
from neo4j_genai.indexes import create_vector_index

URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")
Expand All @@ -23,11 +24,12 @@ def embed_query(self, text: str) -> List[float]:

embedder = CustomEmbedder()

# Initialize the client
client = GenAIClient(driver, embedder)
# Initialize the retriever
retriever = VectorRetriever(driver, embedder)

# Creating the index
client.create_index(
create_vector_index(
driver,
INDEX_NAME,
label="Document",
property="propertyKey",
Expand All @@ -50,4 +52,4 @@ def embed_query(self, text: str) -> List[float]:

# Perform the similarity search for a text query
query_text = "hello world"
print(client.similarity_search(INDEX_NAME, query_text=query_text, top_k=5))
print(retriever.search(INDEX_NAME, query_text=query_text, top_k=5))
13 changes: 8 additions & 5 deletions examples/similarity_search_for_vector.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from neo4j import GraphDatabase
from neo4j_genai.client import GenAIClient
from neo4j_genai import VectorRetriever

from random import random

from neo4j_genai.indexes import create_vector_index

URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")

Expand All @@ -12,11 +14,12 @@
# Connect to Neo4j database
driver = GraphDatabase.driver(URI, auth=AUTH)

# Initialize the client
client = GenAIClient(driver)
# Initialize the retriever
retriever = VectorRetriever(driver)

# Creating the index
client.create_index(
create_vector_index(
driver,
INDEX_NAME,
label="Document",
property="propertyKey",
Expand All @@ -40,4 +43,4 @@

# Perform the similarity search for a vector query
query_vector = [random() for _ in range(DIMENSION)]
print(client.similarity_search(INDEX_NAME, query_vector=query_vector, top_k=5))
print(retriever.search(INDEX_NAME, query_vector=query_vector, top_k=5))
4 changes: 2 additions & 2 deletions src/neo4j_genai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .client import GenAIClient
from .retrievers import VectorRetriever


__all__ = ["GenAIClient"]
__all__ = ["VectorRetriever"]
70 changes: 4 additions & 66 deletions src/neo4j_genai/client.py → src/neo4j_genai/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from pydantic import ValidationError
from neo4j import Driver
from .embedder import Embedder
from .types import VectorIndexModel, SimilaritySearchModel, Neo4jRecord
from .types import SimilaritySearchModel, Neo4jRecord


class GenAIClient:
class VectorRetriever:
"""
Provides functionality to use Neo4j's GenAI features
Provides retrieval methods using vector search over embeddings
"""

def __init__(
Expand Down Expand Up @@ -46,69 +46,7 @@ def _verify_version(self) -> None:
"This package only supports Neo4j version 5.18.1 or greater"
)

def create_index(
self,
name: str,
label: str,
property: str,
dimensions: int,
similarity_fn: str,
) -> None:
"""
This method constructs a Cypher query and executes it
to create a new vector index in Neo4j.
See Cypher manual on [Create node index](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_createNodeIndex)
Args:
name (str): The unique name of the index.
label (str): The node label to be indexed.
property (str): The property key of a node which contains embedding values.
dimensions (int): Vector embedding dimension
similarity_fn (str): case-insensitive values for the vector similarity function:
``euclidean`` or ``cosine``.
Raises:
ValueError: If validation of the input arguments fail.
"""
try:
VectorIndexModel(
**{
"name": name,
"label": label,
"property": property,
"dimensions": dimensions,
"similarity_fn": similarity_fn,
}
)
except ValidationError as e:
raise ValueError(f"Error for inputs to create_index {str(e)}")

query = (
f"CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:{label}) ON n.{property} OPTIONS "
"{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }"
)
self.driver.execute_query(
query,
{"name": name, "dimensions": dimensions, "similarity_fn": similarity_fn},
)

def drop_index(self, name: str) -> None:
"""
This method constructs a Cypher query and executes it
to drop a vector index in Neo4j.
See Cypher manual on [Drop vector indexes](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-drop)
Args:
name (str): The name of the index to delete.
"""
query = "DROP INDEX $name"
parameters = {
"name": name,
}
self.driver.execute_query(query, parameters)

def similarity_search(
def search(
self,
name: str,
query_vector: Optional[List[float]] = None,
Expand Down
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from neo4j_genai import GenAIClient
from neo4j_genai import VectorRetriever
from neo4j import Driver
from unittest.mock import MagicMock, patch

Expand All @@ -10,6 +10,6 @@ def driver():


@pytest.fixture
@patch("neo4j_genai.GenAIClient._verify_version")
def client(_verify_version_mock, driver):
return GenAIClient(driver)
@patch("neo4j_genai.VectorRetriever._verify_version")
def retriever(_verify_version_mock, driver):
return VectorRetriever(driver)
4 changes: 2 additions & 2 deletions tests/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_create_vector_index_happy_path(driver):
)


def test_create_vector_index_ensure_escaping(driver, client):
def test_create_vector_index_ensure_escaping(driver):
create_query = (
"CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:People) ON n.name OPTIONS "
"{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }"
Expand Down Expand Up @@ -97,7 +97,7 @@ def test_create_fulltext_index_empty_node_properties(driver):
assert "Error for inputs to create_fulltext_index" in str(excinfo)


def test_create_fulltext_index_ensure_escaping(driver, client):
def test_create_fulltext_index_ensure_escaping(driver):
label = "node-label"
text_node_properties = ["property-1", "property-2"]
create_query = (
Expand Down
60 changes: 27 additions & 33 deletions tests/test_client.py → tests/test_retrievers.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,51 @@
import pytest
from neo4j_genai import GenAIClient
from neo4j_genai.types import Neo4jRecord
from unittest.mock import patch, MagicMock
from neo4j_genai import VectorRetriever
from neo4j_genai.types import Neo4jRecord


def test_genai_client_supported_aura_version(driver):
def test_vector_retriever_supported_aura_version(driver):
driver.execute_query.return_value = [[{"versions": ["5.18-aura"]}], None, None]

GenAIClient(driver=driver)
VectorRetriever(driver=driver)


def test_genai_client_no_supported_aura_version(driver):
def test_vector_retriever_no_supported_aura_version(driver):
driver.execute_query.return_value = [[{"versions": ["5.3-aura"]}], None, None]

with pytest.raises(ValueError) as excinfo:
GenAIClient(driver=driver)
VectorRetriever(driver=driver)

assert "This package only supports Neo4j version 5.18.1 or greater" in str(excinfo)


def test_genai_client_supported_version(driver):
def test_vector_retriever_supported_version(driver):
driver.execute_query.return_value = [[{"versions": ["5.19.0"]}], None, None]

GenAIClient(driver=driver)
VectorRetriever(driver=driver)


def test_genai_client_no_supported_version(driver):
def test_vector_retriever_no_supported_version(driver):
driver.execute_query.return_value = [[{"versions": ["4.3.5"]}], None, None]

with pytest.raises(ValueError) as excinfo:
GenAIClient(driver=driver)
VectorRetriever(driver=driver)

assert "This package only supports Neo4j version 5.18.1 or greater" in str(excinfo)


@patch("neo4j_genai.GenAIClient._verify_version")
@patch("neo4j_genai.VectorRetriever._verify_version")
def test_similarity_search_vector_happy_path(_verify_version_mock, driver):
custom_embeddings = MagicMock()

client = GenAIClient(driver, custom_embeddings)
retriever = VectorRetriever(driver, custom_embeddings)

index_name = "my-index"
dimensions = 1536
query_vector = [1.0 for _ in range(dimensions)]
top_k = 5

client.driver.execute_query.return_value = [
retriever.driver.execute_query.return_value = [
[{"node": "dummy-node", "score": 1.0}],
None,
None,
Expand All @@ -55,13 +55,11 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver):
YIELD node, score
"""

records = client.similarity_search(
name=index_name, query_vector=query_vector, top_k=top_k
)
records = retriever.search(name=index_name, query_vector=query_vector, top_k=top_k)

custom_embeddings.embed_query.assert_not_called()

client.driver.execute_query.assert_called_once_with(
retriever.driver.execute_query.assert_called_once_with(
search_query,
{
"index_name": index_name,
Expand All @@ -73,13 +71,13 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver):
assert records == [Neo4jRecord(node="dummy-node", score=1.0)]


@patch("neo4j_genai.GenAIClient._verify_version")
@patch("neo4j_genai.VectorRetriever._verify_version")
def test_similarity_search_text_happy_path(_verify_version_mock, driver):
embed_query_vector = [1.0 for _ in range(1536)]
custom_embeddings = MagicMock()
custom_embeddings.embed_query.return_value = embed_query_vector

client = GenAIClient(driver, custom_embeddings)
retriever = VectorRetriever(driver, custom_embeddings)

index_name = "my-index"
query_text = "may thy knife chip and shatter"
Expand All @@ -96,9 +94,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver):
YIELD node, score
"""

records = client.similarity_search(
name=index_name, query_text=query_text, top_k=top_k
)
records = retriever.search(name=index_name, query_text=query_text, top_k=top_k)

custom_embeddings.embed_query.assert_called_once_with(query_text)

Expand All @@ -114,16 +110,16 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver):
assert records == [Neo4jRecord(node="dummy-node", score=1.0)]


def test_similarity_search_missing_embedder_for_text(client):
def test_similarity_search_missing_embedder_for_text(retriever):
index_name = "my-index"
query_text = "may thy knife chip and shatter"
top_k = 5

with pytest.raises(ValueError, match="Embedding method required for text query"):
client.similarity_search(name=index_name, query_text=query_text, top_k=top_k)
retriever.search(name=index_name, query_text=query_text, top_k=top_k)


def test_similarity_search_both_text_and_vector(client):
def test_similarity_search_both_text_and_vector(retriever):
index_name = "my-index"
query_text = "may thy knife chip and shatter"
query_vector = [1.1, 2.2, 3.3]
Expand All @@ -132,26 +128,26 @@ def test_similarity_search_both_text_and_vector(client):
with pytest.raises(
ValueError, match="You must provide exactly one of query_vector or query_text."
):
client.similarity_search(
retriever.search(
name=index_name,
query_text=query_text,
query_vector=query_vector,
top_k=top_k,
)


@patch("neo4j_genai.GenAIClient._verify_version")
@patch("neo4j_genai.VectorRetriever._verify_version")
def test_similarity_search_vector_bad_results(_verify_version_mock, driver):
custom_embeddings = MagicMock()

client = GenAIClient(driver, custom_embeddings)
retriever = VectorRetriever(driver, custom_embeddings)

index_name = "my-index"
dimensions = 1536
query_vector = [1.0 for _ in range(dimensions)]
top_k = 5

client.driver.execute_query.return_value = [
retriever.driver.execute_query.return_value = [
[{"node": "dummy-node", "score": "adsa"}],
None,
None,
Expand All @@ -162,13 +158,11 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver):
"""

with pytest.raises(ValueError):
client.similarity_search(
name=index_name, query_vector=query_vector, top_k=top_k
)
retriever.search(name=index_name, query_vector=query_vector, top_k=top_k)

custom_embeddings.embed_query.assert_not_called()

client.driver.execute_query.assert_called_once_with(
retriever.driver.execute_query.assert_called_once_with(
search_query,
{
"index_name": index_name,
Expand Down

0 comments on commit 32b967d

Please sign in to comment.