Skip to content

Commit

Permalink
Changed test to mock neo4j driver's execute_driver instead of the cli…
Browse files Browse the repository at this point in the history
…ent's _database_query()

Changed test_client.py to have mocks at execute_query level

Changed precommit config to include ruff linting and formatting

Update pre-commit
  • Loading branch information
willtai committed Mar 13, 2024
1 parent f134f7d commit a262c48
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 101 deletions.
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@ repos:
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.2
hooks:
- id: ruff
args: [ --fix ]
- id: ruff-format
3 changes: 2 additions & 1 deletion examples/similarity_search_for_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def embed_query(self, text: str) -> List[float]:
# Initialize the client
client = GenAIClient(driver, embedder)

client.drop_index(INDEX_NAME)
# Creating the index
client.create_index(
INDEX_NAME,
Expand All @@ -46,7 +47,7 @@ def embed_query(self, text: str) -> List[float]:
parameters = {
"vector": vector,
}
client.database_query(insert_query, params=parameters)
client._database_query(insert_query, params=parameters)

# Perform the similarity search for a text query
query_text = "hello world"
Expand Down
2 changes: 1 addition & 1 deletion examples/similarity_search_for_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
parameters = {
"vector": vector,
}
client.database_query(insert_query, params=parameters)
client._database_query(insert_query, params=parameters)

# Perform the similarity search for a vector query
query_vector = [random() for _ in range(DIMENSION)]
Expand Down
23 changes: 12 additions & 11 deletions src/neo4j_genai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _verify_version(self) -> None:
indexing. Raises a ValueError if the connected Neo4j version is
not supported.
"""
version = self.database_query("CALL dbms.components()")[0]["versions"][0]
version = self._database_query("CALL dbms.components()")[0]["versions"][0]
if "aura" in version:
version_tuple = (
*tuple(map(int, version.split("-")[0].split("."))),
Expand All @@ -45,7 +45,9 @@ def _verify_version(self) -> None:
"Version index is only supported in Neo4j version 5.11 or greater"
)

def database_query(self, query: str, params: Dict = {}) -> List[Dict[str, Any]]:
def _database_query(
self, query: str, params: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
This method sends a Cypher query to the connected Neo4j database
and returns the results as a list of dictionaries.
Expand All @@ -57,12 +59,11 @@ def database_query(self, query: str, params: Dict = {}) -> List[Dict[str, Any]]:
Returns:
List[Dict[str, Any]]: List of dictionaries containing the query results.
"""
with self.driver.session() as session:
try:
data = session.run(query, params)
return [r.data() for r in data]
except CypherSyntaxError as e:
raise ValueError(f"Cypher Statement is not valid\n{e}")
try:
records, _, _ = self.driver.execute_query(query, params)
return records
except CypherSyntaxError as e:
raise ValueError(f"Cypher Statement is not valid\n{e}")

def create_index(
self,
Expand Down Expand Up @@ -109,7 +110,7 @@ def create_index(
"toInteger($dimensions),"
"$similarity_fn )"
)
self.database_query(query, params=index_data.model_dump())
self._database_query(query, params=index_data.model_dump())

def drop_index(self, name: str) -> None:
"""
Expand All @@ -124,7 +125,7 @@ def drop_index(self, name: str) -> None:
parameters = {
"name": name,
}
self.database_query(query, params=parameters)
self._database_query(query, params=parameters)

def similarity_search(
self,
Expand Down Expand Up @@ -176,7 +177,7 @@ def similarity_search(
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
"""
records = self.database_query(db_query_string, params=parameters)
records = self._database_query(db_query_string, params=parameters)

try:
return [
Expand Down
1 change: 1 addition & 0 deletions src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class CreateIndexModel(BaseModel):
name: str
label: str
property: str
# TODO: consider changing this to positive integer
dimensions: int = Field(ge=1, le=2048)
similarity_fn: Literal["euclidean", "cosine"]

Expand Down
14 changes: 10 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import pytest
from neo4j_genai import GenAIClient
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, patch
from typing import List
from neo4j_genai.embedder import Embedder


@pytest.fixture
def driver():
return Mock()
return MagicMock()


@pytest.fixture
Expand All @@ -20,8 +20,14 @@ def client(_verify_version_mock, driver):
@patch("neo4j_genai.GenAIClient._verify_version")
def client_with_embedder(_verify_version_mock, driver):
class CustomEmbedder(Embedder):
def __init__(self):
self.dimension = 1536

def embed_query(self, text: str) -> List[float]:
return [1.0 for _ in range(1536)]
return [1.0 for _ in range(self.dimension)]

def set_dimension(self, dimension: int):
self.dimension = dimension

embedder = CustomEmbedder()
return GenAIClient(driver, embedder)
return GenAIClient(driver, embedder), embedder
152 changes: 68 additions & 84 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,50 @@
import pytest
from neo4j_genai import GenAIClient
from unittest.mock import Mock, patch
from neo4j.exceptions import CypherSyntaxError


@patch(
"neo4j_genai.GenAIClient.database_query",
return_value=[{"versions": ["5.11-aura"]}],
)
def test_genai_client_supported_aura_version(mock_database_query, driver):
GenAIClient(driver)
mock_database_query.assert_called_once()
def test_genai_client_supported_aura_version(driver):
driver.execute_query.return_value = [[{"versions": ["5.11-aura"]}], None, None]

GenAIClient(driver=driver)


@patch(
"neo4j_genai.GenAIClient.database_query",
return_value=[{"versions": ["5.3-aura"]}],
)
def test_genai_client_no_supported_aura_version(driver):
driver.execute_query.return_value = [[{"versions": ["5.3-aura"]}], None, None]

with pytest.raises(ValueError):
GenAIClient(driver)
GenAIClient(driver=driver)


def test_genai_client_supported_version(driver):
driver.execute_query.return_value = [[{"versions": ["5.11.5"]}], None, None]

@patch(
"neo4j_genai.GenAIClient.database_query",
return_value=[{"versions": ["5.11.5"]}],
)
def test_genai_client_supported_version(mock_database_query, driver):
GenAIClient(driver)
mock_database_query.assert_called_once()
GenAIClient(driver=driver)


@patch(
"neo4j_genai.GenAIClient.database_query",
return_value=[{"versions": ["4.3.5"]}],
)
def test_genai_client_no_supported_version(driver):
driver.execute_query.return_value = [[{"versions": ["4.3.5"]}], None, None]

with pytest.raises(ValueError):
GenAIClient(driver)
GenAIClient(driver=driver)


@patch("neo4j_genai.GenAIClient.database_query")
def test_create_index_happy_path(mock_database_query, client):
client.create_index("my-index", "People", "name", 2048, "cosine")
query = (
def test_create_index_happy_path(driver, client):
driver.execute_query.return_value = [None, None, None]
create_query = (
"CALL db.index.vector.createNodeIndex("
"$name,"
"$label,"
"$property,"
"toInteger($dimensions),"
"$similarity_fn )"
)
mock_database_query.assert_called_once_with(
query,
params={

client.create_index("my-index", "People", "name", 2048, "cosine")

driver.execute_query.assert_called_once_with(
create_query,
{
"name": "my-index",
"label": "People",
"property": "name",
Expand All @@ -63,11 +54,6 @@ def test_create_index_happy_path(mock_database_query, client):
)


def test_create_index_too_big_dimension(client):
with pytest.raises(ValueError):
client.create_index("my-index", "People", "name", 5024, "cosine")


def test_create_index_validation_error_dimensions(client):
with pytest.raises(ValueError) as excinfo:
client.create_index("my-index", "People", "name", "no-dim", "cosine")
Expand All @@ -80,35 +66,23 @@ def test_create_index_validation_error_similarity_fn(client):
assert "Error for inputs to create_index" in str(excinfo)


@patch("neo4j_genai.GenAIClient.database_query")
def test_drop_index(mock_database_query, client):
client.drop_index("my-index")
def test_drop_index(driver, client):
driver.execute_query.return_value = [None, None, None]
drop_query = "DROP INDEX $name"

query = "DROP INDEX $name"
client.drop_index("my-index")

mock_database_query.assert_called_with(query, params={"name": "my-index"})
driver.execute_query.assert_called_once_with(
drop_query,
{"name": "my-index"},
)


def test_database_query_happy(client, driver):
class Session:
def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
pass

def run(self, query, params):
m_list = []
for i in range(3):
mock = Mock()
mock.data.return_value = i
m_list.append(mock)

return m_list

driver.session = Session
res = client.database_query("MATCH (p:$label) RETURN p", {"label": "People"})
assert res == [0, 1, 2]
expected_db_result = [0, 1, 2]
driver.execute_query.return_value = [expected_db_result, None, None]
res = client._database_query("MATCH (p:$label) RETURN p", {"label": "People"})
assert res == expected_db_result


def test_database_query_cypher_error(client, driver):
Expand All @@ -125,49 +99,59 @@ def run(self, query, params):
driver.session = Session

with pytest.raises(ValueError):
client.database_query("MATCH (p:$label) RETURN p", {"label": "People"})
client._database_query("MATCH (p:$label) RETURN p", {"label": "People"})


@patch("neo4j_genai.GenAIClient.database_query")
def test_similarity_search_vector_happy_path(mock_database_query, client):
def test_similarity_search_vector_happy_path(driver, client):
index_name = "my-index"
query_vector = [1.1, 2.2, 3.3]
dimensions = 1536
query_vector = [1.0 for _ in range(dimensions)]
top_k = 5

client.similarity_search(name=index_name, query_vector=query_vector, top_k=top_k)

query = """
driver.execute_query.return_value = [
[{"node": "dummy-node", "score": 1.0}],
None,
None,
]
search_query = """
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
"""
mock_database_query.assert_called_once_with(
query,
params={

client.similarity_search(name=index_name, query_vector=query_vector, top_k=top_k)

driver.execute_query.assert_called_once_with(
search_query,
{
"index_name": index_name,
"top_k": top_k,
"query_vector": query_vector,
},
)


@patch("neo4j_genai.GenAIClient.database_query")
def test_similarity_search_text_happy_path(mock_database_query, client_with_embedder):
def test_similarity_search_text_happy_path(driver, client_with_embedder):
client, embedder = client_with_embedder
index_name = "my-index"
query_text = "may thy knife chip and shatter"
query_vector = [1.0 for _ in range(1536)]
dimensions = 1536
query_vector = [1.0 for _ in range(dimensions)]
top_k = 5

client_with_embedder.similarity_search(
name=index_name, query_text=query_text, top_k=top_k
)

query = """
driver.execute_query.return_value = [
[{"node": "dummy-node", "score": 1.0}],
None,
None,
]
embedder.set_dimension(dimensions)
search_query = """
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
"""
mock_database_query.assert_called_once_with(
query,
params={

client.similarity_search(name=index_name, query_text=query_text, top_k=top_k)

driver.execute_query.assert_called_once_with(
search_query,
{
"index_name": index_name,
"top_k": top_k,
"query_vector": query_vector,
Expand Down

0 comments on commit a262c48

Please sign in to comment.