Skip to content

Commit

Permalink
Update Github Actions
Browse files Browse the repository at this point in the history
  • Loading branch information
jonbesga committed Mar 4, 2024
1 parent 7a0b30a commit 4b6fb94
Show file tree
Hide file tree
Showing 12 changed files with 450 additions and 20 deletions.
10 changes: 10 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
root = true

[*]
indent_style = space
indent_size = 4
insert_final_newline = true
trim_trailing_whitespace = true
end_of_line = lf
charset = utf-8
max_line_length = 88
30 changes: 30 additions & 0 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: neo4j_genai PR
on: pull_request

jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Install Poetry
uses: snok/install-poetry@v1
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v4
with:
path: .venv
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction --no-root
- name: Install root project
run: poetry install --no-interaction
- name: Run tests and check coverage
run: |
poetry run coverage run -m pytest
poetry run coverage report --fail-under=85
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
dist/
**/__pycache__/*
*.py[cod]
.mypy_cache/
.mypy_cache/
.coverage
htmlcov/
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0 # Use the ref you want to point at
rev: v4.5.0
hooks:
- id: trailing-whitespace
# - id: ...
- id: end-of-file-fixer
253 changes: 247 additions & 6 deletions poetry.lock

Large diffs are not rendered by default.

18 changes: 15 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,26 @@ include = "neo4j_genai"
from = "src"

[tool.poetry.dependencies]
python = "^3.8"
python = "^3.9"
neo4j = "^5.17.0"
types-requests = "^2.31.0.20240218"
pytest = "^8.0.2"
pytest-mock = "^3.12.0"
pydantic = "^2.6.3"

[tool.poetry.group.dev.dependencies]
pylint = "^3.1.0"
mypy = "^1.8.0"
black = "^24.2.0"
pytest = "^8.0.2"
pytest-mock = "^3.12.0"
pre-commit = "^3.6.2"
coverage = "^7.4.3"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

exclude = ["**/tests/"]

[tool.black]
line-length = 88
target-version = ['py38']
Expand All @@ -40,3 +43,12 @@ exclude = '''
| dist
)/
'''

[tool.pytest.ini_options]
testpaths = ["tests"]
filterwarnings = [
"",
]

[tool.coverage.paths]
source = ["src"]
1 change: 1 addition & 0 deletions src/neo4j_genai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .client import GenAIClient
8 changes: 4 additions & 4 deletions src/neo4j_genai/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import List, Dict, Any, Optional
from neo4j import Driver
from neo4j.exceptions import CypherSyntaxError
from neo4j_genai.embeddings import Embeddings
from neo4j_genai.types import CreateIndexModel, SimilaritySearchModel
from .embeddings import Embeddings
from .types import CreateIndexModel, SimilaritySearchModel
from pydantic import ValidationError


Expand Down Expand Up @@ -86,7 +86,7 @@ def create_index(
"toInteger($dimensions),"
"$similarity_fn )"
)
self.database_query(query, params=index_data.dict())
self.database_query(query, params=index_data.model_dump())

def drop_index(self, name: str) -> None:
"""
Expand Down Expand Up @@ -125,6 +125,6 @@ def similarity_search(
raise ValueError("Embedding method required for text query.")
query_vector = self.embeddings.embed_query(query_text)

parameters = validated_data.dict(exclude_none=True)
parameters = validated_data.model_dump(exclude_none=True)
db_query_string = "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) YIELD node, score"
return self.database_query(db_query_string, params=parameters)
3 changes: 1 addition & 2 deletions src/neo4j_genai/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from abc import ABC, abstractmethod
from typing import List
from neo4j_genai.types import EmbeddingVector
from .types import EmbeddingVector


class Embeddings(ABC):
Expand Down
4 changes: 2 additions & 2 deletions src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel, PositiveInt, root_validator
from pydantic import BaseModel, PositiveInt, model_validator
from typing import List, Literal, Optional


Expand Down Expand Up @@ -26,7 +26,7 @@ class SimilaritySearchModel(BaseModel):
query_vector: Optional[EmbeddingVector] = None
query_text: Optional[str] = None

@root_validator(pre=True)
@model_validator(mode="before")
def check_query(cls, values):
query_vector, query_text = values.get("query_vector"), values.get("query_text")
if bool(query_vector) ^ bool(query_text):
Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest
from neo4j_genai import GenAIClient
from unittest.mock import Mock, patch


@pytest.fixture
def driver():
return Mock()


@pytest.fixture
@patch("neo4j_genai.GenAIClient._verify_version")
def client(_verify_version_mock, driver):
return GenAIClient(driver)
121 changes: 121 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import pytest
from neo4j_genai import GenAIClient
from unittest.mock import Mock, patch
from neo4j.exceptions import CypherSyntaxError


@patch(
"neo4j_genai.GenAIClient.database_query", return_value=[{"versions": ["5.11-aura"]}]
)
def test_genai_client_supported_aura_version(mock_database_query, driver):
GenAIClient(driver)
mock_database_query.assert_called_once()


@patch(
"neo4j_genai.GenAIClient.database_query", return_value=[{"versions": ["5.3-aura"]}]
)
def test_genai_client_no_supported_aura_version(driver):
with pytest.raises(ValueError):
GenAIClient(driver)


@patch(
"neo4j_genai.GenAIClient.database_query", return_value=[{"versions": ["5.11.5"]}]
)
def test_genai_client_supported_version(mock_database_query, driver):
GenAIClient(driver)
mock_database_query.assert_called_once()


@patch("neo4j_genai.GenAIClient.database_query", return_value=[{"versions": ["4.3.5"]}])
def test_genai_client_no_supported_version(driver):
with pytest.raises(ValueError):
GenAIClient(driver)


@patch("neo4j_genai.GenAIClient.database_query")
def test_create_index_happy_path(mock_database_query, client):
client.create_index("my-index", "People", "name", 1024, "cosine")
query = (
"CALL db.index.vector.createNodeIndex("
"$name,"
"$label,"
"$property,"
"toInteger($dimensions),"
"$similarity_fn )"
)
mock_database_query.assert_called_once_with(
query,
params={
"name": "my-index",
"label": "People",
"property": "name",
"dimensions": 1024,
"similarity_fn": "cosine",
},
)


def test_create_index_validation_error_dimensions(client):
with pytest.raises(ValueError) as excinfo:
client.create_index("my-index", "People", "name", "no-dim", "cosine")
assert "Error for inputs to create_index" in str(excinfo)


def test_create_index_validation_error_similarity_fn(client):
with pytest.raises(ValueError) as excinfo:
client.create_index("my-index", "People", "name", "no-dim", "algebra")
assert "Error for inputs to create_index" in str(excinfo)


@patch("neo4j_genai.GenAIClient.database_query")
def test_drop_index(mock_database_query, client):
client.drop_index("my-index")

query = "DROP INDEX $name"

mock_database_query.assert_called_with(query, params={"name": "my-index"})


def test_database_query_happy(client, driver):
class Session:
def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
pass

def run(self, query, params):
m_list = []
for i in range(3):
mock = Mock()
mock.data.return_value = i
m_list.append(mock)

return m_list

driver.session = Session
res = client.database_query("MATCH (p:$label) RETURN p", {"label": "People"})
assert res == [0, 1, 2]


def test_database_query_cypher_error(client, driver):
class Session:
def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
pass

def run(self, query, params):
raise CypherSyntaxError

driver.session = Session

with pytest.raises(ValueError):
client.database_query("MATCH (p:$label) RETURN p", {"label": "People"})


def test_similarity_search():
pass

0 comments on commit 4b6fb94

Please sign in to comment.