-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
450 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
root = true | ||
|
||
[*] | ||
indent_style = space | ||
indent_size = 4 | ||
insert_final_newline = true | ||
trim_trailing_whitespace = true | ||
end_of_line = lf | ||
charset = utf-8 | ||
max_line_length = 88 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
name: neo4j_genai PR | ||
on: pull_request | ||
|
||
jobs: | ||
test: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Check out repository code | ||
uses: actions/checkout@v4 | ||
- name: Install Poetry | ||
uses: snok/install-poetry@v1 | ||
with: | ||
virtualenvs-create: true | ||
virtualenvs-in-project: true | ||
installer-parallel: true | ||
- name: Load cached venv | ||
id: cached-poetry-dependencies | ||
uses: actions/cache@v4 | ||
with: | ||
path: .venv | ||
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} | ||
- name: Install dependencies | ||
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' | ||
run: poetry install --no-interaction --no-root | ||
- name: Install root project | ||
run: poetry install --no-interaction | ||
- name: Run tests and check coverage | ||
run: | | ||
poetry run coverage run -m pytest | ||
poetry run coverage report --fail-under=85 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
dist/ | ||
**/__pycache__/* | ||
*.py[cod] | ||
.mypy_cache/ | ||
.mypy_cache/ | ||
.coverage | ||
htmlcov/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
repos: | ||
- repo: https://github.com/pre-commit/pre-commit-hooks | ||
rev: v4.5.0 # Use the ref you want to point at | ||
rev: v4.5.0 | ||
hooks: | ||
- id: trailing-whitespace | ||
# - id: ... | ||
- id: end-of-file-fixer |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .client import GenAIClient |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import pytest | ||
from neo4j_genai import GenAIClient | ||
from unittest.mock import Mock, patch | ||
|
||
|
||
@pytest.fixture | ||
def driver(): | ||
return Mock() | ||
|
||
|
||
@pytest.fixture | ||
@patch("neo4j_genai.GenAIClient._verify_version") | ||
def client(_verify_version_mock, driver): | ||
return GenAIClient(driver) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import pytest | ||
from neo4j_genai import GenAIClient | ||
from unittest.mock import Mock, patch | ||
from neo4j.exceptions import CypherSyntaxError | ||
|
||
|
||
@patch( | ||
"neo4j_genai.GenAIClient.database_query", return_value=[{"versions": ["5.11-aura"]}] | ||
) | ||
def test_genai_client_supported_aura_version(mock_database_query, driver): | ||
GenAIClient(driver) | ||
mock_database_query.assert_called_once() | ||
|
||
|
||
@patch( | ||
"neo4j_genai.GenAIClient.database_query", return_value=[{"versions": ["5.3-aura"]}] | ||
) | ||
def test_genai_client_no_supported_aura_version(driver): | ||
with pytest.raises(ValueError): | ||
GenAIClient(driver) | ||
|
||
|
||
@patch( | ||
"neo4j_genai.GenAIClient.database_query", return_value=[{"versions": ["5.11.5"]}] | ||
) | ||
def test_genai_client_supported_version(mock_database_query, driver): | ||
GenAIClient(driver) | ||
mock_database_query.assert_called_once() | ||
|
||
|
||
@patch("neo4j_genai.GenAIClient.database_query", return_value=[{"versions": ["4.3.5"]}]) | ||
def test_genai_client_no_supported_version(driver): | ||
with pytest.raises(ValueError): | ||
GenAIClient(driver) | ||
|
||
|
||
@patch("neo4j_genai.GenAIClient.database_query") | ||
def test_create_index_happy_path(mock_database_query, client): | ||
client.create_index("my-index", "People", "name", 1024, "cosine") | ||
query = ( | ||
"CALL db.index.vector.createNodeIndex(" | ||
"$name," | ||
"$label," | ||
"$property," | ||
"toInteger($dimensions)," | ||
"$similarity_fn )" | ||
) | ||
mock_database_query.assert_called_once_with( | ||
query, | ||
params={ | ||
"name": "my-index", | ||
"label": "People", | ||
"property": "name", | ||
"dimensions": 1024, | ||
"similarity_fn": "cosine", | ||
}, | ||
) | ||
|
||
|
||
def test_create_index_validation_error_dimensions(client): | ||
with pytest.raises(ValueError) as excinfo: | ||
client.create_index("my-index", "People", "name", "no-dim", "cosine") | ||
assert "Error for inputs to create_index" in str(excinfo) | ||
|
||
|
||
def test_create_index_validation_error_similarity_fn(client): | ||
with pytest.raises(ValueError) as excinfo: | ||
client.create_index("my-index", "People", "name", "no-dim", "algebra") | ||
assert "Error for inputs to create_index" in str(excinfo) | ||
|
||
|
||
@patch("neo4j_genai.GenAIClient.database_query") | ||
def test_drop_index(mock_database_query, client): | ||
client.drop_index("my-index") | ||
|
||
query = "DROP INDEX $name" | ||
|
||
mock_database_query.assert_called_with(query, params={"name": "my-index"}) | ||
|
||
|
||
def test_database_query_happy(client, driver): | ||
class Session: | ||
def __enter__(self): | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_value, traceback): | ||
pass | ||
|
||
def run(self, query, params): | ||
m_list = [] | ||
for i in range(3): | ||
mock = Mock() | ||
mock.data.return_value = i | ||
m_list.append(mock) | ||
|
||
return m_list | ||
|
||
driver.session = Session | ||
res = client.database_query("MATCH (p:$label) RETURN p", {"label": "People"}) | ||
assert res == [0, 1, 2] | ||
|
||
|
||
def test_database_query_cypher_error(client, driver): | ||
class Session: | ||
def __enter__(self): | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_value, traceback): | ||
pass | ||
|
||
def run(self, query, params): | ||
raise CypherSyntaxError | ||
|
||
driver.session = Session | ||
|
||
with pytest.raises(ValueError): | ||
client.database_query("MATCH (p:$label) RETURN p", {"label": "People"}) | ||
|
||
|
||
def test_similarity_search(): | ||
pass |