Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added VectorRetriever class and remove GenAIClient #6

Merged
merged 3 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions examples/openai_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from neo4j import GraphDatabase
from neo4j_genai import VectorRetriever

from random import random
from neo4j_genai.indexes import create_vector_index

from langchain_openai import OpenAIEmbeddings

URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")

INDEX_NAME = "embedding-name-large"
DIMENSION = 3072

# Connect to Neo4j database
driver = GraphDatabase.driver(URI, auth=AUTH)


# Create Embedder object
embedder = OpenAIEmbeddings(model="text-embedding-3-large")
willtai marked this conversation as resolved.
Show resolved Hide resolved

# Initialize the retriever
retriever = VectorRetriever(driver, embedder)

# Creating the index
create_vector_index(
driver,
INDEX_NAME,
label="Document",
property="propertyKey",
dimensions=DIMENSION,
similarity_fn="cosine",
)

# Upsert the query
vector = [random() for _ in range(DIMENSION)]
insert_query = (
"MERGE (n:Document)"
"WITH n "
"CALL db.create.setNodeVectorProperty(n, 'propertyKey', $vector)"
"RETURN n"
)
parameters = {
"vector": vector,
}
driver.execute_query(insert_query, parameters)

# Perform the similarity search for a text query
query_text = "hello world"
print(retriever.search(INDEX_NAME, query_text=query_text, top_k=5))
willtai marked this conversation as resolved.
Show resolved Hide resolved
oskarhane marked this conversation as resolved.
Show resolved Hide resolved
12 changes: 7 additions & 5 deletions examples/similarity_search_for_text.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import List
from neo4j import GraphDatabase
from neo4j_genai.client import GenAIClient
from neo4j_genai import VectorRetriever

from random import random
from neo4j_genai.embedder import Embedder
from neo4j_genai.indexes import create_vector_index

URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")
Expand All @@ -23,11 +24,12 @@ def embed_query(self, text: str) -> List[float]:

embedder = CustomEmbedder()

# Initialize the client
client = GenAIClient(driver, embedder)
# Initialize the retriever
retriever = VectorRetriever(driver, embedder)

# Creating the index
client.create_index(
create_vector_index(
driver,
INDEX_NAME,
label="Document",
property="propertyKey",
Expand All @@ -50,4 +52,4 @@ def embed_query(self, text: str) -> List[float]:

# Perform the similarity search for a text query
query_text = "hello world"
print(client.similarity_search(INDEX_NAME, query_text=query_text, top_k=5))
print(retriever.search(INDEX_NAME, query_text=query_text, top_k=5))
13 changes: 8 additions & 5 deletions examples/similarity_search_for_vector.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from neo4j import GraphDatabase
from neo4j_genai.client import GenAIClient
from neo4j_genai import VectorRetriever

from random import random

from neo4j_genai.indexes import create_vector_index

URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")

Expand All @@ -12,11 +14,12 @@
# Connect to Neo4j database
driver = GraphDatabase.driver(URI, auth=AUTH)

# Initialize the client
client = GenAIClient(driver)
# Initialize the retriever
retriever = VectorRetriever(driver)

# Creating the index
client.create_index(
create_vector_index(
driver,
INDEX_NAME,
label="Document",
property="propertyKey",
Expand All @@ -40,4 +43,4 @@

# Perform the similarity search for a vector query
query_vector = [random() for _ in range(DIMENSION)]
print(client.similarity_search(INDEX_NAME, query_vector=query_vector, top_k=5))
print(retriever.search(INDEX_NAME, query_vector=query_vector, top_k=5))
602 changes: 597 additions & 5 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ include = "neo4j_genai"
from = "src"

[tool.poetry.dependencies]
python = "^3.8"
python = "^3.8.1"
neo4j = "^5.17.0"
types-requests = "^2.31.0.20240218"
pydantic = "^2.6.3"
langchain-openai = "^0.1.1"

[tool.poetry.group.dev.dependencies]
pylint = "^3.1.0"
Expand Down
4 changes: 2 additions & 2 deletions src/neo4j_genai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .client import GenAIClient
from .retrievers import VectorRetriever


__all__ = ["GenAIClient"]
__all__ = ["VectorRetriever"]
70 changes: 4 additions & 66 deletions src/neo4j_genai/client.py → src/neo4j_genai/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from pydantic import ValidationError
from neo4j import Driver
from .embedder import Embedder
from .types import VectorIndexModel, SimilaritySearchModel, Neo4jRecord
from .types import SimilaritySearchModel, Neo4jRecord


class GenAIClient:
class VectorRetriever:
stellasia marked this conversation as resolved.
Show resolved Hide resolved
"""
Provides functionality to use Neo4j's GenAI features
Provides retrieval methods using vector search over embeddings
"""

def __init__(
Expand Down Expand Up @@ -46,69 +46,7 @@ def _verify_version(self) -> None:
"This package only supports Neo4j version 5.18.1 or greater"
)

def create_index(
self,
name: str,
label: str,
property: str,
dimensions: int,
similarity_fn: str,
) -> None:
"""
This method constructs a Cypher query and executes it
to create a new vector index in Neo4j.

See Cypher manual on [Create node index](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_createNodeIndex)

Args:
name (str): The unique name of the index.
label (str): The node label to be indexed.
property (str): The property key of a node which contains embedding values.
dimensions (int): Vector embedding dimension
similarity_fn (str): case-insensitive values for the vector similarity function:
``euclidean`` or ``cosine``.

Raises:
ValueError: If validation of the input arguments fail.
"""
try:
VectorIndexModel(
**{
"name": name,
"label": label,
"property": property,
"dimensions": dimensions,
"similarity_fn": similarity_fn,
}
)
except ValidationError as e:
raise ValueError(f"Error for inputs to create_index {str(e)}")

query = (
f"CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:{label}) ON n.{property} OPTIONS "
"{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }"
)
self.driver.execute_query(
query,
{"name": name, "dimensions": dimensions, "similarity_fn": similarity_fn},
)

def drop_index(self, name: str) -> None:
"""
This method constructs a Cypher query and executes it
to drop a vector index in Neo4j.
See Cypher manual on [Drop vector indexes](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-drop)

Args:
name (str): The name of the index to delete.
"""
query = "DROP INDEX $name"
parameters = {
"name": name,
}
self.driver.execute_query(query, parameters)

def similarity_search(
def search(
self,
name: str,
query_vector: Optional[List[float]] = None,
Expand Down
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from neo4j_genai import GenAIClient
from neo4j_genai import VectorRetriever
from neo4j import Driver
from unittest.mock import MagicMock, patch

Expand All @@ -10,6 +10,6 @@ def driver():


@pytest.fixture
@patch("neo4j_genai.GenAIClient._verify_version")
def client(_verify_version_mock, driver):
return GenAIClient(driver)
@patch("neo4j_genai.VectorRetriever._verify_version")
def retriever(_verify_version_mock, driver):
return VectorRetriever(driver)
4 changes: 2 additions & 2 deletions tests/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_create_vector_index_happy_path(driver):
)


def test_create_vector_index_ensure_escaping(driver, client):
def test_create_vector_index_ensure_escaping(driver):
create_query = (
"CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:People) ON n.name OPTIONS "
"{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }"
Expand Down Expand Up @@ -97,7 +97,7 @@ def test_create_fulltext_index_empty_node_properties(driver):
assert "Error for inputs to create_fulltext_index" in str(excinfo)


def test_create_fulltext_index_ensure_escaping(driver, client):
def test_create_fulltext_index_ensure_escaping(driver):
label = "node-label"
text_node_properties = ["property-1", "property-2"]
create_query = (
Expand Down
Loading