-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #647 from andreped/pgvector-support
Added pgvector support
- Loading branch information
Showing
5 changed files
with
293 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .pgvector import PG_VectorStore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,265 @@ | ||
import ast | ||
import json | ||
import logging | ||
import uuid | ||
|
||
import pandas as pd | ||
from langchain_core.documents import Document | ||
from langchain_postgres.vectorstores import PGVector | ||
from sqlalchemy import create_engine, text | ||
|
||
from .. import ValidationError | ||
from ..base import VannaBase | ||
from ..types import TrainingPlan, TrainingPlanItem | ||
|
||
|
||
class PG_VectorStore(VannaBase): | ||
def __init__(self, config=None): | ||
if not config or "connection_string" not in config: | ||
raise ValueError( | ||
"A valid 'config' dictionary with a 'connection_string' is required.") | ||
|
||
VannaBase.__init__(self, config=config) | ||
|
||
if config and "connection_string" in config: | ||
self.connection_string = config.get("connection_string") | ||
self.n_results = config.get("n_results", 10) | ||
|
||
if config and "embedding_function" in config: | ||
self.embedding_function = config.get("embedding_function") | ||
else: | ||
from sentence_transformers import SentenceTransformer | ||
self.embedding_function = SentenceTransformer("sentence-transformers/all-MiniLM-l6-v2") | ||
|
||
self.sql_vectorstore = PGVector( | ||
embeddings=self.embedding_function, | ||
collection_name="sql", | ||
connection=self.connection_string, | ||
) | ||
self.ddl_vectorstore = PGVector( | ||
embeddings=self.embedding_function, | ||
collection_name="ddl", | ||
connection=self.connection_string, | ||
) | ||
self.documentation_vectorstore = PGVector( | ||
embeddings=self.embedding_function, | ||
collection_name="documentation", | ||
connection=self.connection_string, | ||
) | ||
|
||
def add_question_sql(self, question: str, sql: str, **kwargs) -> str: | ||
question_sql_json = json.dumps( | ||
{ | ||
"question": question, | ||
"sql": sql, | ||
}, | ||
ensure_ascii=False, | ||
) | ||
id = str(uuid.uuid4()) + "-sql" | ||
createdat = kwargs.get("createdat") | ||
doc = Document( | ||
page_content=question_sql_json, | ||
metadata={"id": id, "createdat": createdat}, | ||
) | ||
self.sql_collection.add_documents([doc], ids=[doc.metadata["id"]]) | ||
|
||
return id | ||
|
||
def add_ddl(self, ddl: str, **kwargs) -> str: | ||
_id = str(uuid.uuid4()) + "-ddl" | ||
doc = Document( | ||
page_content=ddl, | ||
metadata={"id": _id}, | ||
) | ||
self.ddl_collection.add_documents([doc], ids=[doc.metadata["id"]]) | ||
return _id | ||
|
||
def add_documentation(self, documentation: str, **kwargs) -> str: | ||
_id = str(uuid.uuid4()) + "-doc" | ||
doc = Document( | ||
page_content=documentation, | ||
metadata={"id": _id}, | ||
) | ||
self.documentation_collection.add_documents([doc], ids=[doc.metadata["id"]]) | ||
return _id | ||
|
||
def get_collection(self, collection_name): | ||
match collection_name: | ||
case "sql": | ||
return self.sql_collection | ||
case "ddl": | ||
return self.ddl_collection | ||
case "documentation": | ||
return self.documentation_collection | ||
case _: | ||
raise ValueError("Specified collection does not exist.") | ||
|
||
async def get_similar_question_sql(self, question: str) -> list: | ||
documents = self.sql_collection.similarity_search(query=question, k=self.n_results) | ||
return [ast.literal_eval(document.page_content) for document in documents] | ||
|
||
async def get_related_ddl(self, question: str, **kwargs) -> list: | ||
documents = await self.ddl_collection.similarity_search(query=question, k=self.n_results) | ||
return [document.page_content for document in documents] | ||
|
||
async def get_related_documentation(self, question: str, **kwargs) -> list: | ||
documents = await self.documentation_collection.similarity_search(query=question, k=self.n_results) | ||
return [document.page_content for document in documents] | ||
|
||
def train( | ||
self, | ||
question: str | None = None, | ||
sql: str | None = None, | ||
ddl: str | None = None, | ||
documentation: str | None = None, | ||
plan: TrainingPlan | None = None, | ||
createdat: str | None = None, | ||
): | ||
if question and not sql: | ||
raise ValidationError("Please provide a SQL query.") | ||
|
||
if documentation: | ||
logging.info(f"Adding documentation: {documentation}") | ||
return self.add_documentation(documentation) | ||
|
||
if sql and question: | ||
return self.add_question_sql(question=question, sql=sql, createdat=createdat) | ||
|
||
if ddl: | ||
logging.info(f"Adding ddl: {ddl}") | ||
return self.add_ddl(ddl) | ||
|
||
if plan: | ||
for item in plan._plan: | ||
if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL: | ||
self.add_ddl(item.item_value) | ||
elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS: | ||
self.add_documentation(item.item_value) | ||
elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL and item.item_name: | ||
self.add_question_sql(question=item.item_name, sql=item.item_value) | ||
|
||
def get_training_data(self, **kwargs) -> pd.DataFrame: | ||
# Establishing the connection | ||
engine = create_engine(self.connection_string) | ||
|
||
# Querying the 'langchain_pg_embedding' table | ||
query_embedding = "SELECT cmetadata, document FROM langchain_pg_embedding" | ||
df_embedding = pd.read_sql(query_embedding, engine) | ||
|
||
# List to accumulate the processed rows | ||
processed_rows = [] | ||
|
||
# Process each row in the DataFrame | ||
for _, row in df_embedding.iterrows(): | ||
custom_id = row["cmetadata"]["id"] | ||
document = row["document"] | ||
training_data_type = "documentation" if custom_id[-3:] == "doc" else custom_id[-3:] | ||
|
||
if training_data_type == "sql": | ||
# Convert the document string to a dictionary | ||
try: | ||
doc_dict = ast.literal_eval(document) | ||
question = doc_dict.get("question") | ||
content = doc_dict.get("sql") | ||
except (ValueError, SyntaxError): | ||
logging.info(f"Skipping row with custom_id {custom_id} due to parsing error.") | ||
continue | ||
elif training_data_type in ["documentation", "ddl"]: | ||
question = None # Default value for question | ||
content = document | ||
else: | ||
# If the suffix is not recognized, skip this row | ||
logging.info(f"Skipping row with custom_id {custom_id} due to unrecognized training data type.") | ||
continue | ||
|
||
# Append the processed data to the list | ||
processed_rows.append( | ||
{"id": custom_id, "question": question, "content": content, "training_data_type": training_data_type} | ||
) | ||
|
||
# Create a DataFrame from the list of processed rows | ||
df_processed = pd.DataFrame(processed_rows) | ||
|
||
return df_processed | ||
|
||
def remove_training_data(self, id: str, **kwargs) -> bool: | ||
# Create the database engine | ||
engine = create_engine(self.connection_string) | ||
|
||
# SQL DELETE statement | ||
delete_statement = text( | ||
""" | ||
DELETE FROM langchain_pg_embedding | ||
WHERE cmetadata ->> 'id' = :id | ||
""" | ||
) | ||
|
||
# Connect to the database and execute the delete statement | ||
with engine.connect() as connection: | ||
# Start a transaction | ||
with connection.begin() as transaction: | ||
try: | ||
result = connection.execute(delete_statement, {"id": id}) | ||
# Commit the transaction if the delete was successful | ||
transaction.commit() | ||
# Check if any row was deleted and return True or False accordingly | ||
return result.rowcount > 0 | ||
except Exception as e: | ||
# Rollback the transaction in case of error | ||
logging.error(f"An error occurred: {e}") | ||
transaction.rollback() | ||
return False | ||
|
||
def remove_collection(self, collection_name: str) -> bool: | ||
engine = create_engine(self.connection_string) | ||
|
||
# Determine the suffix to look for based on the collection name | ||
suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc"} | ||
suffix = suffix_map.get(collection_name) | ||
|
||
if not suffix: | ||
logging.info("Invalid collection name. Choose from 'ddl', 'sql', or 'documentation'.") | ||
return False | ||
|
||
# SQL query to delete rows based on the condition | ||
query = text( | ||
f""" | ||
DELETE FROM langchain_pg_embedding | ||
WHERE cmetadata->>'id' LIKE '%{suffix}' | ||
""" | ||
) | ||
|
||
# Execute the deletion within a transaction block | ||
with engine.connect() as connection: | ||
with connection.begin() as transaction: | ||
try: | ||
result = connection.execute(query) | ||
transaction.commit() # Explicitly commit the transaction | ||
if result.rowcount > 0: | ||
logging.info( | ||
f"Deleted {result.rowcount} rows from " | ||
f"langchain_pg_embedding where collection is {collection_name}." | ||
) | ||
return True | ||
else: | ||
logging.info(f"No rows deleted for collection {collection_name}.") | ||
return False | ||
except Exception as e: | ||
logging.error(f"An error occurred: {e}") | ||
transaction.rollback() # Rollback in case of error | ||
return False | ||
|
||
def generate_embedding(self, *args, **kwargs): | ||
pass | ||
|
||
def submit_prompt(self, *args, **kwargs): | ||
pass | ||
|
||
def system_message(self, message: str) -> any: | ||
return {"role": "system", "content": message} | ||
|
||
def user_message(self, message: str) -> any: | ||
return {"role": "user", "content": message} | ||
|
||
def assistant_message(self, message: str) -> any: | ||
return {"role": "assistant", "content": message} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import os | ||
This comment has been minimized.
Sorry, something went wrong. |
||
|
||
from dotenv import load_dotenv | ||
|
||
from vanna.pgvector import PG_VectorStore | ||
|
||
load_dotenv() | ||
|
||
|
||
def get_vanna_connection_string(): | ||
server = os.environ.get("PG_SERVER") | ||
driver = "psycopg" | ||
port = 5434 | ||
database = os.environ.get("PG_DATABASE") | ||
username = os.environ.get("PG_USERNAME") | ||
password = os.environ.get("PG_PASSWORD") | ||
|
||
return f"postgresql+psycopg://{username}:{password}@{server}:{port}/{database}" | ||
|
||
|
||
def test_pgvector(): | ||
connection_string = get_vanna_connection_string() | ||
pgclient = PG_VectorStore(config={"connection_string": connection_string}) | ||
assert pgclient is not None |
@zainhoda Just saw that you merged this PR. Did you want this
test_pgvector.py
? I just had it there for the PR draft. I would rather think this should be added somewhere else, or?