Skip to content

Commit

Permalink
Moved index_name definition to constructor level of VectorRetriever
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed Apr 8, 2024
1 parent 6e4b3fe commit 6d7d6ec
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 31 deletions.
4 changes: 2 additions & 2 deletions examples/openai_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
embedder = OpenAIEmbeddings(model="text-embedding-3-large")

# Initialize the retriever
retriever = VectorRetriever(driver, embedder)
retriever = VectorRetriever(driver, INDEX_NAME, embedder)

# Creating the index
create_vector_index(
Expand All @@ -47,4 +47,4 @@

# Perform the similarity search for a text query
query_text = "hello world"
print(retriever.search(INDEX_NAME, query_text=query_text, top_k=5))
print(retriever.search(query_text=query_text, top_k=5))
8 changes: 4 additions & 4 deletions examples/similarity_search_for_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ def embed_query(self, text: str) -> List[float]:

embedder = CustomEmbedder()

# Initialize the retriever
retriever = VectorRetriever(driver, embedder)

# Creating the index
create_vector_index(
driver,
Expand All @@ -37,6 +34,9 @@ def embed_query(self, text: str) -> List[float]:
similarity_fn="euclidean",
)

# Initialize the retriever
retriever = VectorRetriever(driver, INDEX_NAME, embedder)

# Upsert the query
vector = [random() for _ in range(DIMENSION)]
insert_query = (
Expand All @@ -52,4 +52,4 @@ def embed_query(self, text: str) -> List[float]:

# Perform the similarity search for a text query
query_text = "hello world"
print(retriever.search(INDEX_NAME, query_text=query_text, top_k=5))
print(retriever.search(query_text=query_text, top_k=5))
8 changes: 4 additions & 4 deletions examples/similarity_search_for_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
# Connect to Neo4j database
driver = GraphDatabase.driver(URI, auth=AUTH)

# Initialize the retriever
retriever = VectorRetriever(driver)

# Creating the index
create_vector_index(
driver,
Expand All @@ -27,6 +24,9 @@
similarity_fn="euclidean",
)

# Initialize the retriever
retriever = VectorRetriever(driver, INDEX_NAME)

# Upsert the vector
vector = [random() for _ in range(DIMENSION)]
insert_query = (
Expand All @@ -43,4 +43,4 @@

# Perform the similarity search for a vector query
query_vector = [random() for _ in range(DIMENSION)]
print(retriever.search(INDEX_NAME, query_vector=query_vector, top_k=5))
print(retriever.search(query_vector=query_vector, top_k=5))
2 changes: 1 addition & 1 deletion src/neo4j_genai/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def create_vector_index(
raise ValueError(f"Error for inputs to create_vector_index {str(e)}")

query = (
f"CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:{label}) ON n.{property} OPTIONS "
f"CREATE VECTOR INDEX $name FOR (n:{label}) ON n.{property} OPTIONS "
"{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }"
)
driver.execute_query(
Expand Down
5 changes: 3 additions & 2 deletions src/neo4j_genai/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ class VectorRetriever:
def __init__(
self,
driver: Driver,
index_name: str,
embedder: Optional[Embedder] = None,
) -> None:
self.driver = driver
self._verify_version()
self.index_name = index_name
self.embedder = embedder

def _verify_version(self) -> None:
Expand Down Expand Up @@ -48,7 +50,6 @@ def _verify_version(self) -> None:

def search(
self,
name: str,
query_vector: Optional[List[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
Expand All @@ -74,7 +75,7 @@ def search(
"""
try:
validated_data = SimilaritySearchModel(
index_name=name,
index_name=self.index_name,
top_k=top_k,
query_vector=query_vector,
query_text=query_text,
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ def driver():
@pytest.fixture
@patch("neo4j_genai.VectorRetriever._verify_version")
def retriever(_verify_version_mock, driver):
return VectorRetriever(driver)
return VectorRetriever(driver, "my-index")
31 changes: 14 additions & 17 deletions tests/test_retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,29 @@
def test_vector_retriever_supported_aura_version(driver):
driver.execute_query.return_value = [[{"versions": ["5.18-aura"]}], None, None]

VectorRetriever(driver=driver)
VectorRetriever(driver=driver, index_name="my-index")


def test_vector_retriever_no_supported_aura_version(driver):
driver.execute_query.return_value = [[{"versions": ["5.3-aura"]}], None, None]

with pytest.raises(ValueError) as excinfo:
VectorRetriever(driver=driver)
VectorRetriever(driver=driver, index_name="my-index")

assert "This package only supports Neo4j version 5.18.1 or greater" in str(excinfo)


def test_vector_retriever_supported_version(driver):
driver.execute_query.return_value = [[{"versions": ["5.19.0"]}], None, None]

VectorRetriever(driver=driver)
VectorRetriever(driver=driver, index_name="my-index")


def test_vector_retriever_no_supported_version(driver):
driver.execute_query.return_value = [[{"versions": ["4.3.5"]}], None, None]

with pytest.raises(ValueError) as excinfo:
VectorRetriever(driver=driver)
VectorRetriever(driver=driver, index_name="my-index")

assert "This package only supports Neo4j version 5.18.1 or greater" in str(excinfo)

Expand All @@ -38,13 +38,13 @@ def test_vector_retriever_no_supported_version(driver):
def test_similarity_search_vector_happy_path(_verify_version_mock, driver):
custom_embeddings = MagicMock()

retriever = VectorRetriever(driver, custom_embeddings)

index_name = "my-index"
dimensions = 1536
query_vector = [1.0 for _ in range(dimensions)]
top_k = 5

retriever = VectorRetriever(driver, index_name, custom_embeddings)

retriever.driver.execute_query.return_value = [
[{"node": "dummy-node", "score": 1.0}],
None,
Expand All @@ -55,7 +55,7 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver):
YIELD node, score
"""

records = retriever.search(name=index_name, query_vector=query_vector, top_k=top_k)
records = retriever.search(query_vector=query_vector, top_k=top_k)

custom_embeddings.embed_query.assert_not_called()

Expand All @@ -77,12 +77,12 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver):
custom_embeddings = MagicMock()
custom_embeddings.embed_query.return_value = embed_query_vector

retriever = VectorRetriever(driver, custom_embeddings)

index_name = "my-index"
query_text = "may thy knife chip and shatter"
top_k = 5

retriever = VectorRetriever(driver, index_name, custom_embeddings)

driver.execute_query.return_value = [
[{"node": "dummy-node", "score": 1.0}],
None,
Expand All @@ -94,7 +94,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver):
YIELD node, score
"""

records = retriever.search(name=index_name, query_text=query_text, top_k=top_k)
records = retriever.search(query_text=query_text, top_k=top_k)

custom_embeddings.embed_query.assert_called_once_with(query_text)

Expand All @@ -111,16 +111,14 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver):


def test_similarity_search_missing_embedder_for_text(retriever):
index_name = "my-index"
query_text = "may thy knife chip and shatter"
top_k = 5

with pytest.raises(ValueError, match="Embedding method required for text query"):
retriever.search(name=index_name, query_text=query_text, top_k=top_k)
retriever.search(query_text=query_text, top_k=top_k)


def test_similarity_search_both_text_and_vector(retriever):
index_name = "my-index"
query_text = "may thy knife chip and shatter"
query_vector = [1.1, 2.2, 3.3]
top_k = 5
Expand All @@ -129,7 +127,6 @@ def test_similarity_search_both_text_and_vector(retriever):
ValueError, match="You must provide exactly one of query_vector or query_text."
):
retriever.search(
name=index_name,
query_text=query_text,
query_vector=query_vector,
top_k=top_k,
Expand All @@ -140,13 +137,13 @@ def test_similarity_search_both_text_and_vector(retriever):
def test_similarity_search_vector_bad_results(_verify_version_mock, driver):
custom_embeddings = MagicMock()

retriever = VectorRetriever(driver, custom_embeddings)

index_name = "my-index"
dimensions = 1536
query_vector = [1.0 for _ in range(dimensions)]
top_k = 5

retriever = VectorRetriever(driver, index_name, custom_embeddings)

retriever.driver.execute_query.return_value = [
[{"node": "dummy-node", "score": "adsa"}],
None,
Expand All @@ -158,7 +155,7 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver):
"""

with pytest.raises(ValueError):
retriever.search(name=index_name, query_vector=query_vector, top_k=top_k)
retriever.search(query_vector=query_vector, top_k=top_k)

custom_embeddings.embed_query.assert_not_called()

Expand Down

0 comments on commit 6d7d6ec

Please sign in to comment.